[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Describe the bug**\nA clear and concise description of what the bug is.\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1. Go to '...'\n2. Click on '....'\n3. Scroll down to '....'\n4. See error\n\n**Expected behavior**\nA clear and concise description of what you expected to happen.\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Desktop (please complete the following information):**\n - OS: [e.g. iOS]\n - Browser [e.g. chrome, safari]\n - Version [e.g. 22]\n\n**Smartphone (please complete the following information):**\n - Device: [e.g. iPhone6]\n - OS: [e.g. iOS8.1]\n - Browser [e.g. stock browser, safari]\n - Version [e.g. 22]\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/custom.md",
    "content": "---\nname: Custom issue template\nabout: Describe this issue template's purpose here.\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest an idea for this project\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Is your feature request related to a problem? Please describe.**\nA clear and concise description of what the problem is. Ex. I'm always frustrated when [...]\n\n**Describe the solution you'd like**\nA clear and concise description of what you want to happen.\n\n**Describe alternatives you've considered**\nA clear and concise description of any alternative solutions or features you've considered.\n\n**Additional context**\nAdd any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\ndocs/build/*\ndocs/pytorch_sphinx_theme/*\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n.idea/\n\nexp/*\ntrash/*\nexamples/domain_adaptation/digits/logs/*\nexamples/domain_adaptation/digits/data/*\n+.DS_Store\n*/.DS_Store\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "## Contributing to Transfer-Learning-Library\n\nAll kinds of contributions are welcome, including but not limited to the following.\n\n- Fix typo or bugs\n- Add documentation\n- Add new features and components\n\n### Workflow\n\n1. fork and pull the latest Transfer-Learning-Library repository\n2. checkout a new branch (do not use master branch for PRs)\n3. commit your changes\n4. create a PR\n\n```{note}\nIf you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.\n```\n"
  },
  {
    "path": "DATASETS.md",
    "content": "## Notice (2023-08-01)\n\n### Transfer-Learning-Library Dataset Link Failure Issue\nDear users,\n\nWe sincerely apologize to inform you that the dataset links of Transfer-Learning-Library have become invalid due to a cloud storage failure, resulting in many users being unable to download the datasets properly recently.\n\nWe are working diligently to resolve this issue and plan to restore the links as soon as possible. Currently, we have already restored some dataset links, and they have been updated on the master branch. You can obtain the latest version by running \"git pull.\"\n\nAs the version on PyPI has not been updated yet, please temporarily uninstall the old version by running \"pip uninstall tllib\" before installing the new one.\n\nIn the future, we are planning to store the datasets on both Baidu Cloud and Google Cloud to provide more stable download links.\n\nAdditionally, a small portion of datasets that were backed up on our local server have also been lost due to a hard disk failure. For these datasets, we need to re-download and verify them, which might take longer to restore the links.\n\nWithin this week, we will release the updated dataset and confirm the list of datasets without backup. For datasets without backup, if you have previously downloaded them locally, please contact us via email. Your support is highly appreciated.\n\nOnce again, we apologize for any inconvenience caused and thank you for your understanding.\n\nSincerely,\n\nThe Transfer-Learning-Library Team\n\n## Update (2023-08-09)\n\nMost of the dataset links have been restored at present. The confirmed datasets without backups are as follows:\n\n- Classification\n  - COCO70\n  - EuroSAT\n  - PACS\n  - PatchCamelyon\n  - [Partial Domain Adaptation]\n    - CaltechImageNet\n\n- Keypoint Detection\n  - Hand3DStudio\n  - LSP\n  - SURREAL\n\n- Object Detection\n  - Comic\n\n- Re-Identification\n  - PersonX\n  - UnrealPerson\n\n**For these datasets, if you had previously downloaded them locally, please contact us via email. We greatly appreciate everyone's support.**\n\n## Notice (2023-08-01)\n\n### Transfer-Learning-Library数据集链接失效问题\n\n各位使用者，我们很抱歉通知大家，最近Transfer-Learning-Library的数据集链接因为云盘故障而失效，导致很多使用者无法正常下载数据集。\n\n我们正在全力以赴解决这一问题，并计划在最短的时间内恢复链接。目前我们已经恢复了部分数据集链接，更新在master分支上，您可以通过git pull来获取最新的版本。\n\n由于pypi上的版本还未更新，暂时请首先通过pip uninstall tllib卸载旧版本。\n\n日后我们计划将数据集存储在百度云和谷歌云上，提供更加稳定的下载链接。\n\n另外，小部分数据集在我们本地服务器上的备份也由于硬盘故障而丢失，对于这些数据集我们需要重新下载并验证，可能需要更长的时间来恢复链接。\n\n我们会在本周内发布已经更新的数据集和确认无备份的数据集列表，对于无备份的数据集，如果您之前有下载到本地，请通过邮件联系我们，非常感谢大家的支持。\n\n再次向您表达我们的歉意，并感谢您的理解。\n\nTransfer-Learning-Library团队\n\n## Update (2023-08-09)\n\n目前大部分数据集的链接已经恢复，确认无备份的数据集如下：\n\n- Classification\n  - COCO70\n  - EuroSAT\n  - PACS\n  - PatchCamelyon\n  - [Partial Domain Adaptation]\n    - CaltechImageNet\n\n- Keypoint Detection\n  - Hand3DStudio\n  - LSP\n  - SURREAL\n\n- Object Detection\n  - Comic\n\n- Re-Identification\n  - PersonX\n  - UnrealPerson\n\n**对于这些数据集，如果您之前有下载到本地，请通过邮件联系我们，非常感谢大家的支持。**\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2018 The Python Packaging Authority\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "<div align='center' margin-bottom:40px> <img src=\"logo.png\" width=200/> </div>\n\n# Transfer Learning Library\n\n- [Introduction](#introduction)\n- [Updates](#updates)\n- [Supported Methods](#supported-methods)\n- [Installation](#installation)\n- [Documentation](#documentation)\n- [Contact](#contact)\n- [Citation](#citation)\n\n## Update (2024-03-15)\n\nWe upload an offline version of documentation [here](/docs/html.zip). You can download and unzip it to view the documentation.\n\n## Notice (2023-08-09)\n\nA note on broken dataset links can be found here: [DATASETS.md](DATASETS.md).\n\n## Introduction\n*TLlib* is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consistent with torchvision. You can easily develop new algorithms, or readily apply existing algorithms.\n\nOur _API_ is divided by methods, which include: \n- domain alignment methods (tllib.aligment)\n- domain translation methods (tllib.translation)\n- self-training methods (tllib.self_training)\n- regularization methods (tllib.regularization)\n- data reweighting/resampling methods (tllib.reweight)\n- model ranking/selection methods (tllib.ranking)\n- normalization-based methods (tllib.normalization)\n\n<img src=\"Tllib.png\">\n\nWe provide many example codes in the directory _examples_, which is divided by learning setups. Currently, the supported learning setups include:\n- DA (domain adaptation)\n- TA (task adaptation, also known as finetune)\n- OOD (out-of-distribution generalization, also known as DG / domain generalization)\n- SSL (semi-supervised learning)\n- Model Selection \n\nOur supported tasks include: classification, regression, object detection, segmentation, keypoint detection, and so on.\n\n## Updates \n\n### 2022.9\n\nWe support installing *TLlib* via `pip`, which is experimental currently.\n\n```shell\npip install -i https://test.pypi.org/simple/ tllib==0.4\n```\n\n### 2022.8\nWe release `v0.4` of *TLlib*. Previous versions of *TLlib* can be found [here](https://github.com/thuml/Transfer-Learning-Library/releases). In `v0.4`, we add implementations of \nthe following methods:\n- Domain Adaptation for Object Detection [[Code]](/examples/domain_adaptation/object_detection) [[API]](/tllib/alignment/d_adapt)\n- Pre-trained Model Selection [[Code]](/examples/model_selection) [[API]](/tllib/ranking)\n- Semi-supervised Learning for Classification [[Code]](/examples/semi_supervised_learning/image_classification/) [[API]](/tllib/self_training)\n\nBesides, we maintain a collection of **_awesome papers in Transfer Learning_** in another repo [_A Roadmap for Transfer Learning_](https://github.com/thuml/A-Roadmap-for-Transfer-Learning).\n\n### 2022.2\nWe adjusted our API following our survey [Transferablity in Deep Learning](https://arxiv.org/abs/2201.05867).\n\n## Supported Methods\nThe currently supported algorithms include:\n\n##### Domain Adaptation for Classification [[Code]](/examples/domain_adaptation/image_classification)\n- **DANN** - Unsupervised Domain Adaptation by Backpropagation [[ICML 2015]](http://proceedings.mlr.press/v37/ganin15.pdf) [[Code]](/examples/domain_adaptation/image_classification/dann.py)\n- **DAN** - Learning Transferable Features with Deep Adaptation Networks [[ICML 2015]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-adaptation-networks-icml15.pdf) [[Code]](/examples/domain_adaptation/image_classification/dan.py)\n- **JAN** - Deep Transfer Learning with Joint Adaptation Networks [[ICML 2017]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/joint-adaptation-networks-icml17.pdf) [[Code]](/examples/domain_adaptation/image_classification/jan.py)\n- **ADDA** - Adversarial Discriminative Domain Adaptation [[CVPR 2017]](http://openaccess.thecvf.com/content_cvpr_2017/papers/Tzeng_Adversarial_Discriminative_Domain_CVPR_2017_paper.pdf) [[Code]](/examples/domain_adaptation/image_classification/adda.py)\n- **CDAN** - Conditional Adversarial Domain Adaptation [[NIPS 2018]](http://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation) [[Code]](/examples/domain_adaptation/image_classification/cdan.py) \n- **MCD** - Maximum Classifier Discrepancy for Unsupervised Domain Adaptation [[CVPR 2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Saito_Maximum_Classifier_Discrepancy_CVPR_2018_paper.pdf) [[Code]](/examples/domain_adaptation/image_classification/mcd.py)\n- **MDD** - Bridging Theory and Algorithm for Domain Adaptation [[ICML 2019]](http://proceedings.mlr.press/v97/zhang19i/zhang19i.pdf) [[Code]](/examples/domain_adaptation/image_classification/mdd.py) \n- **BSP** - Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation [[ICML 2019]](http://proceedings.mlr.press/v97/chen19i/chen19i.pdf) [[Code]](/examples/domain_adaptation/image_classification/bsp.py) \n- **MCC** - Minimum Class Confusion for Versatile Domain Adaptation [[ECCV 2020]](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123660460.pdf) [[Code]](/examples/domain_adaptation/image_classification/mcc.py)\n\n##### Domain Adaptation for Object Detection [[Code]](/examples/domain_adaptation/object_detection)\n- **CycleGAN** - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [[ICCV 2017]](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf) [[Code]](/examples/domain_adaptation/object_detection/cycle_gan.py)\n- **D-adapt** - Decoupled Adaptation for Cross-Domain Object Detection [[ICLR 2022]](https://openreview.net/pdf?id=VNqaB1g9393) [[Code]](/examples/domain_adaptation/object_detection/d_adapt)\n\n##### Domain Adaptation for Semantic Segmentation [[Code]](/examples/domain_adaptation/semantic_segmentation/)\n- **CycleGAN** - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [[ICCV 2017]](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf) [[Code]](/examples/domain_adaptation/semantic_segmentation/cycle_gan.py)\n- **CyCADA** - Cycle-Consistent Adversarial Domain Adaptation [[ICML 2018]](http://proceedings.mlr.press/v80/hoffman18a.html) [[Code]](/examples/domain_adaptation/semantic_segmentation/cycada.py)\n- **ADVENT** - Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation [[CVPR 2019]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Vu_ADVENT_Adversarial_Entropy_Minimization_for_Domain_Adaptation_in_Semantic_Segmentation_CVPR_2019_paper.pdf) [[Code]](/examples/domain_adaptation/semantic_segmentation/advent.py)\n- **FDA** - Fourier Domain Adaptation for Semantic Segmentation [[CVPR 2020]](https://arxiv.org/abs/2004.05498) [[Code]](/examples/domain_adaptation/semantic_segmentation/fda.py)\n\n##### Domain Adaptation for Keypoint Detection [[Code]](/examples/domain_adaptation/keypoint_detection)\n- **RegDA** - Regressive Domain Adaptation for Unsupervised Keypoint Detection [[CVPR 2021]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/regressive-domain-adaptation-cvpr21.pdf) [[Code]](/examples/domain_adaptation/keypoint_detection)\n\n##### Domain Adaptation for Person Re-identification [[Code]](/examples/domain_adaptation/re_identification/)\n- **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)\n- **MMT** - Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification [[ICLR 2020]](https://arxiv.org/abs/2001.01526) [[Code]](/examples/domain_adaptation/re_identification/mmt.py)\n- **SPGAN** - Similarity Preserving Generative Adversarial Network [[CVPR 2018]](https://arxiv.org/pdf/1811.10551.pdf) [[Code]](/examples/domain_adaptation/re_identification/spgan.py)\n\n##### Partial Domain Adaptation [[Code]](/examples/domain_adaptation/partial_domain_adaptation)\n- **IWAN** - Importance Weighted Adversarial Nets for Partial Domain Adaptation[[CVPR 2018]](https://arxiv.org/abs/1803.09210) [[Code]](/examples/domain_adaptation/partial_domain_adaptation/iwan.py)\n- **AFN** - Larger Norm More Transferable: An Adaptive Feature Norm Approach for\nUnsupervised Domain Adaptation [[ICCV 2019]](https://arxiv.org/pdf/1811.07456v2.pdf) [[Code]](/examples/domain_adaptation/partial_domain_adaptation/afn.py)\n\n##### Open-set Domain Adaptation [[Code]](/examples/domain_adaptation/openset_domain_adaptation)\n- **OSBP** - Open Set Domain Adaptation by Backpropagation [[ECCV 2018]](https://arxiv.org/abs/1804.10427) [[Code]](/examples/domain_adaptation/openset_domain_adaptation/osbp.py)\n\n##### Domain Generalization for Classification [[Code]](/examples/domain_generalization/image_classification/)\n- **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)\n- **MixStyle** - Domain Generalization with MixStyle [[ICLR 2021]](https://arxiv.org/abs/2104.02008) [[Code]](/examples/domain_generalization/image_classification/mixstyle.py)\n- **MLDG** - Learning to Generalize: Meta-Learning for Domain Generalization [[AAAI 2018]](https://arxiv.org/pdf/1710.03463.pdf) [[Code]](/examples/domain_generalization/image_classification/mldg.py)\n- **IRM** - Invariant Risk Minimization [[ArXiv]](https://arxiv.org/abs/1907.02893) [[Code]](/examples/domain_generalization/image_classification/irm.py)\n- **VREx** - Out-of-Distribution Generalization via Risk Extrapolation [[ICML 2021]](https://arxiv.org/abs/2003.00688) [[Code]](/examples/domain_generalization/image_classification/vrex.py)\n- **GroupDRO** - Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization [[ArXiv]](https://arxiv.org/abs/1911.08731) [[Code]](/examples/domain_generalization/image_classification/groupdro.py)\n- **Deep CORAL** - Correlation Alignment for Deep Domain Adaptation [[ECCV 2016]](https://arxiv.org/abs/1607.01719) [[Code]](/examples/domain_generalization/image_classification/coral.py)\n\n##### Domain Generalization for Person Re-identification [[Code]](/examples/domain_generalization/re_identification/)\n- **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)\n- **MixStyle** - Domain Generalization with MixStyle [[ICLR 2021]](https://arxiv.org/abs/2104.02008) [[Code]](/examples/domain_generalization/re_identification/mixstyle.py)\n\n##### Task Adaptation (Fine-Tuning) for Image Classification [[Code]](/examples/task_adaptation/image_classification/)\n- **L2-SP** - Explicit inductive bias for transfer learning with convolutional networks [[ICML 2018]]((https://arxiv.org/abs/1802.01483)) [[Code]](/examples/task_adaptation/image_classification/delta.py)\n- **BSS** - Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning [[NIPS 2019]](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf) [[Code]](/examples/task_adaptation/image_classification/bss.py)\n- **DELTA** - DEep Learning Transfer using Fea- ture Map with Attention for convolutional networks [[ICLR 2019]](https://openreview.net/pdf?id=rkgbwsAcYm) [[Code]](/examples/task_adaptation/image_classification/delta.py)\n- **Co-Tuning** - Co-Tuning for Transfer Learning [[NIPS 2020]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf) [[Code]](/examples/task_adaptation/image_classification/co_tuning.py)\n- **StochNorm** - Stochastic Normalization [[NIPS 2020]](https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf) [[Code]](/examples/task_adaptation/image_classification/stochnorm.py)\n- **LWF** - Learning Without Forgetting [[ECCV 2016]](https://arxiv.org/abs/1606.09282) [[Code]](/examples/task_adaptation/image_classification/lwf.py)\n- **Bi-Tuning** - Bi-tuning of Pre-trained Representations [[ArXiv]](https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29) [[Code]](/examples/task_adaptation/image_classification/bi_tuning.py)\n\n##### Pre-trained Model Selection [[Code]](/examples/model_selection)\n\n- **H-Score** - An Information-theoretic Approach to Transferability in Task Transfer Learning [[ICIP 2019]](http://yangli-feasibility.com/home/media/icip-19.pdf) [[Code]](/examples/model_selection/hscore.py)\n- **NCE** - Negative Conditional Entropy in `Transferability and Hardness of Supervised Classification Tasks [[ICCV 2019]](https://arxiv.org/pdf/1908.08142v1.pdf) [[Code]](/examples/model_selection/nce.py)\n- **LEEP** - LEEP: A New Measure to Evaluate Transferability of Learned Representations [[ICML 2020]](http://proceedings.mlr.press/v119/nguyen20b/nguyen20b.pdf) [[Code]](/examples/model_selection/leep.py)\n- **LogME** - Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models for Transfer Learning [[ICML 2021]](https://arxiv.org/pdf/2102.11005.pdf) [[Code]](/examples/model_selection/logme.py)\n\n##### Semi-Supervised Learning for Classification [[Code]](/examples/semi_supervised_learning/image_classification/)\n- **Pseudo Label** - Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks [[ICML 2013]](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf) [[Code]](/examples/semi_supervised_learning/image_classification/pseudo_label.py)\n- **Pi Model** - Temporal Ensembling for Semi-Supervised Learning [[ICLR 2017]](https://arxiv.org/abs/1610.02242) [[Code]](/examples/semi_supervised_learning/image_classification/pi_model.py)\n- **Mean Teacher** - Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results [[NIPS 2017]](https://arxiv.org/abs/1703.01780) [[Code]](/examples/semi_supervised_learning/image_classification/mean_teacher.py)\n- **Noisy Student** - Self-Training With Noisy Student Improves ImageNet Classification [[CVPR 2020]](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_Self-Training_With_Noisy_Student_Improves_ImageNet_Classification_CVPR_2020_paper.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/noisy_student.py)\n- **UDA** - Unsupervised Data Augmentation for Consistency Training [[NIPS 2020]](https://arxiv.org/pdf/1904.12848v4.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/uda.py)\n- **FixMatch** - Simplifying Semi-Supervised Learning with Consistency and Confidence [[NIPS 2020]](https://arxiv.org/abs/2001.07685) [[Code]](/examples/semi_supervised_learning/image_classification/fixmatch.py)\n- **Self-Tuning** - Self-Tuning for Data-Efficient Deep Learning [[ICML 2021]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/self_tuning.py)\n- **FlexMatch** - FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling [[NIPS 2021]](https://arxiv.org/abs/2110.08263) [[Code]](/examples/semi_supervised_learning/image_classification/flexmatch.py)\n- **DebiasMatch** - Debiased Learning From Naturally Imbalanced Pseudo-Labels [[CVPR 2022]](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Debiased_Learning_From_Naturally_Imbalanced_Pseudo-Labels_CVPR_2022_paper.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/debiasmatch.py)\n- **DST** - Debiased Self-Training for Semi-Supervised Learning [[NIPS 2022 Oral]](https://arxiv.org/abs/2202.07136) [[Code]](/examples/semi_supervised_learning/image_classification/dst.py)\n\n## Installation\n\n##### Install from Source Code\n\n- Please git clone the library first. Then, run the following commands to install `tllib` and all the dependency.\n```shell\npython setup.py install\npip install -r requirements.txt\n```\n##### Install via `pip`\n\n- Installing via `pip` is currently experimental.\n\n```shell\npip install -i https://test.pypi.org/simple/ tllib==0.4\n```\n\n\n## Documentation\nYou can find the API documentation on the website: [Documentation](http://tl.thuml.ai/).\n\n## Usage\nYou can find examples in the directory `examples`. A typical usage is \n```shell script\n# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.\n# Assume you have put the datasets under the path `data/office-31`, \n# or you are glad to download the datasets automatically from the Internet to this path\npython dann.py data/office31 -d Office31 -s A -t W -a resnet50  --epochs 20\n```\n\n## Contributing\nWe appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us. \n\n## Disclaimer on Datasets\n\nThis is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have licenses to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.\n\nIf you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!\n\n\n## Contact\nIf you have any problem with our code or have some suggestions, including the future feature, feel free to contact \n- Baixu Chen (cbx_99_hasta@outlook.com)\n- Junguang Jiang (JiangJunguang1123@outlook.com)\n- Mingsheng Long (longmingsheng@gmail.com)\n\nor describe it in Issues.\n\nFor Q&A in Chinese, you can choose to ask questions here before sending an email. [迁移学习算法库答疑专区](https://zhuanlan.zhihu.com/p/248104070)\n\n## Citation\n\nIf you use this toolbox or benchmark in your research, please cite this project. \n\n```latex\n@misc{jiang2022transferability,\n      title={Transferability in Deep Learning: A Survey}, \n      author={Junguang Jiang and Yang Shu and Jianmin Wang and Mingsheng Long},\n      year={2022},\n      eprint={2201.05867},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG}\n}\n\n@misc{tllib,\n    author = {Junguang Jiang, Baixu Chen, Bo Fu, Mingsheng Long},\n    title = {Transfer-Learning-library},\n    year = {2020},\n    publisher = {GitHub},\n    journal = {GitHub repository},\n    howpublished = {\\url{https://github.com/thuml/Transfer-Learning-Library}},\n}\n```\n\n## Acknowledgment\n\nWe would like to thank School of Software, Tsinghua University and The National Engineering Laboratory for Big Data Software for providing such an excellent ML research platform.\n\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSPHINXPROJ    = PyTorchSphinxTheme\nSOURCEDIR     = .\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/conf.py",
    "content": "import sys\nimport os\n\nsys.path.append(os.path.abspath('..'))\nsys.path.append(os.path.abspath('./demo/'))\n\nfrom pytorch_sphinx_theme import __version__\nimport pytorch_sphinx_theme\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#sys.path.insert(0, os.path.abspath('.'))\n\n# -- General configuration -----------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be extensions\n# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.\nextensions = [\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.autodoc',\n    'sphinx.ext.viewcode',\n    'sphinxcontrib.httpdomain',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.autosectionlabel',\n    'sphinx.ext.napoleon',\n]\n\n# build the templated autosummary files\nautosummary_generate = True\nnumpydoc_show_class_members = False\n\n# autosectionlabel throws warnings if section names are duplicated.\n# The following tells autosectionlabel to not throw a warning for\n# duplicated section names that are in different documents.\nautosectionlabel_prefix_document = True\n\n\nnapoleon_use_ivar = True\n\n# Do not warn about external images (status badges in README.rst)\nsuppress_warnings = ['image.nonlocal_uri']\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix of source filenames.\nsource_suffix = '.rst'\n\n# The encoding of source files.\n#source_encoding = 'utf-8-sig'\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# General information about the project.\nproject = u'Transfer Learning Library'\ncopyright = u'THUML Group'\n\n# The version info for the project you're documenting, acts as replacement for\n# |version| and |release|, also used in various other places throughout the\n# built documents.\n#\n# The short X.Y version.\nversion = __version__\n# The full version, including alpha/beta/rc tags.\nrelease = __version__\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\nlanguage = 'en'\n\n# There are two options for replacing |today|: either, you set today to some\n# non-false value, then it is used:\n#today = ''\n# Else, today_fmt is used as the format for a strftime call.\n#today_fmt = '%B %d, %Y'\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\nexclude_patterns = []\n\n# The reST default role (used for this markup: `text`) to use for all documents.\n#default_role = None\n\n# If true, '()' will be appended to :func: etc. cross-reference text.\n#add_function_parentheses = True\n\n# If true, the current module name will be prepended to all description\n# unit titles (such as .. function::).\n#add_module_names = True\n\n# If true, sectionauthor and moduleauthor directives will be shown in the\n# output. They are ignored by default.\n#show_authors = False\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = 'sphinx'\n\n\n# A list of ignored prefixes for module index sorting.\n#modindex_common_prefix = []\n\nintersphinx_mapping = {\n    'rtd': ('https://docs.readthedocs.io/en/latest/', None),\n    'python': ('https://docs.python.org/3', None),\n    'numpy': ('https://numpy.org/doc/stable', None),\n    'torch': ('https://pytorch.org/docs/stable', None),\n    'torchvision': ('https://pytorch.org/vision/stable', None),\n    'PIL': ('https://pillow.readthedocs.io/en/stable/', None)\n}\n\n\n# -- Options for HTML output ---------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\nhtml_theme = 'pytorch_sphinx_theme'\nhtml_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\nhtml_theme_options = {\n    'canonical_url': '',\n    'analytics_id': '',\n    'logo_only': False,\n    'display_version': False,\n    'prev_next_buttons_location': 'bottom',\n    'style_external_links': False,\n\n    # Toc options\n    'collapse_navigation': True,\n    'sticky_navigation': False,\n    'navigation_depth': 4,\n    'includehidden': True,\n    'titles_only': False\n}\n\n\n# The name for this set of Sphinx documents.  If None, it defaults to\n# \"<project> v<release> documentation\".\n#html_title = None\n\n# A shorter title for the navigation bar.  Default is the same as html_title.\n#html_short_title = None\n\n# The name of an image file (relative to this directory) to place at the top\n# of the sidebar.\nhtml_logo = \"_static/images/TransLearn.png\"\n\n# The name of an image file (within the static path) to use as favicon of the\n# docs.  This file should be a Windows icon file (.ico) being 16x16 or 32x32\n# pixels large.\n#html_favicon = None\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n\n# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,\n# using the given strftime format.\n#html_last_updated_fmt = '%b %d, %Y'\n\n# If true, SmartyPants will be used to convert quotes and dashes to\n# typographically correct entities.\n#html_use_smartypants = True\n\n# Custom sidebar templates, maps document names to template names.\n#html_sidebars = {}\n\n# Additional templates that should be rendered to pages, maps page names to\n# template names.\n#html_additional_pages = {}\n\n# If false, no module index is generated.\n#html_domain_indices = True\n\n# If false, no index is generated.\n#html_use_index = True\n\n# If true, the index is split into individual pages for each letter.\n#html_split_index = False\n\n# If true, links to the reST sources are added to the pages.\nhtml_show_sourcelink = True\n\n# If true, \"Created using Sphinx\" is shown in the HTML footer. Default is True.\n#html_show_sphinx = True\n\n# If true, \"(C) Copyright ...\" is shown in the HTML footer. Default is True.\n#html_show_copyright = True\n\n# If true, an OpenSearch description file will be output, and all pages will\n# contain a <link> tag referring to it.  The value of this option must be the\n# base URL from which the finished HTML is served.\n#html_use_opensearch = ''\n\n# This is the file name suffix for HTML files (e.g. \".xhtml\").\n#html_file_suffix = None\n\n\n# Disable displaying type annotations, these can be very verbose\nautodoc_typehints = 'none'\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'TransferLearningLibrary'\n\n\n# -- Options for LaTeX output --------------------------------------------------\n\nlatex_elements = {\n# The paper size ('letterpaper' or 'a4paper').\n#'papersize': 'letterpaper',\n\n# The font size ('10pt', '11pt' or '12pt').\n#'pointsize': '10pt',\n\n# Additional stuff for the LaTeX preamble.\n#'preamble': '',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title, author, documentclass [howto/manual]).\nlatex_documents = [\n  ('index', 'TransferLearningLibrary.tex', u'Transfer Learning Library Documentation',\n   u'THUML', 'manual'),\n]\n\n# The name of an image file (relative to this directory) to place at the top of\n# the title page.\n#latex_logo = None\n\n# For \"manual\" documents, if this is true, then toplevel headings are parts,\n# not chapters.\n#latex_use_parts = False\n\n# If true, show page references after internal links.\n#latex_show_pagerefs = False\n\n# If true, show URL addresses after external links.\n#latex_show_urls = False\n\n# Documents to append as an appendix to all manuals.\n#latex_appendices = []\n\n# If false, no module index is generated.\n#latex_domain_indices = True\n\n\n# -- Options for manual page output --------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [\n    ('index', 'Transfer Learning Library', u'Transfer Learning Library Documentation',\n     [u'THUML'], 1)\n]\n\n# If true, show URL addresses after external links.\n#man_show_urls = False\n\n\n# -- Options for Texinfo output ------------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n  ('index', 'Transfer Learning Library', u'Transfer Learning Library Documentation',\n   u'THUML', 'Transfer Learning Library',\n   'One line description of project.', 'Miscellaneous'),\n]\n\n# Documents to append as an appendix to all manuals.\n#texinfo_appendices = []\n\n# If false, no module index is generated.\n#texinfo_domain_indices = True\n\n# How to display URL addresses: 'footnote', 'no', or 'inline'.\n#texinfo_show_urls = 'footnote'\n\n"
  },
  {
    "path": "docs/index.rst",
    "content": "=====================================\nTransfer Learning\n=====================================\n\n.. toctree::\n    :maxdepth: 2\n    :caption: Transfer Learning API\n    :titlesonly:\n\n    tllib/modules\n    tllib/alignment/index\n    tllib/translation\n    tllib/self_training\n    tllib/reweight\n    tllib/normalization\n    tllib/regularization\n    tllib/ranking\n\n\n.. toctree::\n    :maxdepth: 2\n    :caption: Common API\n    :titlesonly:\n\n    tllib/vision/index\n    tllib/utils/index\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=python -msphinx\n)\nset SPHINXOPTS=\nset SPHINXBUILD=sphinx-build\nset SOURCEDIR=.\nset BUILDDIR=build\nset SPHINXPROJ=PyTorchSphinxTheme\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n\techo.\n\techo.The Sphinx module was not found. Make sure you have Sphinx installed,\n\techo.then set the SPHINXBUILD environment variable to point to the full\n\techo.path of the 'sphinx-build' executable. Alternatively you may add the\n\techo.Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.http://sphinx-doc.org/\n\texit /b 1\n)\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%\n\n:end\npopd\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "sphinxcontrib-httpdomain\nsphinx\n"
  },
  {
    "path": "docs/tllib/alignment/domain_adversarial.rst",
    "content": "==========================================\nDomain Adversarial Training\n==========================================\n\n\n.. _DANN:\n\nDANN: Domain Adversarial Neural Network\n----------------------------------------\n\n.. autoclass:: tllib.alignment.dann.DomainAdversarialLoss\n\n.. _CDAN:\n\nCDAN: Conditional Domain Adversarial Network\n-----------------------------------------------\n\n.. autoclass:: tllib.alignment.cdan.ConditionalDomainAdversarialLoss\n\n\n.. autoclass:: tllib.alignment.cdan.RandomizedMultiLinearMap\n\n\n.. autoclass:: tllib.alignment.cdan.MultiLinearMap\n\n\n.. _ADDA:\n\nADDA: Adversarial Discriminative Domain Adaptation\n-----------------------------------------------------\n\n.. autoclass:: tllib.alignment.adda.DomainAdversarialLoss\n\n.. note::\n    ADDAgrl is also implemented and benchmarked. You can find code\n    `here <https://github.com/thuml/Transfer-Learning-Library/blob/master/examples/domain_adaptation/image_classification/addagrl.py>`_.\n\n\n.. _BSP:\n\nBSP: Batch Spectral Penalization\n-----------------------------------\n\n.. autoclass:: tllib.alignment.bsp.BatchSpectralPenalizationLoss\n\n\n.. _OSBP:\n\nOSBP: Open Set Domain Adaptation by Backpropagation\n----------------------------------------------------\n\n.. autoclass:: tllib.alignment.osbp.UnknownClassBinaryCrossEntropy\n\n\n.. _ADVENT:\n\nADVENT: Adversarial Entropy Minimization for Semantic Segmentation\n------------------------------------------------------------------\n\n.. autoclass:: tllib.alignment.advent.Discriminator\n\n.. autoclass:: tllib.alignment.advent.DomainAdversarialEntropyLoss\n    :members:\n\n\n.. _DADAPT:\n\nD-adapt: Decoupled Adaptation for Cross-Domain Object Detection\n----------------------------------------------------------------\n`Origin Paper <https://openreview.net/pdf?id=VNqaB1g9393>`_.\n\n.. autoclass:: tllib.alignment.d_adapt.proposal.Proposal\n\n.. autoclass:: tllib.alignment.d_adapt.proposal.PersistentProposalList\n\n.. autoclass:: tllib.alignment.d_adapt.proposal.ProposalDataset\n\n.. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledGeneralizedRCNN\n\n.. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledRetinaNet\n\n"
  },
  {
    "path": "docs/tllib/alignment/hypothesis_adversarial.rst",
    "content": "==========================================\nHypothesis Adversarial Learning\n==========================================\n\n\n\n.. _MCD:\n\nMCD: Maximum Classifier Discrepancy\n--------------------------------------------\n\n.. autofunction:: tllib.alignment.mcd.classifier_discrepancy\n\n.. autofunction:: tllib.alignment.mcd.entropy\n\n.. autoclass:: tllib.alignment.mcd.ImageClassifierHead\n\n\n.. _MDD:\n\n\nMDD: Margin Disparity Discrepancy\n--------------------------------------------\n\n\n.. autoclass:: tllib.alignment.mdd.MarginDisparityDiscrepancy\n\n\n**MDD for Classification**\n\n\n.. autoclass:: tllib.alignment.mdd.ClassificationMarginDisparityDiscrepancy\n\n\n.. autoclass:: tllib.alignment.mdd.ImageClassifier\n    :members:\n\n.. autofunction:: tllib.alignment.mdd.shift_log\n\n\n**MDD for Regression**\n\n.. autoclass:: tllib.alignment.mdd.RegressionMarginDisparityDiscrepancy\n\n.. autoclass:: tllib.alignment.mdd.ImageRegressor\n\n\n.. _RegDA:\n\nRegDA: Regressive Domain Adaptation\n--------------------------------------------\n\n.. autoclass:: tllib.alignment.regda.PseudoLabelGenerator2d\n\n.. autoclass:: tllib.alignment.regda.RegressionDisparity\n\n.. autoclass:: tllib.alignment.regda.PoseResNet2d\n"
  },
  {
    "path": "docs/tllib/alignment/index.rst",
    "content": "=====================================\nFeature Alignment\n=====================================\n\n.. toctree::\n    :maxdepth: 3\n    :caption: Feature Alignment\n    :titlesonly:\n\n    statistics_matching\n    domain_adversarial\n    hypothesis_adversarial\n"
  },
  {
    "path": "docs/tllib/alignment/statistics_matching.rst",
    "content": "=====================\nStatistics Matching\n=====================\n\n\n.. _DAN:\n\nDAN: Deep Adaptation Network\n-----------------------------\n\n.. autoclass:: tllib.alignment.dan.MultipleKernelMaximumMeanDiscrepancy\n\n\n.. _CORAL:\n\nDeep CORAL: Correlation Alignment for Deep Domain Adaptation\n--------------------------------------------------------------\n\n.. autoclass:: tllib.alignment.coral.CorrelationAlignmentLoss\n\n\n.. _JAN:\n\nJAN: Joint Adaptation Network\n------------------------------\n\n.. autoclass:: tllib.alignment.jan.JointMultipleKernelMaximumMeanDiscrepancy\n\n"
  },
  {
    "path": "docs/tllib/modules.rst",
    "content": "=====================\nModules\n=====================\n\n\nClassifier\n-------------------------------\n.. autoclass:: tllib.modules.classifier.Classifier\n    :members:\n\nRegressor\n-------------------------------\n.. autoclass:: tllib.modules.regressor.Regressor\n    :members:\n\nDomain Discriminator\n-------------------------------\n.. autoclass:: tllib.modules.domain_discriminator.DomainDiscriminator\n    :members:\n\nGRL: Gradient Reverse Layer\n-----------------------------\n.. autoclass:: tllib.modules.grl.WarmStartGradientReverseLayer\n    :members:\n\nGaussian Kernels\n------------------------\n.. autoclass:: tllib.modules.kernels.GaussianKernel\n\n\nEntropy\n------------------------\n.. autofunction:: tllib.modules.entropy.entropy\n\n\nKnowledge Distillation Loss\n-------------------------------\n.. autoclass:: tllib.modules.loss.KnowledgeDistillationLoss\n    :members:\n\n\n"
  },
  {
    "path": "docs/tllib/normalization.rst",
    "content": "=====================\nNormalization\n=====================\n\n\n\n.. _AFN:\n\nAFN: Adaptive Feature Norm\n-----------------------------\n\n.. autoclass:: tllib.normalization.afn.AdaptiveFeatureNorm\n\n.. autoclass:: tllib.normalization.afn.Block\n\n.. autoclass:: tllib.normalization.afn.ImageClassifier\n\n\nStochNorm: Stochastic Normalization\n------------------------------------------\n\n.. autoclass:: tllib.normalization.stochnorm.StochNorm1d\n\n.. autoclass:: tllib.normalization.stochnorm.StochNorm2d\n\n.. autoclass:: tllib.normalization.stochnorm.StochNorm3d\n\n.. autofunction:: tllib.normalization.stochnorm.convert_model\n\n\n.. _IBN:\n\nIBN-Net: Instance-Batch Normalization Network\n------------------------------------------------\n\n.. autoclass:: tllib.normalization.ibn.InstanceBatchNorm2d\n\n.. autoclass:: tllib.normalization.ibn.IBNNet\n    :members:\n\n.. automodule:: tllib.normalization.ibn\n   :members:\n\n\n.. _MIXSTYLE:\n\nMixStyle: Domain Generalization with MixStyle\n-------------------------------------------------\n\n.. autoclass:: tllib.normalization.mixstyle.MixStyle\n\n.. note::\n    MixStyle is only activated during `training` stage, with some probability :math:`p`.\n\n.. automodule:: tllib.normalization.mixstyle.resnet\n    :members:\n"
  },
  {
    "path": "docs/tllib/ranking.rst",
    "content": "=====================\nRanking\n=====================\n\n\n\n.. _H_score:\n\nH-score\n-------------------------------------------\n\n.. autofunction:: tllib.ranking.hscore.h_score\n\n\n.. _LEEP:\n\nLEEP: Log Expected Empirical Prediction\n-------------------------------------------\n\n.. autofunction:: tllib.ranking.leep.log_expected_empirical_prediction\n\n\n.. _NCE:\n\nNCE: Negative Conditional Entropy\n-------------------------------------------\n\n.. autofunction:: tllib.ranking.nce.negative_conditional_entropy\n\n\n.. _LogME:\n\nLogME: Log Maximum Evidence\n-------------------------------------------\n\n.. autofunction:: tllib.ranking.logme.log_maximum_evidence\n\n"
  },
  {
    "path": "docs/tllib/regularization.rst",
    "content": "===========================================\nRegularization\n===========================================\n\n.. _L2:\n\nL2\n------\n\n.. autoclass:: tllib.regularization.delta.L2Regularization\n\n\n.. _L2SP:\n\nL2-SP\n------\n\n.. autoclass:: tllib.regularization.delta.SPRegularization\n\n\n.. _DELTA:\n\nDELTA: DEep Learning Transfer using Feature Map with Attention\n-------------------------------------------------------------------------------------\n\n.. autoclass:: tllib.regularization.delta.BehavioralRegularization\n\n.. autoclass:: tllib.regularization.delta.AttentionBehavioralRegularization\n\n.. autoclass:: tllib.regularization.delta.IntermediateLayerGetter\n\n\n.. _LWF:\n\nLWF: Learning without Forgetting\n------------------------------------------\n\n.. autoclass:: tllib.regularization.lwf.Classifier\n\n\n\n.. _CoTuning:\n\nCo-Tuning\n------------------------------------------\n\n.. autoclass:: tllib.regularization.co_tuning.CoTuningLoss\n\n.. autoclass:: tllib.regularization.co_tuning.Relationship\n\n\n.. _StochNorm:\n\n\n.. _BiTuning:\n\nBi-Tuning\n------------------------------------------\n\n.. autoclass:: tllib.regularization.bi_tuning.BiTuning\n\n\n.. _BSS:\n\nBSS: Batch Spectral Shrinkage\n------------------------------------------\n\n.. autoclass:: tllib.regularization.bss.BatchSpectralShrinkage\n\n"
  },
  {
    "path": "docs/tllib/reweight.rst",
    "content": "=======================================\nRe-weighting\n=======================================\n\n\n.. _PADA:\n\nPADA: Partial Adversarial Domain Adaptation\n---------------------------------------------\n\n.. autoclass:: tllib.reweight.pada.ClassWeightModule\n\n.. autoclass:: tllib.reweight.pada.AutomaticUpdateClassWeightModule\n    :members:\n\n.. autofunction::  tllib.reweight.pada.collect_classification_results\n\n\n.. _IWAN:\n\nIWAN: Importance Weighted Adversarial Nets\n---------------------------------------------\n\n.. autoclass:: tllib.reweight.iwan.ImportanceWeightModule\n    :members:\n\n\n\n.. _GroupDRO:\n\nGroupDRO: Group Distributionally robust optimization\n------------------------------------------------------\n\n.. autoclass:: tllib.reweight.groupdro.AutomaticUpdateDomainWeightModule\n    :members:\n"
  },
  {
    "path": "docs/tllib/self_training.rst",
    "content": "=======================================\nSelf Training Methods\n=======================================\n\n\n.. _PseudoLabel:\n\nPseudo Label\n-----------------------------\n\n.. autoclass:: tllib.self_training.pseudo_label.ConfidenceBasedSelfTrainingLoss\n\n.. _PiModel:\n\n:math:`\\Pi` Model\n-----------------------------\n\n.. autoclass:: tllib.self_training.pi_model.ConsistencyLoss\n\n\n.. autoclass:: tllib.self_training.pi_model.L2ConsistencyLoss\n\n\n.. _MeanTeacher:\n\nMean Teacher\n-----------------------------\n\n.. autoclass:: tllib.self_training.mean_teacher.EMATeacher\n\n\n.. _SelfEnsemble:\n\nSelf Ensemble\n-----------------------------\n\n.. autoclass:: tllib.self_training.self_ensemble.ClassBalanceLoss\n\n\n.. _UDA:\n\nUDA\n-----------------------------\n\n.. autoclass:: tllib.self_training.uda.StrongWeakConsistencyLoss\n\n\n.. _MCC:\n\nMCC: Minimum Class Confusion\n-----------------------------\n\n.. autoclass:: tllib.self_training.mcc.MinimumClassConfusionLoss\n\n\n.. _MMT:\n\nMMT: Mutual Mean-Teaching\n--------------------------\n`Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised\nDomain Adaptation on Person Re-identification (ICLR 2020) <https://arxiv.org/pdf/2001.01526.pdf>`_\n\nState of the art unsupervised domain adaptation methods utilize clustering algorithms to generate pseudo labels on target\ndomain, which are noisy and thus harmful for training. Inspired by the teacher-student approaches, MMT framework\nprovides robust soft pseudo labels in an on-line peer-teaching manner.\n\nWe denote two networks as :math:`f_1,f_2`, their parameters as :math:`\\theta_1,\\theta_2`. The authors also\npropose to use the temporally average model of each network :math:`\\text{ensemble}(f_1),\\text{ensemble}(f_2)` to generate more reliable\nsoft pseudo labels for supervising the other network. Specifically, the parameters of the temporally\naverage models of the two networks at current iteration :math:`T` are denoted as :math:`E^{(T)}[\\theta_1]` and\n:math:`E^{(T)}[\\theta_2]` respectively, which can be calculated as\n\n.. math::\n    E^{(T)}[\\theta_1] = \\alpha E^{(T-1)}[\\theta_1] + (1-\\alpha)\\theta_1\n.. math::\n    E^{(T)}[\\theta_2] = \\alpha E^{(T-1)}[\\theta_2] + (1-\\alpha)\\theta_2\n\nwhere :math:`E^{(T-1)}[\\theta_1],E^{(T-1)}[\\theta_2]` indicate the temporal average parameters of the two networks in\nthe previous iteration :math:`(T-1)`, the initial temporal average parameters are\n:math:`E^{(0)}[\\theta_1]=\\theta_1,E^{(0)}[\\theta_2]=\\theta_2` and :math:`\\alpha` is the momentum.\n\nThese two networks cooperate with each other in three ways:\n\n- When running clustering algorithm, we average features produced by :math:`\\text{ensemble}(f_1)` and\n    :math:`\\text{ensemble}(f_2)` instead of only considering one of them.\n- A **soft triplet loss** is optimized between :math:`f_1` and :math:`\\text{ensemble}(f_2)` and vice versa\n    to force one network to learn from temporally average of another network.\n- A **cross entropy loss** is optimized between :math:`f_1` and :math:`\\text{ensemble}(f_2)` and vice versa\n    to force one network to learn from temporally average of another network.\n\nThe above mentioned loss functions are listed below, more details can be found in training scripts.\n\n.. autoclass:: tllib.vision.models.reid.loss.SoftTripletLoss\n\n.. autoclass:: tllib.vision.models.reid.loss.CrossEntropyLoss\n\n\n.. _SelfTuning:\n\nSelf Tuning\n-----------------------------\n\n.. autoclass:: tllib.self_training.self_tuning.Classifier\n\n.. autoclass:: tllib.self_training.self_tuning.SelfTuning\n\n\n.. _FlexMatch:\n\nFlexMatch\n-----------------------------\n\n.. autoclass:: tllib.self_training.flexmatch.DynamicThresholdingModule\n    :members:\n\n.. _DST:\n\nDebiased Self-Training\n-----------------------------\n\n.. autoclass:: tllib.self_training.dst.ImageClassifier\n\n.. autoclass:: tllib.self_training.dst.WorstCaseEstimationLoss\n"
  },
  {
    "path": "docs/tllib/translation.rst",
    "content": "=======================================\nDomain Translation\n=======================================\n\n\n.. _CycleGAN:\n\n------------------------------------------------\nCycleGAN: Cycle-Consistent Adversarial Networks\n------------------------------------------------\n\nDiscriminator\n--------------\n\n.. autofunction:: tllib.translation.cyclegan.pixel\n\n.. autofunction:: tllib.translation.cyclegan.patch\n\nGenerator\n--------------\n\n.. autofunction:: tllib.translation.cyclegan.resnet_9\n\n.. autofunction:: tllib.translation.cyclegan.resnet_6\n\n.. autofunction:: tllib.translation.cyclegan.unet_256\n\n.. autofunction:: tllib.translation.cyclegan.unet_128\n\n\nGAN Loss\n--------------\n\n.. autoclass:: tllib.translation.cyclegan.LeastSquaresGenerativeAdversarialLoss\n\n.. autoclass:: tllib.translation.cyclegan.VanillaGenerativeAdversarialLoss\n\n.. autoclass:: tllib.translation.cyclegan.WassersteinGenerativeAdversarialLoss\n\nTranslation\n--------------\n\n.. autoclass:: tllib.translation.cyclegan.Translation\n\n\nUtil\n----------------\n\n.. autoclass:: tllib.translation.cyclegan.util.ImagePool\n    :members:\n\n.. autofunction:: tllib.translation.cyclegan.util.set_requires_grad\n\n\n\n\n.. _Cycada:\n\n--------------------------------------------------------------\nCyCADA: Cycle-Consistent Adversarial Domain Adaptation\n--------------------------------------------------------------\n\n.. autoclass:: tllib.translation.cycada.SemanticConsistency\n\n\n\n.. _SPGAN:\n\n-----------------------------------------------------------\nSPGAN: Similarity Preserving Generative Adversarial Network\n-----------------------------------------------------------\n`Image-Image Domain Adaptation with Preserved Self-Similarity and Domain-Dissimilarity for Person Re-identification\n<https://arxiv.org/pdf/1711.07027.pdf>`_. SPGAN is based on CycleGAN. An additional Siamese network is adopted to force\nthe generator to produce images different from identities in target dataset.\n\nSiamese Network\n-------------------\n\n.. autoclass:: tllib.translation.spgan.siamese.SiameseNetwork\n\nContrastive Loss\n-------------------\n\n.. autoclass:: tllib.translation.spgan.loss.ContrastiveLoss\n\n\n.. _FDA:\n\n------------------------------------------------\nFDA: Fourier Domain Adaptation\n------------------------------------------------\n\n.. autoclass:: tllib.translation.fourier_transform.FourierTransform\n\n.. autofunction:: tllib.translation.fourier_transform.low_freq_mutate\n\n\n\n\n\n"
  },
  {
    "path": "docs/tllib/utils/analysis.rst",
    "content": "==============\nAnalysis Tools\n==============\n\n\n.. autofunction:: tllib.utils.analysis.collect_feature\n\n\n.. autofunction:: tllib.utils.analysis.a_distance.calculate\n\n\n.. autofunction:: tllib.utils.analysis.tsne.visualize\n\n"
  },
  {
    "path": "docs/tllib/utils/base.rst",
    "content": "Generic Tools\n==============\n\n\nAverage Meter\n---------------------------------\n\n.. autoclass:: tllib.utils.meter.AverageMeter\n   :members:\n\nProgress Meter\n---------------------------------\n\n.. autoclass:: tllib.utils.meter.ProgressMeter\n   :members:\n\nMeter\n---------------------------------\n\n.. autoclass:: tllib.utils.meter.Meter\n   :members:\n\nData\n---------------------------------\n\n.. autoclass:: tllib.utils.data.ForeverDataIterator\n   :members:\n\n.. autoclass:: tllib.utils.data.CombineDataset\n   :members:\n\n.. autofunction:: tllib.utils.data.send_to_device\n\n.. autofunction:: tllib.utils.data.concatenate\n\nLogger\n-----------\n\n.. autoclass:: tllib.utils.logger.TextLogger\n   :members:\n\n\n.. autoclass:: tllib.utils.logger.CompleteLogger\n   :members:\n\n"
  },
  {
    "path": "docs/tllib/utils/index.rst",
    "content": "=====================================\nUtilities\n=====================================\n\n.. toctree::\n    :maxdepth: 2\n    :caption: Utilities\n    :titlesonly:\n\n    base\n    metric\n    analysis"
  },
  {
    "path": "docs/tllib/utils/metric.rst",
    "content": "===========\nMetrics\n===========\n\nClassification & Segmentation\n==============================\n\n\nAccuracy\n---------------------------------\n\n.. autofunction:: tllib.utils.metric.accuracy\n\n\nConfusionMatrix\n---------------------------------\n\n.. autoclass:: tllib.utils.metric.ConfusionMatrix\n   :members:\n"
  },
  {
    "path": "docs/tllib/vision/datasets.rst",
    "content": "Datasets\n=============================\n\nCross-Domain Classification\n---------------------------------------------------------\n\n\n--------------------------------------\nImageList\n--------------------------------------\n\n.. autoclass:: tllib.vision.datasets.imagelist.ImageList\n   :members:\n\n-------------------------------------\nOffice-31\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.office31.Office31\n   :members:\n   :inherited-members:\n\n---------------------------------------\nOffice-Caltech\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.officecaltech.OfficeCaltech\n   :members:\n   :inherited-members:\n\n---------------------------------------\nOffice-Home\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.officehome.OfficeHome\n   :members:\n   :inherited-members:\n\n--------------------------------------\nVisDA-2017\n--------------------------------------\n\n.. autoclass:: tllib.vision.datasets.visda2017.VisDA2017\n   :members:\n   :inherited-members:\n\n--------------------------------------\nDomainNet\n--------------------------------------\n\n.. autoclass:: tllib.vision.datasets.domainnet.DomainNet\n   :members:\n   :inherited-members:\n\n--------------------------------------\nPACS\n--------------------------------------\n\n.. autoclass:: tllib.vision.datasets.pacs.PACS\n   :members:\n\n\n--------------------------------------\nMNIST\n--------------------------------------\n\n.. autoclass:: tllib.vision.datasets.digits.MNIST\n   :members:\n\n\n--------------------------------------\nUSPS\n--------------------------------------\n\n.. autoclass:: tllib.vision.datasets.digits.USPS\n   :members:\n\n\n--------------------------------------\nSVHN\n--------------------------------------\n\n.. autoclass:: tllib.vision.datasets.digits.SVHN\n   :members:\n\n\nPartial Cross-Domain Classification\n----------------------------------------------------\n\n---------------------------------------\nPartial Wrapper\n---------------------------------------\n\n.. autofunction:: tllib.vision.datasets.partial.partial\n\n.. autofunction:: tllib.vision.datasets.partial.default_partial\n\n\n---------------------------------------\nCaltech-256->ImageNet-1k\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.partial.caltech_imagenet.CaltechImageNet\n   :members:\n\n\n---------------------------------------\nImageNet-1k->Caltech-256\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.partial.imagenet_caltech.ImageNetCaltech\n   :members:\n\n\nOpen Set Cross-Domain Classification\n------------------------------------------------------\n\n---------------------------------------\nOpen Set Wrapper\n---------------------------------------\n\n.. autofunction:: tllib.vision.datasets.openset.open_set\n\n.. autofunction:: tllib.vision.datasets.openset.default_open_set\n\n\nCross-Domain Regression\n------------------------------------------------------\n\n---------------------------------------\nImageRegression\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.regression.image_regression.ImageRegression\n   :members:\n\n---------------------------------------\nDSprites\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.regression.dsprites.DSprites\n   :members:\n\n---------------------------------------\nMPI3D\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.regression.mpi3d.MPI3D\n   :members:\n\n\nCross-Domain Segmentation\n-----------------------------------------------\n\n---------------------------------------\nSegmentationList\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.segmentation.segmentation_list.SegmentationList\n   :members:\n\n---------------------------------------\nCityscapes\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.segmentation.cityscapes.Cityscapes\n\n---------------------------------------\nGTA5\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.segmentation.gta5.GTA5\n\n---------------------------------------\nSynthia\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.segmentation.synthia.Synthia\n\n\n---------------------------------------\nFoggy Cityscapes\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.segmentation.cityscapes.FoggyCityscapes\n\n\nCross-Domain Keypoint Detection\n-----------------------------------------------\n\n---------------------------------------\nDataset Base for Keypoint Detection\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.KeypointDataset\n   :members:\n\n.. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.Body16KeypointDataset\n   :members:\n\n.. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.Hand21KeypointDataset\n   :members:\n\n---------------------------------------\nRendered Handpose Dataset\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.keypoint_detection.rendered_hand_pose.RenderedHandPose\n   :members:\n\n---------------------------------------\nHand-3d-Studio Dataset\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.keypoint_detection.hand_3d_studio.Hand3DStudio\n   :members:\n\n---------------------------------------\nFreiHAND Dataset\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.keypoint_detection.freihand.FreiHand\n   :members:\n\n---------------------------------------\nSurreal Dataset\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.keypoint_detection.surreal.SURREAL\n   :members:\n\n---------------------------------------\nLSP Dataset\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.keypoint_detection.lsp.LSP\n   :members:\n\n---------------------------------------\nHuman3.6M Dataset\n---------------------------------------\n.. autoclass:: tllib.vision.datasets.keypoint_detection.human36m.Human36M\n   :members:\n\nCross-Domain ReID\n------------------------------------------------------\n\n---------------------------------------\nMarket1501\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.reid.market1501.Market1501\n   :members:\n\n---------------------------------------\nDukeMTMC-reID\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.reid.dukemtmc.DukeMTMC\n   :members:\n\n---------------------------------------\nMSMT17\n---------------------------------------\n\n.. autoclass:: tllib.vision.datasets.reid.msmt17.MSMT17\n   :members:\n\n\nNatural Object Recognition\n---------------------------------------------------------\n\n\n-------------------------------------\nStanford Dogs\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.stanford_dogs.StanfordDogs\n   :members:\n\n-------------------------------------\nStanford Cars\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.stanford_cars.StanfordCars\n   :members:\n\n-------------------------------------\nCUB-200-2011\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.cub200.CUB200\n   :members:\n\n-------------------------------------\nFVGC Aircraft\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.aircrafts.Aircraft\n   :members:\n\n-------------------------------------\nOxford-IIIT Pets\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.oxfordpets.OxfordIIITPets\n   :members:\n\n-------------------------------------\nCOCO-70\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.coco70.COCO70\n   :members:\n\n-------------------------------------\nDTD\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.dtd.DTD\n   :members:\n\n-------------------------------------\nOxfordFlowers102\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.oxfordflowers.OxfordFlowers102\n   :members:\n\n-------------------------------------\nCaltech101\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.caltech101.Caltech101\n   :members:\n\n\nSpecialized Image Classification\n--------------------------------\n\n-------------------------------------\nPatchCamelyon\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.patchcamelyon.PatchCamelyon\n   :members:\n\n-------------------------------------\nRetinopathy\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.retinopathy.Retinopathy\n   :members:\n\n-------------------------------------\nEuroSAT\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.eurosat.EuroSAT\n   :members:\n\n-------------------------------------\nResisc45\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.resisc45.Resisc45\n   :members:\n\n-------------------------------------\nFood-101\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.food101.Food101\n   :members:\n\n-------------------------------------\nSUN397\n-------------------------------------\n\n.. autoclass:: tllib.vision.datasets.sun397.SUN397\n   :members:\n"
  },
  {
    "path": "docs/tllib/vision/index.rst",
    "content": "=====================================\nVision\n=====================================\n\n.. toctree::\n    :maxdepth: 2\n    :caption: Vision\n    :titlesonly:\n\n    datasets\n    models\n    transforms"
  },
  {
    "path": "docs/tllib/vision/models.rst",
    "content": "Models\n===========================\n\n------------------------------\nImage Classification\n------------------------------\n\nResNets\n---------------------------------\n\n.. automodule:: tllib.vision.models.resnet\n   :members:\n\nLeNet\n--------------------------\n\n.. automodule:: tllib.vision.models.digits.lenet\n   :members:\n\nDTN\n--------------------------\n\n.. automodule:: tllib.vision.models.digits.dtn\n   :members:\n\n----------------------------------\nObject Detection\n----------------------------------\n\n.. autoclass:: tllib.vision.models.object_detection.meta_arch.TLGeneralizedRCNN\n   :members:\n\n.. autoclass:: tllib.vision.models.object_detection.meta_arch.TLRetinaNet\n   :members:\n\n.. autoclass:: tllib.vision.models.object_detection.proposal_generator.rpn.TLRPN\n\n.. autoclass:: tllib.vision.models.object_detection.roi_heads.TLRes5ROIHeads\n    :members:\n\n.. autoclass:: tllib.vision.models.object_detection.roi_heads.TLStandardROIHeads\n    :members:\n\n----------------------------------\nSemantic Segmentation\n----------------------------------\n\n.. autofunction:: tllib.vision.models.segmentation.deeplabv2.deeplabv2_resnet101\n\n\n----------------------------------\nKeypoint Detection\n----------------------------------\n\nPoseResNet\n--------------------------\n\n.. autofunction:: tllib.vision.models.keypoint_detection.pose_resnet.pose_resnet101\n\n.. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.PoseResNet\n\n.. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.Upsampling\n\n\nJoint Loss\n----------------------------------\n\n.. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsMSELoss\n\n.. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsKLLoss\n\n\n-----------------------------------\nRe-Identification\n-----------------------------------\n\nModels\n---------------\n.. autoclass:: tllib.vision.models.reid.resnet.ReidResNet\n\n.. automodule:: tllib.vision.models.reid.resnet\n    :members:\n\n.. autoclass:: tllib.vision.models.reid.identifier.ReIdentifier\n    :members:\n\nLoss\n-----------------------------------\n.. autoclass:: tllib.vision.models.reid.loss.TripletLoss\n\nSampler\n-----------------------------------\n.. autoclass:: tllib.utils.data.RandomMultipleGallerySampler\n"
  },
  {
    "path": "docs/tllib/vision/transforms.rst",
    "content": "Transforms\n=============================\n\n\nClassification\n---------------------------------\n\n.. automodule:: tllib.vision.transforms\n   :members:\n\n\nSegmentation\n---------------------------------\n\n\n.. automodule:: tllib.vision.transforms.segmentation\n   :members:\n\n\nKeypoint Detection\n---------------------------------\n\n\n.. automodule:: tllib.vision.transforms.keypoint_detection\n   :members:\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/README.md",
    "content": "# Unsupervised Domain Adaptation for Image Classification\n\n## Installation\n\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You\nalso need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [MNIST](http://yann.lecun.com/exdb/mnist/), [SVHN](http://ufldl.stanford.edu/housenumbers/)\n  , [USPS](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps)\n- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)\n- [OfficeCaltech](https://www.cc.gatech.edu/~judy/domainadapt/)\n- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)\n- [VisDA2017](http://ai.bu.edu/visda-2017/)\n- [DomainNet](http://ai.bu.edu/M3SDA/)\n\nYou need to prepare following datasets manually if you want to use them:\n\n- [ImageNet](https://www.image-net.org/)\n- [ImageNetR](https://github.com/hendrycks/imagenet-r)\n- [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch)\n\nand prepare them following [Documentation for ImageNetR](/common/vision/datasets/imagenet_r.py)\nand [ImageNet-Sketch](/common/vision/datasets/imagenet_sketch.py).\n\n## Supported Methods\n\nSupported methods include:\n\n- [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818)\n- [Deep Adaptation Network (DAN)](https://arxiv.org/pdf/1502.02791)\n- [Joint Adaptation Network (JAN)](https://arxiv.org/abs/1605.06636)\n- [Adversarial Discriminative Domain Adaptation (ADDA)](https://arxiv.org/pdf/1702.05464.pdf)\n- [Conditional Domain Adversarial Network (CDAN)](https://arxiv.org/abs/1705.10667)\n- [Maximum Classifier Discrepancy (MCD)](https://arxiv.org/abs/1712.02560)\n- [Adaptive Feature Norm (AFN)](https://arxiv.org/pdf/1811.07456v2.pdf)\n- [Batch Spectral Penalization (BSP)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/batch-spectral-penalization-icml19.pdf)\n- [Margin Disparity Discrepancy (MDD)](https://arxiv.org/abs/1904.05801)\n- [Minimum Class Confusion (MCC)](https://arxiv.org/abs/1912.03699)\n- [FixMatch](https://arxiv.org/abs/2001.07685)\n\n## Usage\n\nThe shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to\ntrain DANN on Office31, use the following script\n\n```shell script\n# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.\n# Assume you have put the datasets under the path `data/office-31`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W\n```\n\nNote that ``-s`` specifies the source domain, ``-t`` specifies the target domain, and ``--log`` specifies where to store\nresults.\n\nAfter running the above command, it will download ``Office-31`` datasets from the Internet if it's the first time you\nrun the code. Directory that stores datasets will be named as\n``examples/domain_adaptation/image_classification/data/<dataset name>``.\n\nIf everything works fine, you will see results in following format::\n\n    Epoch: [1][ 900/1000]\tTime  0.60 ( 0.69)\tData  0.22 ( 0.31)\tLoss   0.74 (  0.85)\tCls Acc 96.9 (95.1)\tDomain Acc 64.1 (62.6)\n\nYou can also watch these results in the log file ``logs/dann/Office31_A2W/log.txt``.\n\nAfter training, you can test your algorithm's performance by passing in ``--phase test``.\n\n```\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W --phase test\n```\n\n## Experiment and Results\n\n**Notations**\n\n- ``Origin`` means the accuracy reported by the original paper.\n- ``Avg`` is the accuracy reported by `TLlib`.\n- ``ERM`` refers to the model trained with data from the source domain.\n- ``Oracle`` refers to the model trained with data from the target domain.\n\nWe found that the accuracies of adversarial methods (including DANN, ADDA, CDAN, MCD, BSP and MDD) are not stable even\nafter the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017*\nfor three times and report their average accuracy.\n\n### Office-31 accuracy on ResNet-50\n\n| Methods | Origin | Avg  | A → W | D → W | W → D | A → D | D → A | W → A |\n|---------|--------|------|-------|-------|-------|-------|-------|-------|\n| ERM     | 76.1   | 79.5 | 75.8  | 95.5  | 99.0  | 79.3  | 63.6  | 63.8  |\n| DANN    | 82.2   | 86.1 | 91.4  | 97.9  | 100.0 | 83.6  | 73.3  | 70.4  |\n| ADDA    | /      | 87.3 | 94.6  | 97.5  | 99.7  | 90.0  | 69.6  | 72.5  |\n| BSP     | 87.7   | 87.8 | 92.7  | 97.9  | 100.0 | 88.2  | 74.1  | 73.8  |\n| DAN     | 80.4   | 83.7 | 84.2  | 98.4  | 100.0 | 87.3  | 66.9  | 65.2  |\n| JAN     | 84.3   | 87.0 | 93.7  | 98.4  | 100.0 | 89.4  | 69.2  | 71.0  |\n| CDAN    | 87.7   | 87.7 | 93.8  | 98.5  | 100.0 | 89.9  | 73.4  | 70.4  |\n| MCD     | /      | 85.4 | 90.4  | 98.5  | 100.0 | 87.3  | 68.3  | 67.6  |\n| AFN     | 85.7   | 88.6 | 94.0  | 98.9  | 100.0 | 94.4  | 72.9  | 71.1  |\n| MDD     | 88.9   | 89.6 | 95.6  | 98.6  | 100.0 | 94.4  | 76.6  | 72.2  |\n| MCC     | 89.4   | 89.6 | 94.1  | 98.4  | 99.8  | 95.6  | 75.5  | 74.2  |\n| FixMatch| /      | 86.4 | 86.4  | 98.2  | 100.0 | 95.4  | 70.0  | 68.1  |\n\n### Office-Home accuracy on ResNet-50\n\n| Methods     | Origin | Avg  | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr |\n|-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|\n| ERM         | 46.1   | 58.4 | 41.1    | 65.9    | 73.7    | 53.1    | 60.1    | 63.3    | 52.2    | 36.7    | 71.8    | 64.8    | 42.6    | 75.2    |\n| DAN         | 56.3   | 61.4 | 45.6    | 67.7    | 73.9    | 57.7    | 63.8    | 66.0    | 54.9    | 40.0    | 74.5    | 66.2    | 49.1    | 77.9    |\n| DANN        | 57.6   | 65.2 | 53.8    | 62.6    | 74.0    | 55.8    | 67.3    | 67.3    | 55.8    | 55.1    | 77.9    | 71.1    | 60.7    | 81.1    |\n| ADDA        | /      | 65.6 | 52.6    | 62.9    | 74.0    | 59.7    | 68.0    | 68.8    | 61.4    | 52.5    | 77.6    | 71.1    | 58.6    | 80.2    |\n| JAN         | 58.3   | 65.9 | 50.8    | 71.9    | 76.5    | 60.6    | 68.3    | 68.7    | 60.5    | 49.6    | 76.9    | 71.0    | 55.9    | 80.5    |\n| CDAN        | 65.8   | 68.8 | 55.2    | 72.4    | 77.6    | 62.0    | 69.7    | 70.9    | 62.4    | 54.3    | 80.5    | 75.5    | 61.0    | 83.8    |\n| MCD         | /      | 67.8 | 51.7    | 72.2    | 78.2    | 63.7    | 69.5    | 70.8    | 61.5    | 52.8    | 78.0    | 74.5    | 58.4    | 81.8    |\n| BSP         | 64.9   | 67.6 | 54.7    | 67.7    | 76.2    | 61.0    | 69.4    | 70.9    | 60.9    | 55.2    | 80.2    | 73.4    | 60.3    | 81.2    |\n| AFN         | 67.3   | 68.2 | 53.2    | 72.7    | 76.8    | 65.0    | 71.3    | 72.3    | 65.0    | 51.4    | 77.9    | 72.3    | 57.8    | 82.4    |\n| MDD         | 68.1   | 69.7 | 56.2    | 75.4    | 79.6    | 63.5    | 72.1    | 73.8    | 62.5    | 54.8    | 79.9    | 73.5    | 60.9    | 84.5    |\n| MCC         | /      | 72.4 | 58.4    | 79.6    | 83.0    | 67.5    | 77.0    | 78.5    | 66.6    | 54.8    | 81.8    | 74.4    | 61.4    | 85.6    |\n| FixMatch    | /      | 70.8 | 56.4    | 76.4    | 79.9    | 65.3    | 73.8    | 71.2    | 67.2    | 56.4    | 80.6    | 74.9    | 63.5    | 84.3    |\n\n### Office-Home accuracy on vit_base_patch16_224 (batch size 24)\n\n| Methods     | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr | Avg  |\n|-------------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|------|\n| Source Only | 52.4    | 82.1    | 86.9    | 76.8    | 84.1    | 86      | 75.1    | 51.2    | 88.1    | 78.3    | 51.5    | 87.8    | 75.0 |\n| DANN        | 60.1    | 80.8    | 87.9    | 78.1    | 82.6    | 85.9    | 78.8    | 63.2    | 90.2    | 82.3    | 64      | 89.3    | 78.6 |\n| DAN         | 56.3    | 83.6    | 87.5    | 77.7    | 84.7    | 86.7    | 75.9    | 54.5    | 88.5    | 80.2    | 56.2    | 88.2    | 76.7 |\n| JAN         | 60.1    | 86.9    | 88.6    | 79.2    | 85.4    | 86.7    | 80.4    | 59.4    | 89.6    | 82      | 60.7    | 89.9    | 79.1 |\n| CDAN        | 61.6    | 87.8    | 89.6    | 81.4    | 88.1    | 88.5    | 82.4    | 62.5    | 90.8    | 84.2    | 63.5    | 90.8    | 80.9 |\n| MCD         | 52.3    | 75.3    | 85.3    | 75.4    | 75.4    | 78.3    | 68.8    | 49.7    | 86      | 80.6    | 60      | 89      | 73.0 |\n| AFN         | 58.3    | 87.2    | 88.2    | 81.7    | 87      | 88.2    | 81      | 58.4    | 89.2    | 81.5    | 59.2    | 89.2    | 79.1 |\n| MDD         | 64      | 89.3    | 90.4    | 82.2    | 87.7    | 89.2    | 82.8    | 64.9    | 91.7    | 83.7    | 65.4    | 92      | 81.9 |\n\n### VisDA-2017 accuracy ResNet-101\n\n| Methods     | Origin | Mean | plane | bcycl | bus  | car  | horse | knife | mcycl | person | plant | sktbrd | train | truck | Avg  |\n|-------------|--------|------|-------|-------|------|------|-------|-------|-------|--------|-------|--------|-------|-------|------|\n| ERM         | 52.4   | 51.7 | 63.6  | 35.3  | 50.6 | 78.2 | 74.6  | 18.7  | 82.1  | 16.0   | 84.2  | 35.5   | 77.4  | 4.7   | 56.9 |\n| DANN        | 57.4   | 79.5 | 93.5  | 74.3  | 83.4 | 50.7 | 87.2  | 90.2  | 89.9  | 76.1   | 88.1  | 91.4   | 89.7  | 39.8  | 74.9 |\n| ADDA        | /      | 77.5 | 95.6  | 70.8  | 84.4 | 54.0 | 87.8  | 75.8  | 88.4  | 69.3   | 84.1  | 86.2   | 85.0  | 48.0  | 74.3 |\n| BSP         | 75.9   | 80.5 | 95.7  | 75.6  | 82.8 | 54.5 | 89.2  | 96.5  | 91.3  | 72.2   | 88.9  | 88.7   | 88.0  | 43.4  | 76.2 |\n| DAN         | 61.1   | 66.4 | 89.2  | 37.2  | 77.7 | 61.8 | 81.7  | 64.3  | 90.6  | 61.4   | 79.9  | 37.7   | 88.1  | 27.4  | 67.2 |\n| JAN         | /      | 73.4 | 96.3  | 66.0  | 82.0 | 44.1 | 86.4  | 70.3  | 87.9  | 74.6   | 83.0  | 64.6   | 84.5  | 41.3  | 70.3 |\n| CDAN        | /      | 80.1 | 94.0  | 69.2  | 78.9 | 57.0 | 89.8  | 94.9  | 91.9  | 80.3   | 86.8  | 84.9   | 85.0  | 48.5  | 76.5 |\n| MCD         | 71.9   | 77.7 | 87.8  | 75.7  | 84.2 | 78.1 | 91.6  | 95.3  | 88.1  | 78.3   | 83.4  | 64.5   | 84.8  | 20.9  | 76.7 |\n| AFN         | 76.1   | 75.0 | 95.6  | 56.2  | 81.3 | 69.8 | 93.0  | 81.0  | 93.4  | 74.1   | 91.7  | 55.0   | 90.6  | 18.1  | 74.4 |\n| MDD         | /      | 82.0 | 88.3  | 62.8  | 85.2 | 69.9 | 91.9  | 95.1  | 94.4  | 81.2   | 93.8  | 89.8   | 84.1  | 47.9  | 79.8 |\n| MCC         | 78.8   | 83.6 | 95.3  | 85.8  | 77.1 | 68.0 | 93.9  | 92.9  | 84.5  | 79.5   | 93.6  | 93.7   | 85.3  | 53.8  | 80.4 |\n| FixMatch    | /      | 79.5 | 96.5  | 76.6  | 72.6 | 84.6 | 96.3  | 92.6  | 90.5  | 81.8   | 91.9  | 74.6   | 87.3  | 8.6   | 78.4 |\n\n### DomainNet accuracy on ResNet-101\n\n| Methods   | c->p | c->r | c->s | p->c | p->r | p->s | r->c | r->p | r->s | s->c | s->p | s->r | Avg  |\n|-------------|------|------|------|------|------|------|------|------|------|------|------|------|------|\n| ERM         | 32.7 | 50.6 | 39.4 | 41.1 | 56.8 | 35.0 | 48.6 | 48.8 | 36.1 | 49.0 | 34.8 | 46.1 | 43.3 |\n| DAN         | 38.8 | 55.2 | 43.9 | 45.9 | 59.0 | 40.8 | 50.8 | 49.8 | 38.9 | 56.1 | 45.9 | 55.5 | 48.4 |\n| DANN        | 37.9 | 54.3 | 44.4 | 41.7 | 55.6 | 36.8 | 50.7 | 50.8 | 40.1 | 55.0 | 45.0 | 54.5 | 47.2 |\n| JAN         | 40.5 | 56.7 | 45.1 | 47.2 | 59.9 | 43.0 | 54.2 | 52.6 | 41.9 | 56.6 | 46.2 | 55.5 | 50.0 |\n| CDAN        | 40.4 | 56.8 | 46.1 | 45.1 | 58.4 | 40.5 | 55.6 | 53.6 | 43.0 | 57.2 | 46.4 | 55.7 | 49.9 |\n| MCD         | 37.5 | 52.9 | 44.0 | 44.6 | 54.5 | 41.6 | 52.0 | 51.5 | 39.7 | 55.5 | 44.6 | 52.0 | 47.5 |\n| MDD         | 42.9 | 59.5 | 47.5 | 48.6 | 59.4 | 42.6 | 58.3 | 53.7 | 46.2 | 58.7 | 46.5 | 57.7 | 51.8 |\n| MCC         | 37.7 | 55.7 | 42.6 | 45.4 | 59.8 | 39.9 | 54.4 | 53.1 | 37.0 | 58.1 | 46.3 | 56.2 | 48.9 |\n\n### DomainNet accuracy on ResNet-101 (Multi-Source)\n\n| Methods     | Origin | Avg  | :c   | :i   | :p   | :q   | :r   | :s   |\n|-------------|--------|------|------|------|------|------|------|------|\n| ERM         | 32.9   | 47.0 | 64.9 | 25.2 | 54.4 | 16.9 | 68.2 | 52.3 |\n| MDD         | /      | 48.8 | 68.7 | 29.7 | 58.2 | 9.7  | 69.4 | 56.9 |\n| Oracle      | 63.0   | 69.1 | 78.2 | 40.7 | 71.6 | 69.7 | 83.8 | 70.6 |\n\n### Performance on ImageNet-scale dataset\n\n|      | ResNet50, ImageNet->ImageNetR | ig_resnext101_32x8d, ImageNet->ImageSketch |\n|------|-------------------------------|------------------------------------------|\n| ERM  | 35.6                          | 54.9                                     |\n| DAN  | 39.8                          | 55.7                                     |\n| DANN | 52.7                          | 56.5                                     |\n| JAN  | 41.7                          | 55.7                                     |\n| CDAN | 53.9                          | 58.2                                     |\n| MCD  | 46.7                          | 55.0                                     |\n| AFN  | 43.0                          | 55.1                                     |\n| MDD  | 56.2                          | 62.4                                     |\n\n## Visualization\n\nAfter training `DANN`, run the following command\n\n```\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W --phase analysis\n```\n\nIt may take a while, then in directory ``logs/dann/Office31_A2W/visualize``, you can find\n``TSNE.png``.\n\nFollowing are the t-SNE of representations from ResNet50 trained on source domain and those from DANN.\n\n<img src=\"./fig/resnet_A2W.png\" width=\"300\"/>\n<img src=\"./fig/dann_A2W.png\" width=\"300\"/>\n\n## TODO\n\n1. Support self-training methods\n2. Support translation methods\n3. Add results on ViT\n4. Add results on ImageNet\n\n## Citation\n\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{DANN,\n    author = {Ganin, Yaroslav and Lempitsky, Victor},\n    Booktitle = {ICML},\n    Title = {Unsupervised domain adaptation by backpropagation},\n    Year = {2015}\n}\n\n@inproceedings{DAN,\n    author    = {Mingsheng Long and\n    Yue Cao and\n    Jianmin Wang and\n    Michael I. Jordan},\n    title     = {Learning Transferable Features with Deep Adaptation Networks},\n    booktitle = {ICML},\n    year      = {2015},\n}\n\n@inproceedings{JAN,\n    title={Deep transfer learning with joint adaptation networks},\n    author={Long, Mingsheng and Zhu, Han and Wang, Jianmin and Jordan, Michael I},\n    booktitle={ICML},\n    year={2017},\n}\n\n@inproceedings{ADDA,\n    title={Adversarial discriminative domain adaptation},\n    author={Tzeng, Eric and Hoffman, Judy and Saenko, Kate and Darrell, Trevor},\n    booktitle={CVPR},\n    year={2017}\n}\n\n@inproceedings{CDAN,\n    author    = {Mingsheng Long and\n                Zhangjie Cao and\n                Jianmin Wang and\n                Michael I. Jordan},\n    title     = {Conditional Adversarial Domain Adaptation},\n    booktitle = {NeurIPS},\n    year      = {2018}\n}\n\n@inproceedings{MCD,\n    title={Maximum classifier discrepancy for unsupervised domain adaptation},\n    author={Saito, Kuniaki and Watanabe, Kohei and Ushiku, Yoshitaka and Harada, Tatsuya},\n    booktitle={CVPR},\n    year={2018}\n}\n\n@InProceedings{AFN,\n    author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang},\n    title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation},\n    booktitle = {ICCV},\n    year = {2019}\n}\n\n@inproceedings{MDD,\n    title={Bridging theory and algorithm for domain adaptation},\n    author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael},\n    booktitle={ICML},\n    year={2019},\n}\n\n@inproceedings{BSP,\n    title={Transferability vs. discriminability: Batch spectral penalization for adversarial domain adaptation},\n    author={Chen, Xinyang and Wang, Sinan and Long, Mingsheng and Wang, Jianmin},\n    booktitle={ICML},\n    year={2019},\n}\n\n@inproceedings{MCC,\n    author    = {Ying Jin and\n                Ximei Wang and\n                Mingsheng Long and\n                Jianmin Wang},\n    title     = {Less Confusion More Transferable: Minimum Class Confusion for Versatile\n               Domain Adaptation},\n    year={2020},\n    booktitle={ECCV},\n}\n\n@inproceedings{FixMatch,\n    title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},\n    author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},\n    booktitle={NIPS},\n    year={2020}\n}\n\n```\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/adda.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\nNote: Our implementation is different from ADDA paper in several respects. We do not use separate networks for\nsource and target domain, nor fix classifier head. Besides, we do not adopt asymmetric objective loss function\nof the feature extractor.\n\"\"\"\nimport random\nimport time\nimport warnings\nimport copy\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.alignment.adda import ImageClassifier\nfrom tllib.alignment.dann import DomainAdversarialLoss\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.modules.grl import WarmStartGradientReverseLayer\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef set_requires_grad(net, requires_grad=False):\n    \"\"\"\n    Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n    \"\"\"\n    for param in net.parameters():\n        param.requires_grad = requires_grad\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    source_classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                        pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    if args.phase == 'train' and args.pretrain is None:\n        # first pretrain the classifier wish source data\n        print(\"Pretraining the model on source domain.\")\n        args.pretrain = logger.get_checkpoint_path('pretrain')\n        pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                         pool_layer=pool_layer, finetune=not args.scratch).to(device)\n        pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum,\n                                 weight_decay=args.weight_decay, nesterov=True)\n        pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,\n                                         lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (\n                                             -args.lr_decay))\n        # start pretraining\n        for epoch in range(args.pretrain_epochs):\n            print(\"lr:\", pretrain_lr_scheduler.get_lr())\n            # pretrain for one epoch\n            utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,\n                                              pretrain_lr_scheduler, epoch, args,\n                                              device)\n            # validate to show pretrain process\n            utils.validate(val_loader, pretrain_model, args, device)\n\n        torch.save(pretrain_model.state_dict(), args.pretrain)\n        print(\"Pretraining process is done.\")\n\n    checkpoint = torch.load(args.pretrain, map_location='cpu')\n    source_classifier.load_state_dict(checkpoint)\n    target_classifier = copy.deepcopy(source_classifier)\n\n    # freeze source classifier\n    set_requires_grad(source_classifier, False)\n    source_classifier.freeze_bn()\n\n    domain_discri = DomainDiscriminator(in_feature=source_classifier.features_dim, hidden_size=1024).to(device)\n\n    # define loss function\n    grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=2., max_iters=1000, auto_step=True)\n    domain_adv = DomainAdversarialLoss(domain_discri, grl=grl).to(device)\n\n    # define optimizer and lr scheduler\n    # note that we only optimize target feature extractor\n    optimizer = SGD(target_classifier.get_parameters(optimize_head=False) + domain_discri.get_parameters(), args.lr,\n                    momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        target_classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(target_classifier.backbone, target_classifier.pool_layer,\n                                          target_classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, target_classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_source_iter, train_target_iter, source_classifier, target_classifier, domain_adv,\n              optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, target_classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(target_classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    target_classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, target_classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          source_model: ImageClassifier, target_model: ImageClassifier, domain_adv: DomainAdversarialLoss,\n          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    losses_transfer = AverageMeter('Transfer Loss', ':6.2f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_transfer, domain_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    target_model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, = next(train_source_iter)[:1]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        _, f_s = source_model(x_s)\n        _, f_t = target_model(x_t)\n        loss_transfer = domain_adv(f_s, f_t)\n\n        # Compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss_transfer.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        losses_transfer.update(loss_transfer.item(), x_s.size(0))\n        domain_acc = domain_adv.domain_discriminator_accuracy\n        domain_accs.update(domain_acc.item(), x_s.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='ADDA for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--pretrain', type=str, default=None,\n                        help='pretrain checkpoint for classification model')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate of the classifier', dest='lr')\n    parser.add_argument('--pretrain-lr', default=0.001, type=float, help='initial pretrain learning rate')\n    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--pretrain-epochs', default=3, type=int, metavar='N',\n                        help='number of total epochs (pretrain) to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='adda',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/adda.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python adda.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/adda/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python adda.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python adda.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 --lr-d 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/SVHN2MNIST\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/afn.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.normalization.afn import AdaptiveFeatureNorm, ImageClassifier\nfrom tllib.modules.entropy import entropy\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, args.num_blocks,\n                                 bottleneck_dim=args.bottleneck_dim, dropout_p=args.dropout_p,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n    adaptive_feature_norm = AdaptiveFeatureNorm(args.delta).to(device)\n\n    # define optimizer\n    # the learning rate is fixed according to origin paper\n    optimizer = SGD(classifier.get_parameters(), args.lr, weight_decay=args.weight_decay)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, adaptive_feature_norm, optimizer, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    norm_losses = AverageMeter('Norm Loss', ':3.2f')\n    src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')\n    tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n        y_t, f_t = model(x_t)\n\n        # classification loss\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        # norm loss\n        norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)\n        loss = cls_loss + norm_loss * args.trade_off_norm\n\n        # using entropy minimization\n        if args.trade_off_entropy:\n            y_t = F.softmax(y_t, dim=1)\n            entropy_loss = entropy(y_t, reduction='mean')\n            loss += entropy_loss * args.trade_off_entropy\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # update statistics\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        cls_losses.update(cls_loss.item(), x_s.size(0))\n        norm_losses.update(norm_loss.item(), x_s.size(0))\n        src_feature_norm.update(f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))\n        tgt_feature_norm.update(f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='AFN for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='ran.crop')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('-n', '--num-blocks', default=1, type=int, help='Number of basic blocks for classifier')\n    parser.add_argument('--bottleneck-dim', default=1000, type=int, help='Dimension of bottleneck')\n    parser.add_argument('--dropout-p', default=0.5, type=float,\n                        help='Dropout probability')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)',\n                        dest='weight_decay')\n    parser.add_argument('--trade-off-norm', default=0.05, type=float,\n                        help='the trade-off hyper-parameter for norm loss')\n    parser.add_argument('--trade-off-entropy', default=None, type=float,\n                        help='the trade-off hyper-parameter for entropy loss')\n    parser.add_argument('-r', '--delta', default=1, type=float, help='Increment for L2 norm')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='afn',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/afn.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python afn.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 -r 0.3 -b 36 \\\n    --epochs 10 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/afn/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python afn.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 20 -i 2500 --seed 0 --log logs/afn/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python afn.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 20 -i 2500 --seed 0 --log logs/afn_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool -r 0.3 --lr 0.01 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool -r 0.1 --lr 0.03 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool -r 0.1 --lr 0.1 --trade-off-entropy 0.03  -b 128 -i 2500 --scratch --seed 0 --log logs/afn/SVHN2MNIST\n\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/bsp.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.alignment.dann import DomainAdversarialLoss\nfrom tllib.alignment.bsp import BatchSpectralPenalizationLoss, ImageClassifier\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),\n                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    domain_adv = DomainAdversarialLoss(domain_discri).to(device)\n    bsp_penalty = BatchSpectralPenalizationLoss().to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    if args.pretrain is None:\n        # first pretrain the classifier wish source data\n        print(\"Pretraining the model on source domain.\")\n        args.pretrain = logger.get_checkpoint_path('pretrain')\n        pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                         pool_layer=pool_layer, finetune=not args.scratch).to(device)\n        pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum,\n                                 weight_decay=args.weight_decay, nesterov=True)\n        pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,\n                                         lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (\n                                             -args.lr_decay))\n        # start pretraining\n        for epoch in range(args.pretrain_epochs):\n            print(\"lr:\", pretrain_lr_scheduler.get_lr())\n            # pretrain for one epoch\n            utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,\n                                              pretrain_lr_scheduler, epoch, args,\n                                              device)\n            # validate to show pretrain process\n            utils.validate(val_loader, pretrain_model, args, device)\n\n        torch.save(pretrain_model.state_dict(), args.pretrain)\n        print(\"Pretraining process is done.\")\n\n    checkpoint = torch.load(args.pretrain, map_location='cpu')\n    classifier.load_state_dict(checkpoint)\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(\"lr:\", lr_scheduler.get_last_lr()[0])\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, domain_adv, bsp_penalty, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: ImageClassifier, domain_adv: DomainAdversarialLoss, bsp_penalty: BatchSpectralPenalizationLoss,\n          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    losses = AverageMeter('Loss', ':6.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs, domain_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = domain_adv(f_s, f_t)\n        bsp_loss = bsp_penalty(f_s, f_t)\n        domain_acc = domain_adv.domain_discriminator_accuracy\n        loss = cls_loss + transfer_loss * args.trade_off + bsp_loss * args.trade_off_bsp\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        domain_accs.update(domain_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='BSP for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--pretrain', type=str, default=None,\n                        help='pretrain checkpoint for classification model')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    parser.add_argument('--trade-off-bsp', default=2e-4, type=float,\n                        help='the trade-off hyper-parameter for bsp loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--pretrain-lr', default=0.001, type=float, help='initial pretrain learning rate')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--pretrain-epochs', default=3, type=int, metavar='N',\n                        help='number of total epochs(pretrain) to run (default: 3)')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='bsp',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/bsp.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/bsp/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.1 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/SVHN2MNIST --pretrain-epochs 1\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/cc_loss.py",
    "content": "\"\"\"\n@author: Ying Jin\n@contact: sherryying003@gmail.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.self_training.mcc import MinimumClassConfusionLoss, ImageClassifier\nfrom tllib.self_training.cc_loss import CCConsistency\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),\n                                                       random_horizontal_flip=not args.no_hflip,\n                                                       random_color_jitter=False, resize_size=args.resize_size,\n                                                       norm_mean=args.norm_mean, norm_std=args.norm_std)\n    weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                             random_horizontal_flip=not args.no_hflip,\n                                             random_color_jitter=False, resize_size=args.resize_size,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                               random_horizontal_flip=not args.no_hflip,\n                                               random_color_jitter=False, resize_size=args.resize_size,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std,\n                                               auto_augment=args.auto_augment)\n    train_target_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_source_transform: \", train_source_transform)\n    print(\"train_target_transform: \", train_target_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform,\n                          train_target_transform=train_target_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,\n                    nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(\"lr:\", lr_scheduler.get_last_lr()[0])\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: ImageClassifier, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    trans_losses = AverageMeter('Trans Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # define loss function\n    mcc = MinimumClassConfusionLoss(temperature=args.temperature)\n    consistency = CCConsistency(temperature=args.temperature, thr=args.thr)\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        (x_t, x_t_strong), labels_t = next(train_target_iter)[:2]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        x_t_strong = x_t_strong.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t, x_t_strong), dim=0)\n        y, f = model(x)\n        y_s, y_t, y_t_strong = y.chunk(3, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        mcc_loss = mcc(y_t)\n        consistency_loss, selec_ratio = consistency(y_t, y_t_strong)\n        loss = cls_loss + mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency\n        transfer_loss = mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        trans_losses.update(transfer_loss.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='CC Loss for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--temperature', default=2.5, type=float, help='parameter temperature scaling')\n    parser.add_argument('--thr', default=0.95, type=float, help='thr parameter for consistency loss')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for original mcc loss')\n    parser.add_argument('--trade_off_consistency', default=1., type=float,\n                        help='the trade-off hyper-parameter for consistency loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.005, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='mcc',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/cc_loss.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=5 python cc_loss.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 30 --seed 0 --lr 0.002 --per-class-eval --temperature 3.0 --train-resizing cen.crop --log logs/cc_loss/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 --seed 0 --temperature 2.5 --bottleneck-dim 2048 --log logs/cc_loss/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/cc_loss_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/SVHN2MNIST\n\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/cdan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n    classifier_feature_dim = classifier.features_dim\n\n    if args.randomized:\n        domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device)\n    else:\n        domain_discri = DomainDiscriminator(classifier_feature_dim * num_classes, hidden_size=1024).to(device)\n\n    all_parameters = classifier.get_parameters() + domain_discri.get_parameters()\n    # define optimizer and lr scheduler\n    optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    domain_adv = ConditionalDomainAdversarialLoss(\n        domain_discri, entropy_conditioning=args.entropy,\n        num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,\n        randomized_dim=args.randomized_dim\n    ).to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(\"lr:\", lr_scheduler.get_last_lr()[0])\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          domain_adv: ConditionalDomainAdversarialLoss, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    trans_losses = AverageMeter('Trans Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = domain_adv(y_s, f_s, y_t, f_t)\n        domain_acc = domain_adv.domain_discriminator_accuracy\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc, x_s.size(0))\n        domain_accs.update(domain_acc, x_s.size(0))\n        trans_losses.update(transfer_loss.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='CDAN for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('-r', '--randomized', action='store_true',\n                        help='using randomized multi-linear-map (default: False)')\n    parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,\n                        help='randomized dimension when using randomized multi-linear-map (default: 1024)')\n    parser.add_argument('--entropy', default=False, action='store_true', help='use entropy conditioning')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='cdan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/cdan.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Pr\n    \n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/cdan/VisDA2017\n\n# ResNet101, DomainNet, Single Source\n# Use randomized multi-linear-map to decrease GPU memory usage\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --log logs/cdan_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/SVHN2MNIST\n\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/dan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.alignment.dan import MultipleKernelMaximumMeanDiscrepancy, ImageClassifier\nfrom tllib.modules.kernels import GaussianKernel\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy(\n        kernels=[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],\n        linear=not args.non_linear\n    )\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, mkmmd_loss, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          mkmmd_loss: MultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    trans_losses = AverageMeter('Trans Loss', ':5.4f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    mkmmd_loss.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n        y_t, f_t = model(x_t)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = mkmmd_loss(f_s, f_t)\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        trans_losses.update(transfer_loss.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='DAN for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--non-linear', default=False, action='store_true',\n                        help='whether not use the linear version')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='dan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/dan.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_W2A\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_W2D\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python dan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 20 -i 500 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/dan/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python dan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python dan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --log logs/dan_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/SVHN2MNIST\n\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/dann.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),\n                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    domain_adv = DomainAdversarialLoss(domain_discri).to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(\"lr:\", lr_scheduler.get_last_lr()[0])\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    losses = AverageMeter('Loss', ':6.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs, domain_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = domain_adv(f_s, f_t)\n        domain_acc = domain_adv.domain_discriminator_accuracy\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        domain_accs.update(domain_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='DANN for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='dann',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/dann.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/dann/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python dann.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/SVHN2MNIST\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/erm.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\nfrom tllib.utils.data import ForeverDataIterator\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        utils.empirical_risk_minimization(train_source_iter, classifier, optimizer, lr_scheduler, epoch, args, device)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Source Only for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='src_only',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/erm.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 20 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/erm/VisDA2017\n\n# ResNet101, DomainNet, Oracle\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i -t i -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s q -t q -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_s\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 20 -i 2500 --seed 0 --log logs/erm/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python erm.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 20 -i 2500 --seed 0 --log logs/erm_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/SVHN2MNIST\n\n\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/fixmatch.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\nclass ImageClassifier(Classifier):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs):\n        bottleneck = nn.Sequential(\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\"\"\"\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        predictions = self.head(f)\n        return predictions\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),\n                                                       random_horizontal_flip=not args.no_hflip,\n                                                       random_color_jitter=False, resize_size=args.resize_size,\n                                                       norm_mean=args.norm_mean, norm_std=args.norm_std)\n    weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                             random_horizontal_flip=not args.no_hflip,\n                                             random_color_jitter=False, resize_size=args.resize_size,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                               random_horizontal_flip=not args.no_hflip,\n                                               random_color_jitter=False, resize_size=args.resize_size,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std,\n                                               auto_augment=args.auto_augment)\n    train_target_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n\n    print(\"train_source_transform: \", train_source_transform)\n    print(\"train_target_transform: \", train_target_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform,\n                          train_target_transform=train_target_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.unlabeled_batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,\n                    nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(\"lr:\", lr_scheduler.get_last_lr())\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: ImageClassifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    cls_losses = AverageMeter('Cls Loss', ':6.2f')\n    self_training_losses = AverageMeter('Self Training Loss', ':6.2f')\n    losses = AverageMeter('Loss', ':6.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')\n    pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,\n         pseudo_label_ratios],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        (x_t, x_t_strong), labels_t = next(train_target_iter)[:2]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        x_t_strong = x_t_strong.to(device)\n        labels_s = labels_s.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        with torch.no_grad():\n            y_t = model(x_t)\n\n        # cross entropy loss\n        y_s = model(x_s)\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        cls_loss.backward()\n\n        # self-training loss\n        y_t_strong = model(x_t_strong)\n        self_training_loss, mask, pseudo_labels = self_training_criterion(y_t_strong, y_t)\n        self_training_loss = args.trade_off * self_training_loss\n        self_training_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + self_training_loss\n        losses.update(loss.item(), x_s.size(0))\n        cls_losses.update(cls_loss.item(), x_s.size(0))\n        self_training_losses.update(self_training_loss.item(), x_s.size(0))\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n\n        # ratio of pseudo labels\n        n_pseudo_labels = mask.sum()\n        ratio = n_pseudo_labels / x_t.size(0)\n        pseudo_label_ratios.update(ratio.item() * 100, x_t.size(0))\n\n        # accuracy of pseudo labels\n        if n_pseudo_labels > 0:\n            pseudo_labels = pseudo_labels * mask - (1 - mask)\n            n_correct = (pseudo_labels == labels_t).float().sum()\n            pseudo_label_acc = n_correct / n_pseudo_labels * 100\n            pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='FixMatch for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('-ub', '--unlabeled-batch-size', default=32, type=int,\n                        help='mini-batch size of unlabeled data (target domain) (default: 32)')\n    parser.add_argument('--threshold', default=0.9, type=float,\n                        help='confidence threshold')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='fixmatch',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/fixmatch.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s A -t W -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s D -t W -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s W -t D -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s A -t D -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s D -t A -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s W -t A -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 --train-resizing cen.crop \\\n    --lr 0.003 --threshold 0.8 --bottleneck-dim 2048 --epochs 20 -ub 64 --seed 0 --per-class-eval --log logs/fixmatch/VisDA2017\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/jan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.alignment.jan import JointMultipleKernelMaximumMeanDiscrepancy, ImageClassifier, Theta\nfrom tllib.modules.kernels import GaussianKernel\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    # define loss function\n    if args.adversarial:\n        thetas = [Theta(dim).to(device) for dim in (classifier.features_dim, num_classes)]\n    else:\n        thetas = None\n    jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy(\n        kernels=(\n            [GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],\n            (GaussianKernel(sigma=0.92, track_running_stats=False),)\n        ),\n        linear=args.linear, thetas=thetas\n    ).to(device)\n\n    parameters = classifier.get_parameters()\n    if thetas is not None:\n        parameters += [{\"params\": theta.parameters(), 'lr': 0.1} for theta in thetas]\n\n    # define optimizer\n    optimizer = SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, jmmd_loss, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    trans_losses = AverageMeter('Trans Loss', ':5.4f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    jmmd_loss.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = jmmd_loss(\n            (f_s, F.softmax(y_s, dim=1)),\n            (f_t, F.softmax(y_t, dim=1))\n        )\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        trans_losses.update(transfer_loss.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='JAN for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--linear', default=False, action='store_true',\n                        help='whether use the linear version')\n    parser.add_argument('--adversarial', default=False, action='store_true',\n                        help='whether use adversarial theta')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='jan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/jan.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_W2A\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_W2D\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python jan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 20 -i 500 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/jan/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python jan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python jan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/jan_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python jan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python jan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/USPS2MNIST\nCUDA_VISIBLE_DEVICES=4 python jan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/SVHN2MNIST\n\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/mcc.py",
    "content": "\"\"\"\n@author: Ying Jin\n@contact: sherryying003@gmail.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.self_training.mcc import MinimumClassConfusionLoss, ImageClassifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,\n                    nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(\"lr:\", lr_scheduler.get_last_lr()[0])\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, mcc_loss, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: ImageClassifier, mcc: MinimumClassConfusionLoss, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    trans_losses = AverageMeter('Trans Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = mcc(y_t)\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        trans_losses.update(transfer_loss.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='MCC for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--temperature', default=2.5, type=float, help='parameter temperature scaling')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.005, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='mcc',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/mcc.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=5 python mcc.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 30 --seed 0 --lr 0.002 --per-class-eval --temperature 3.0 --train-resizing cen.crop --log logs/mcc/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 --seed 0 --temperature 2.5 --bottleneck-dim 2048 --log logs/mcc/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/mcc_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/SVHN2MNIST\n\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/mcd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\nfrom typing import Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.alignment.mcd import ImageClassifierHead, entropy, classifier_discrepancy\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    G = utils.get_model(args.arch, pretrain=not args.scratch).to(device)  # feature extractor\n    # two image classifier heads\n    pool_layer = nn.Identity() if args.no_pool else None\n    F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device)\n    F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device)\n\n    # define optimizer\n    # the learning rate is fixed according to origin paper\n    optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005)\n    optimizer_f = SGD([\n        {\"params\": F1.parameters()},\n        {\"params\": F2.parameters()},\n    ], momentum=0.9, lr=args.lr, weight_decay=0.0005)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        G.load_state_dict(checkpoint['G'])\n        F1.load_state_dict(checkpoint['F1'])\n        F2.load_state_dict(checkpoint['F2'])\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(G, F1.pool_layer).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = validate(test_loader, G, F1, F2, args)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    best_results = None\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g, optimizer_f, epoch, args)\n\n        # evaluate on validation set\n        results = validate(val_loader, G, F1, F2, args)\n\n        # remember best acc@1 and save checkpoint\n        torch.save({\n            'G': G.state_dict(),\n            'F1': F1.state_dict(),\n            'F2': F2.state_dict()\n        }, logger.get_checkpoint_path('latest'))\n        if max(results) > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n            best_acc1 = max(results)\n            best_results = results\n\n    print(\"best_acc1 = {:3.1f}, results = {}\".format(best_acc1, best_results))\n\n    # evaluate on test set\n    checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n    G.load_state_dict(checkpoint['G'])\n    F1.load_state_dict(checkpoint['F1'])\n    F2.load_state_dict(checkpoint['F2'])\n    results = validate(test_loader, G, F1, F2, args)\n    print(\"test_acc1 = {:3.1f}\".format(max(results)))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead,\n          optimizer_g: SGD, optimizer_f: SGD, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    trans_losses = AverageMeter('Trans Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    G.train()\n    F1.train()\n    F2.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n        x = torch.cat((x_s, x_t), dim=0)\n        assert x.requires_grad is False\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # Step A train all networks to minimize loss on source domain\n        optimizer_g.zero_grad()\n        optimizer_f.zero_grad()\n\n        g = G(x)\n        y_1 = F1(g)\n        y_2 = F2(g)\n        y1_s, y1_t = y_1.chunk(2, dim=0)\n        y2_s, y2_t = y_2.chunk(2, dim=0)\n\n        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)\n        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \\\n               (entropy(y1_t) + entropy(y2_t)) * args.trade_off_entropy\n        loss.backward()\n        optimizer_g.step()\n        optimizer_f.step()\n\n        # Step B train classifier to maximize discrepancy\n        optimizer_g.zero_grad()\n        optimizer_f.zero_grad()\n\n        g = G(x)\n        y_1 = F1(g)\n        y_2 = F2(g)\n        y1_s, y1_t = y_1.chunk(2, dim=0)\n        y2_s, y2_t = y_2.chunk(2, dim=0)\n        y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)\n        loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \\\n               (entropy(y1_t) + entropy(y2_t)) * args.trade_off_entropy - \\\n               classifier_discrepancy(y1_t, y2_t) * args.trade_off\n        loss.backward()\n        optimizer_f.step()\n\n        # Step C train genrator to minimize discrepancy\n        for k in range(args.num_k):\n            optimizer_g.zero_grad()\n            g = G(x)\n            y_1 = F1(g)\n            y_2 = F2(g)\n            y1_s, y1_t = y_1.chunk(2, dim=0)\n            y2_s, y2_t = y_2.chunk(2, dim=0)\n            y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)\n            mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off\n            mcd_loss.backward()\n            optimizer_g.step()\n\n        cls_acc = accuracy(y1_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        trans_losses.update(mcd_loss.item(), x_s.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\ndef validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead,\n             F2: ImageClassifierHead, args: argparse.Namespace) -> Tuple[float, float]:\n    batch_time = AverageMeter('Time', ':6.3f')\n    top1_1 = AverageMeter('Acc_1', ':6.2f')\n    top1_2 = AverageMeter('Acc_2', ':6.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, top1_1, top1_2],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    G.eval()\n    F1.eval()\n    F2.eval()\n\n    if args.per_class_eval:\n        confmat = ConfusionMatrix(len(args.class_names))\n    else:\n        confmat = None\n\n    with torch.no_grad():\n        end = time.time()\n        for i, data in enumerate(val_loader):\n            images, target = data[:2]\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            g = G(images)\n            y1, y2 = F1(g), F2(g)\n\n            # measure accuracy and record loss\n            acc1, = accuracy(y1, target)\n            acc2, = accuracy(y2, target)\n            if confmat:\n                confmat.update(target, y1.argmax(1))\n            top1_1.update(acc1.item(), images.size(0))\n            top1_2.update(acc2.item(), images.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}'\n              .format(top1_1=top1_1, top1_2=top1_2))\n        if confmat:\n            print(confmat.format(args.class_names))\n\n    return top1_1.avg, top1_2.avg\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='MCD for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true',\n                        help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+',\n                        default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+',\n                        default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int)\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    parser.add_argument('--trade-off-entropy', default=0.01, type=float,\n                        help='the trade-off hyper-parameter for entropy loss')\n    parser.add_argument('--num-k', type=int, default=4, metavar='K',\n                        help='how many steps to repeat the generator update')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='mcd',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/mcd.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\n# We found MCD loss is sensitive to class number,\n# thus, when the class number increase, please increase trade-off correspondingly.\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_W2A\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_W2D\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 20 --center-crop --seed 0 -i 500 --per-class-eval --train-resizing cen.crop --log logs/mcd/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t i -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2i\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 100.0 --log logs/mcd/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --trade-off 500.0 --log logs/mcd_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Pr\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/SVHN2MNIST\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/mdd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport os.path as osp\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.alignment.mdd import ClassificationMarginDisparityDiscrepancy \\\n    as MarginDisparityDiscrepancy, ImageClassifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,\n                                                random_horizontal_flip=not args.no_hflip,\n                                                random_color_jitter=False, resize_size=args.resize_size,\n                                                norm_mean=args.norm_mean, norm_std=args.norm_std)\n    val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,\n                                            norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 width=args.bottleneck_dim, pool_layer=pool_layer).to(device)\n    mdd = MarginDisparityDiscrepancy(args.margin).to(device)\n\n    # define optimizer and lr_scheduler\n    # The learning rate of the classiﬁers are set 10 times to that of the feature extractor by default.\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, mdd, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          classifier: ImageClassifier, mdd: MarginDisparityDiscrepancy, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    trans_losses = AverageMeter('Trans Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    classifier.train()\n    mdd.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_t, = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        outputs, outputs_adv = classifier(x)\n        y_s, y_t = outputs.chunk(2, dim=0)\n        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)\n\n        # compute cross entropy loss on source domain\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        # compute margin disparity discrepancy between domains\n        # for adversarial classifier, minimize negative mdd is equal to maximize mdd\n        transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv)\n        loss = cls_loss + transfer_loss * args.trade_off\n        classifier.step()\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        trans_losses.update(transfer_loss.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='MDD for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    parser.add_argument('--resize-size', type=int, default=224,\n                        help='the image size after resizing')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.08 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean')\n    parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int)\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--margin', type=float, default=4., help=\"margin gamma\")\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.004, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0002, type=float)\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='mdd',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/mdd.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 --epochs 30 \\\n    --bottleneck-dim 1024 --seed 0 --train-resizing cen.crop --per-class-eval -b 36 --log logs/mdd/VisDA2017\n\n# ResNet101, DomainNet, Single Source\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2p\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2r\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2s\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2c\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2r\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2s\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2c\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2p\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2s\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2c\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2p\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2r\n\n# ResNet50, ImageNet200 -> ImageNetR\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 40 -i 2500 -p 500 \\\n  --bottleneck-dim 2048 --seed 0 --lr 0.004 --train-resizing cen.crop --log logs/mdd/ImageNet_IN2INR\n\n# ig_resnext101_32x8d, ImageNet -> ImageNetSketch\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d \\\n  --epochs 40 -i 2500 -p 500 --bottleneck-dim 2048 --margin 2. --seed 0 --lr 0.004 --train-resizing cen.crop \\\n  --log logs/mdd_ig_resnext101_32x8d/ImageNet_IN2sketch\n\n# Vision Transformer, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Pr\n\n# ResNet50, Office-Home, Multi Source\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Ar\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Cl\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Pr\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Rw\n\n# ResNet101, DomainNet, Multi Source\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2c\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2i\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2p\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2q\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2r\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2s\n\n# Digits\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/MNIST2USPS\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/USPS2MNIST\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \\\n  --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/SVHN2MNIST\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/requirements.txt",
    "content": "timm"
  },
  {
    "path": "examples/domain_adaptation/image_classification/self_ensemble.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.self_training.pi_model import L2ConsistencyLoss\nfrom tllib.self_training.mean_teacher import EMATeacher\nfrom tllib.self_training.self_ensemble import ClassBalanceLoss, ImageClassifier\nfrom tllib.vision.transforms import ResizeImage, MultipleApply\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    # we find self ensemble is sensitive to data augmentation. The following\n    # data augmentation performs well for evaluated datasets\n    normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        ResizeImage(256),\n        T.RandomCrop(224),\n        T.RandomHorizontalFlip(),\n        T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5),\n        T.RandomGrayscale(),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        ResizeImage(256),\n        T.CenterCrop(224),\n        T.ToTensor(),\n        normalize\n    ])\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target,\n                          train_transform, val_transform, MultipleApply([train_transform, val_transform]))\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                 pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = Adam(classifier.get_parameters(), args.lr)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    if args.pretrain is None:\n        # first pretrain the classifier wish source data\n        print(\"Pretraining the model on source domain.\")\n        args.pretrain = logger.get_checkpoint_path('pretrain')\n        pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                         pool_layer=pool_layer, finetune=not args.scratch).to(device)\n        pretrain_optimizer = Adam(pretrain_model.get_parameters(), args.pretrain_lr)\n        pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,\n                                         lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (\n                                             -args.lr_decay))\n\n        # start pretraining\n        for epoch in range(args.pretrain_epochs):\n            # pretrain for one epoch\n            utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,\n                                              pretrain_lr_scheduler, epoch, args,\n                                              device)\n            # validate to show pretrain process\n            utils.validate(val_loader, pretrain_model, args, device)\n\n        torch.save(pretrain_model.state_dict(), args.pretrain)\n        print(\"Pretraining process is done.\")\n\n    checkpoint = torch.load(args.pretrain, map_location='cpu')\n    classifier.load_state_dict(checkpoint)\n    teacher = EMATeacher(classifier, alpha=args.alpha)\n    consistency_loss = L2ConsistencyLoss().to(device)\n    class_balance_loss = ClassBalanceLoss(num_classes).to(device)\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, teacher, consistency_loss, class_balance_loss,\n              optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          teacher: EMATeacher, consistency_loss, class_balance_loss,\n          optimizer: Adam, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    cons_losses = AverageMeter('Cons Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, cls_losses, cons_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    teacher.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        (x_t1, x_t2), = next(train_target_iter)[:1]\n\n        x_s = x_s.to(device)\n        x_t1 = x_t1.to(device)\n        x_t2 = x_t2.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, _ = model(x_s)\n        y_t, _ = model(x_t1)\n        y_t_teacher, _ = teacher(x_t2)\n\n        # classification loss\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        # compute output and mask\n        p_t = F.softmax(y_t, dim=1)\n        p_t_teacher = F.softmax(y_t_teacher, dim=1)\n        confidence, _ = p_t_teacher.max(dim=1)\n        mask = (confidence > args.threshold).float()\n\n        # consistency loss\n        cons_loss = consistency_loss(p_t, p_t_teacher, mask)\n        # balance loss\n        balance_loss = class_balance_loss(p_t) * mask.mean()\n\n        loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # update teacher\n        teacher.update()\n\n        # update statistics\n        cls_acc = accuracy(y_s, labels_s)[0]\n        cls_losses.update(cls_loss.item(), x_s.size(0))\n        cons_losses.update(cons_loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Self Ensemble for Unsupervised Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')\n    parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--pretrain', type=str, default=None,\n                        help='pretrain checkpoint for classification model')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--pretrain-lr', '--pretrain-learning-rate', default=3e-5, type=float,\n                        help='initial pretrain learning rate', dest='pretrain_lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--alpha', default=0.99, type=float, help='ema decay rate (default: 0.99)')\n    parser.add_argument('--threshold', default=0.8, type=float, help='confidence threshold')\n    parser.add_argument('--trade-off-cons', default=3, type=float, help='trade off parameter for consistency loss')\n    parser.add_argument('--trade-off-balance', default=0.01, type=float,\n                        help='trade off parameter for class balance loss')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--pretrain-epochs', default=0, type=int, metavar='N',\n                        help='number of total epochs(pretrain) to run')\n    parser.add_argument('--epochs', default=10, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='self_ensemble',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/self_ensemble.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, Office31, Single Source\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s A -t W -a resnet50 --seed 1 --log logs/self_ensemble/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s D -t W -a resnet50 --seed 1 --log logs/self_ensemble/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s W -t D -a resnet50 --seed 1 --log logs/self_ensemble/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s A -t D -a resnet50 --seed 1 --log logs/self_ensemble/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s D -t A -a resnet50 --seed 1 --log logs/self_ensemble/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s W -t A -a resnet50 --seed 1 --log logs/self_ensemble/Office31_W2A\n\n# ResNet50, Office-Home, Single Source\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Pr\n\n# ResNet101, VisDA-2017, Single Source\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \\\n    --epochs 20 --seed 0 --per-class-eval --log logs/self_ensemble/VisDA2017 --lr-gamma 0.0002 -b 32\n\n# Office-Home on Vision Transformer\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Pr\n"
  },
  {
    "path": "examples/domain_adaptation/image_classification/utils.py",
    "content": "\"\"\"\n@author: Junguang Jiang, Baixu Chen\n@contact: JiangJunguang1123@outlook.com, cbx_99_hasta@outlook.com\n\"\"\"\nimport sys\nimport os.path as osp\nimport time\nfrom PIL import Image\n\nimport timm\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as T\nfrom timm.data.auto_augment import auto_augment_transform, rand_augment_transform\n\nsys.path.append('../../..')\nimport tllib.vision.datasets as datasets\nimport tllib.vision.models as models\nfrom tllib.vision.transforms import ResizeImage\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.vision.datasets.imagelist import MultipleDomainsDataset\n\n\ndef get_model_names():\n    return sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    ) + timm.list_models()\n\n\ndef get_model(model_name, pretrain=True):\n    if model_name in models.__dict__:\n        # load models from tllib.vision.models\n        backbone = models.__dict__[model_name](pretrained=pretrain)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=pretrain)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n    return backbone\n\n\ndef get_dataset_names():\n    return sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    ) + ['Digits']\n\n\ndef get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None):\n    if train_target_transform is None:\n        train_target_transform = train_source_transform\n    if dataset_name == \"Digits\":\n        train_source_dataset = datasets.__dict__[source[0]](osp.join(root, source[0]), download=True,\n                                                            transform=train_source_transform)\n        train_target_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), download=True,\n                                                            transform=train_target_transform)\n        val_dataset = test_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), split='test',\n                                                                  download=True, transform=val_transform)\n        class_names = datasets.MNIST.get_classes()\n        num_classes = len(class_names)\n    elif dataset_name in datasets.__dict__:\n        # load datasets from tllib.vision.datasets\n        dataset = datasets.__dict__[dataset_name]\n\n        def concat_dataset(tasks, start_idx, **kwargs):\n            # return ConcatDataset([dataset(task=task, **kwargs) for task in tasks])\n            return MultipleDomainsDataset([dataset(task=task, **kwargs) for task in tasks], tasks,\n                                          domain_ids=list(range(start_idx, start_idx + len(tasks))))\n\n        train_source_dataset = concat_dataset(root=root, tasks=source, download=True, transform=train_source_transform,\n                                              start_idx=0)\n        train_target_dataset = concat_dataset(root=root, tasks=target, download=True, transform=train_target_transform,\n                                              start_idx=len(source))\n        val_dataset = concat_dataset(root=root, tasks=target, download=True, transform=val_transform,\n                                     start_idx=len(source))\n        if dataset_name == 'DomainNet':\n            test_dataset = concat_dataset(root=root, tasks=target, split='test', download=True, transform=val_transform,\n                                          start_idx=len(source))\n        else:\n            test_dataset = val_dataset\n        class_names = train_source_dataset.datasets[0].classes\n        num_classes = len(class_names)\n    else:\n        raise NotImplementedError(dataset_name)\n    return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names\n\n\ndef validate(val_loader, model, args, device) -> float:\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    top1 = AverageMeter('Acc@1', ':6.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, top1],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    if args.per_class_eval:\n        confmat = ConfusionMatrix(len(args.class_names))\n    else:\n        confmat = None\n\n    with torch.no_grad():\n        end = time.time()\n        for i, data in enumerate(val_loader):\n            images, target = data[:2]\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            loss = F.cross_entropy(output, target)\n\n            # measure accuracy and record loss\n            acc1, = accuracy(output, target, topk=(1,))\n            if confmat:\n                confmat.update(target, output.argmax(1))\n            losses.update(loss.item(), images.size(0))\n            top1.update(acc1.item(), images.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))\n        if confmat:\n            print(confmat.format(args.class_names))\n\n    return top1.avg\n\n\ndef get_train_transform(resizing='default', scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), random_horizontal_flip=True,\n                        random_color_jitter=False, resize_size=224, norm_mean=(0.485, 0.456, 0.406),\n                        norm_std=(0.229, 0.224, 0.225), auto_augment=None):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take a random resized crop of size 224;\n        - cen.crop: resize the image to 256 and take the center crop of size 224;\n        - res: resize the image to 224;\n    \"\"\"\n    transformed_img_size = 224\n    if resizing == 'default':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.RandomResizedCrop(224, scale=scale, ratio=ratio)\n        ])\n    elif resizing == 'cen.crop':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224)\n        ])\n    elif resizing == 'ran.crop':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.RandomCrop(224)\n        ])\n    elif resizing == 'res.':\n        transform = ResizeImage(resize_size)\n        transformed_img_size = resize_size\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if auto_augment:\n        aa_params = dict(\n            translate_const=int(transformed_img_size * 0.45),\n            img_mean=tuple([min(255, round(255 * x)) for x in norm_mean]),\n            interpolation=Image.BILINEAR\n        )\n        if auto_augment.startswith('rand'):\n            transforms.append(rand_augment_transform(auto_augment, aa_params))\n        else:\n            transforms.append(auto_augment_transform(auto_augment, aa_params))\n    elif random_color_jitter:\n        transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=norm_mean, std=norm_std)\n    ])\n    return T.Compose(transforms)\n\n\ndef get_val_transform(resizing='default', resize_size=224,\n                      norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take the center crop of size 224;\n        – res.: resize the image to 224\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224),\n        ])\n    elif resizing == 'res.':\n        transform = ResizeImage(resize_size)\n    else:\n        raise NotImplementedError(resizing)\n    return T.Compose([\n        transform,\n        T.ToTensor(),\n        T.Normalize(mean=norm_mean, std=norm_std)\n    ])\n\n\ndef empirical_risk_minimization(train_source_iter, model, optimizer, lr_scheduler, epoch, args, device):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)[:2]\n        x_s = x_s.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        loss = cls_loss\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/README.md",
    "content": "# Unsupervised Domain Adaptation for Image Regression Tasks\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to better reproduce the benchmark results.\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [DSprites](https://github.com/deepmind/dsprites-dataset)\n- [MPI3D](https://github.com/rr-learning/disentanglement_dataset)\n\n## Supported Methods\n\nSupported methods include:\n\n- [Disparity Discrepancy (DD)](https://arxiv.org/abs/1904.05801)\n- [Representation Subspace Distance (RSD)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Representation-Subspace-Distance-for-Domain-Adaptation-Regression-icml21.pdf)\n\n## Experiment and Results\n\nThe shell files give the script to reproduce the benchmark results with specified hyper-parameters.\nFor example, if you want to train DD on DSprites, use the following script\n\n```shell script\n# Train a DD on DSprites C->N task using ResNet 18.\n# Assume you have put the datasets under the path `data/dSprites`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/mdd/dSprites_C2N --wd 0.0005\n```\n\n**Notations**\n\n- ``Origin`` means the accuracy reported by the original paper.\n- ``Avg`` is the accuracy reported by Transfer-Learn.\n- ``ERM`` refers to the model trained with data from the source domain.\n- ``Oracle`` refers to the model trained with data from the target domain.\n\nLabels are all normalized to [0, 1] to eliminate the effects of diverse scale in regression values.\n\nWe repeat experiments on DD for three times and report the average error of the ``final`` epoch.\n\n\n### dSprites error on ResNet-18\n\n| Methods     | Avg   | C → N | C → S | N → C | N → S | S → C | S → N |\n|-------------|-------|-------|-------|-------|-------|-------|-------|\n| ERM | 0.157 | 0.232 | 0.271 | 0.081 | 0.22  | 0.038 | 0.092 |\n| DD          | 0.057 | 0.047 | 0.08  | 0.03  | 0.095 | 0.053 | 0.037 |\n\n### MPI3D error on ResNet-18\n\n| Methods     | Avg   | RL → RC | RL → T | RC → RL | RC → T | T → RL | T → RC |\n|-------------|-------|---------|--------|---------|--------|--------|--------|\n| ERM | 0.176 | 0.232   | 0.271  | 0.081   | 0.22   | 0.038  | 0.092  |\n| DD          | 0.03  | 0.086   | 0.029  | 0.057   | 0.189  | 0.131  | 0.087  |\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{MDD,\n    title={Bridging theory and algorithm for domain adaptation},\n    author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael},\n    booktitle={ICML},\n    year={2019},\n}\n\n@inproceedings{RSD,\n    title={Representation Subspace Distance for Domain Adaptation Regression},  \n    author={Chen, Xinyang and Wang, Sinan and Wang, Jianmin and Long, Mingsheng}, \n    booktitle={ICML}, \n    year={2021} \n}\n```\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/dann.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.regressor import Regressor\nfrom tllib.alignment.dann import DomainAdversarialLoss\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nimport tllib.vision.datasets.regression as datasets\nimport tllib.vision.models as models\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n\n    dataset = datasets.__dict__[args.data]\n    train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,\n                                   transform=train_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,\n                                   transform=train_transform)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = models.__dict__[args.arch](pretrained=True)\n    if args.normalization == 'IN':\n        backbone = utils.convert_model(backbone)\n    num_factors = train_source_dataset.num_factors\n    bottleneck = nn.Sequential(\n        nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n        nn.Flatten(),\n        nn.Linear(backbone.out_features, 256),\n        nn.ReLU()\n    )\n    regressor = Regressor(backbone=backbone, num_factors=num_factors, bottleneck=bottleneck, bottleneck_dim=256).to(\n        device)\n    print(regressor)\n    domain_discri = DomainDiscriminator(in_feature=regressor.features_dim, hidden_size=1024).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(regressor.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum,\n                    weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    dann = DomainAdversarialLoss(domain_discri).to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        regressor.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        # extract features from both domains\n        feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n        print(mae)\n        return\n\n    # start training\n    best_mae = 100000.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        print(\"lr\", lr_scheduler.get_lr())\n        train(train_source_iter, train_target_iter, regressor, dann, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n\n        # remember best mae and save checkpoint\n        torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))\n        if mae < best_mae:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_mae = min(mae, best_mae)\n        print(\"mean MAE {:6.3f} best MAE {:6.3f}\".format(mae, best_mae))\n\n    print(\"best_mae = {:6.3f}\".format(best_mae))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: Regressor, domain_adv: DomainAdversarialLoss, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    mse_losses = AverageMeter('MSE Loss', ':6.3f')\n    dann_losses = AverageMeter('DANN Loss', ':6.3f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')\n    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, mse_losses, dann_losses, mae_losses_s, mae_losses_t, domain_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, labels_s = next(train_source_iter)\n        x_s = x_s.to(device)\n        labels_s = labels_s.to(device).float()\n        x_t, labels_t = next(train_target_iter)\n        x_t = x_t.to(device)\n        labels_t = labels_t.to(device).float()\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n        y_t, f_t = model(x_t)\n\n        mse_loss = F.mse_loss(y_s, labels_s)\n        mae_loss_s = F.l1_loss(y_s, labels_s)\n        mae_loss_t = F.l1_loss(y_t, labels_t)\n        transfer_loss = domain_adv(f_s, f_t)\n        loss = mse_loss + transfer_loss * args.trade_off\n        domain_acc = domain_adv.domain_discriminator_accuracy\n\n        mse_losses.update(mse_loss.item(), x_s.size(0))\n        dann_losses.update(transfer_loss.item(), x_s.size(0))\n        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))\n        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))\n        domain_accs.update(domain_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='DANN for Regression Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',\n                        help='dataset: ' + ' | '.join(dataset_names) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-size', type=int, default=128)\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet18)')\n    parser.add_argument('--normalization', default='BN', type=str, choices=[\"BN\", \"IN\"])\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.001, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='dann',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/dann.sh",
    "content": "# DSprites\nCUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2N\nCUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2S\nCUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2C\nCUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2S\nCUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2C\nCUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2N\n\n# MPI3D\nCUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2RC --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2T --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2RL --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2T --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RL --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RC --resize-size 224\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/dd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\nimport torch.nn.functional as F\nimport torch.nn as nn\n\nimport utils\nfrom tllib.alignment.mdd import RegressionMarginDisparityDiscrepancy as MarginDisparityDiscrepancy, ImageRegressor\nimport tllib.vision.datasets.regression as datasets\nimport tllib.vision.models as models\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n\n    dataset = datasets.__dict__[args.data]\n    train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,\n                                   transform=train_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,\n                                   transform=train_transform)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    num_factors = train_source_dataset.num_factors\n    backbone = models.__dict__[args.arch](pretrained=True)\n    bottleneck_dim = args.bottleneck_dim\n    if args.normalization == 'IN':\n        backbone = utils.convert_model(backbone)\n        bottleneck = nn.Sequential(\n            nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.InstanceNorm2d(bottleneck_dim),\n            nn.ReLU(),\n        )\n        head = nn.Sequential(\n            nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.InstanceNorm2d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.InstanceNorm2d(bottleneck_dim),\n            nn.ReLU(),\n            nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            nn.Flatten(),\n            nn.Linear(bottleneck_dim, num_factors),\n            nn.Sigmoid()\n        )\n        for layer in head:\n            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n                nn.init.normal_(layer.weight, 0, 0.01)\n                nn.init.constant_(layer.bias, 0)\n        adv_head = nn.Sequential(\n            nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.InstanceNorm2d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.InstanceNorm2d(bottleneck_dim),\n            nn.ReLU(),\n            nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            nn.Flatten(),\n            nn.Linear(bottleneck_dim, num_factors),\n            nn.Sigmoid()\n        )\n        for layer in adv_head:\n            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n                nn.init.normal_(layer.weight, 0, 0.01)\n                nn.init.constant_(layer.bias, 0)\n        regressor = ImageRegressor(backbone, num_factors, bottleneck=bottleneck, head=head, adv_head=adv_head,\n                                   bottleneck_dim=bottleneck_dim, width=bottleneck_dim)\n    else:\n        regressor = ImageRegressor(backbone, num_factors,\n                                   bottleneck_dim=bottleneck_dim, width=bottleneck_dim)\n\n    regressor = regressor.to(device)\n    print(regressor)\n    mdd = MarginDisparityDiscrepancy(args.margin).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        regressor.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        # extract features from both domains\n        feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck, regressor.head[:-2]).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n        print(mae)\n        return\n\n    # start training\n    best_mae = 100000.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        print(\"lr\", lr_scheduler.get_lr())\n        train(train_source_iter, train_target_iter, regressor, mdd, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n\n        # remember best mae and save checkpoint\n        torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))\n        if mae < best_mae:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_mae = min(mae, best_mae)\n        print(\"mean MAE {:6.3f} best MAE {:6.3f}\".format(mae, best_mae))\n\n    print(\"best_mae = {:6.3f}\".format(best_mae))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model, mdd: MarginDisparityDiscrepancy, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    source_losses = AverageMeter('Source Loss', ':6.3f')\n    trans_losses = AverageMeter('Trans Loss', ':6.3f')\n    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')\n    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, source_losses, trans_losses, mae_losses_s, mae_losses_t],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    mdd.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, labels_s = next(train_source_iter)\n        x_s = x_s.to(device)\n        labels_s = labels_s.to(device).float()\n        x_t, labels_t = next(train_target_iter)\n        x_t = x_t.to(device)\n        labels_t = labels_t.to(device).float()\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat([x_s, x_t], dim=0)\n        outputs, outputs_adv = model(x)\n        y_s, y_t = outputs.chunk(2, dim=0)\n        y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)\n\n        # compute mean square loss on source domain\n        mse_loss = F.mse_loss(y_s, labels_s)\n\n        # compute margin disparity discrepancy between domains\n        transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv)\n        # for adversarial classifier, minimize negative mdd is equal to maximize mdd\n        loss = mse_loss - transfer_loss * args.trade_off\n        model.step()\n\n        mae_loss_s = F.l1_loss(y_s, labels_s)\n        mae_loss_t = F.l1_loss(y_t, labels_t)\n\n        source_losses.update(mse_loss.item(), x_s.size(0))\n        trans_losses.update(transfer_loss.item(), x_s.size(0))\n        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))\n        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='DD for Regression Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',\n                        help='dataset: ' + ' | '.join(dataset_names) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-size', type=int, default=128)\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet18)')\n    parser.add_argument('--bottleneck-dim', default=512, type=int)\n    parser.add_argument('--normalization', default='BN', type=str, choices=['BN', 'IN'])\n    parser.add_argument('--margin', type=float, default=1., help=\"margin gamma\")\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='dd',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/dd.sh",
    "content": "# DSprites\nCUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2N --wd 0.0005\nCUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2S --wd 0.0005\nCUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2C --wd 0.0005\nCUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2S --wd 0.0005\nCUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2C --wd 0.0005\nCUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2N --wd 0.0005\n\n# MPI3D\nCUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2RC --normalization IN --resize-size 224 --weight-decay 0.001\nCUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2T --normalization IN --resize-size 224 --weight-decay 0.001\nCUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2RL --normalization IN --resize-size 224 --weight-decay 0.001\nCUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2T --normalization IN --resize-size 224 --weight-decay 0.001\nCUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RL --normalization IN --resize-size 224 --weight-decay 0.001\nCUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RC --normalization IN --resize-size 224 --weight-decay 0.001\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/erm.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.regressor import Regressor\nimport tllib.vision.datasets.regression as datasets\nimport tllib.vision.models as models\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n\n    dataset = datasets.__dict__[args.data]\n    train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,\n                                   transform=train_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,\n                                   transform=train_transform)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = models.__dict__[args.arch](pretrained=True)\n    if args.normalization == 'IN':\n        backbone = utils.convert_model(backbone)\n    num_factors = train_source_dataset.num_factors\n    regressor = Regressor(backbone=backbone, num_factors=num_factors).to(device)\n    print(regressor)\n    # define optimizer and lr scheduler\n    optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        regressor.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        # extract features from both domains\n        feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n        print(mae)\n        return\n\n    # start training\n    best_mae = 100000.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        print(\"lr\", lr_scheduler.get_lr())\n        train(train_source_iter, train_target_iter, regressor, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n\n        # remember best mae and save checkpoint\n        torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))\n        if mae < best_mae:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_mae = min(mae, best_mae)\n        print(\"mean MAE {:6.3f} best MAE {:6.3f}\".format(mae, best_mae))\n\n    print(\"best_mae = {:6.3f}\".format(best_mae))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: Regressor, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    mse_losses = AverageMeter('MSE Loss', ':6.3f')\n    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')\n    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, mse_losses, mae_losses_s, mae_losses_t],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, labels_s = next(train_source_iter)\n        x_s = x_s.to(device)\n        labels_s = labels_s.to(device).float()\n        x_t, labels_t = next(train_target_iter)\n        x_t = x_t.to(device)\n        labels_t = labels_t.to(device).float()\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, _ = model(x_s)\n        y_t, _ = model(x_t)\n\n        mse_loss = F.mse_loss(y_s, labels_s)\n        mae_loss_s = F.l1_loss(y_s, labels_s)\n        mae_loss_t = F.l1_loss(y_t, labels_t)\n        loss = mse_loss\n\n        mse_losses.update(mse_loss.item(), x_s.size(0))\n        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))\n        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='Source Only for Regression Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',\n                        help='dataset: ' + ' | '.join(dataset_names) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-size', type=int, default=128)\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet18)')\n    parser.add_argument('--normalization', default='BN', type=str, choices=[\"IN\", \"BN\"])\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='src_only',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/erm.sh",
    "content": "# DSprites\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2N\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2S\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2C\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2S\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2C\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2N\n\n# MPI3D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2RC\nCUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2T\nCUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2RL\nCUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2T\nCUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RL\nCUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RC\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/rsd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.regressor import Regressor\nfrom tllib.alignment.rsd import RepresentationSubspaceDistance\nimport tllib.vision.datasets.regression as datasets\nimport tllib.vision.models as models\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        T.Resize(args.resize_size),\n        T.ToTensor(),\n        normalize\n    ])\n\n    dataset = datasets.__dict__[args.data]\n    train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,\n                                   transform=train_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,\n                                   transform=train_transform)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = models.__dict__[args.arch](pretrained=True)\n    if args.normalization == 'IN':\n        backbone = utils.convert_model(backbone)\n    num_factors = train_source_dataset.num_factors\n    bottleneck = nn.Sequential(\n        nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n        nn.Flatten(),\n        nn.Linear(backbone.out_features, 256),\n        nn.ReLU()\n    )\n    regressor = Regressor(backbone=backbone, num_factors=num_factors, bottleneck=bottleneck,\n                          bottleneck_dim=256).to(device)\n    print(regressor)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    rsd = RepresentationSubspaceDistance(args.trade_off_bmp)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        regressor.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                         shuffle=True, num_workers=args.workers, drop_last=True)\n        # extract features from both domains\n        feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n        print(mae)\n        return\n\n    # start training\n    best_mae = 100000.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        print(\"lr\", lr_scheduler.get_lr())\n        train(train_source_iter, train_target_iter, regressor, rsd, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)\n\n        # remember best mae and save checkpoint\n        torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))\n        if mae < best_mae:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_mae = min(mae, best_mae)\n        print(\"mean MAE {:6.3f} best MAE {:6.3f}\".format(mae, best_mae))\n\n    print(\"best_mae = {:6.3f}\".format(best_mae))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: Regressor, rsd, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    mse_losses = AverageMeter('MSE Loss', ':6.3f')\n    rsd_losses = AverageMeter('RSD Loss', ':6.3f')\n    mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')\n    mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, mse_losses, rsd_losses, mae_losses_s, mae_losses_t],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, labels_s = next(train_source_iter)\n        x_s = x_s.to(device)\n        labels_s = labels_s.to(device).float()\n        x_t, labels_t = next(train_target_iter)\n        x_t = x_t.to(device)\n        labels_t = labels_t.to(device).float()\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n        y_t, f_t = model(x_t)\n\n        mse_loss = F.mse_loss(y_s, labels_s)\n        mae_loss_s = F.l1_loss(y_s, labels_s)\n        mae_loss_t = F.l1_loss(y_t, labels_t)\n        rsd_loss = rsd(f_s, f_t)\n        loss = mse_loss + rsd_loss * args.trade_off\n\n        mse_losses.update(mse_loss.item(), x_s.size(0))\n        rsd_losses.update(rsd_loss.item(), x_s.size(0))\n        mae_losses_s.update(mae_loss_s.item(), x_s.size(0))\n        mae_losses_t.update(mae_loss_t.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='RSD for Regression Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',\n                        help='dataset: ' + ' | '.join(dataset_names) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-size', type=int, default=128)\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet18)')\n    parser.add_argument('--normalization', default='BN', type=str, choices=[\"BN\", \"IN\"])\n    parser.add_argument('--trade-off', default=0.001, type=float)\n    parser.add_argument('--trade-off-bmp', default=0.1, type=float)\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.001, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='rsd',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/rsd.sh",
    "content": "# DSprites\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2N\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2S\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2C\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2S\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2C\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2N\n\n# MPI3D\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2RC --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2T --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2RL --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2T --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RL --resize-size 224\nCUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RC --resize-size 224\n"
  },
  {
    "path": "examples/domain_adaptation/image_regression/utils.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport sys\nimport time\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d\nfrom torch.nn.modules.instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d\n\nsys.path.append('../../..')\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\ndef convert_model(module):\n    \"\"\"convert BatchNorms in the `module` into InstanceNorms\"\"\"\n    source_modules = (BatchNorm1d, BatchNorm2d, BatchNorm3d)\n    target_modules = (InstanceNorm1d, InstanceNorm2d, InstanceNorm3d)\n    for src_module, tgt_module in zip(source_modules, target_modules):\n        if isinstance(module, src_module):\n            mod = tgt_module(module.num_features, module.eps, module.momentum, module.affine)\n            module = mod\n\n    for name, child in module.named_children():\n        module.add_module(name, convert_model(child))\n\n    return module\n\n\ndef validate(val_loader, model, args, factors, device):\n    batch_time = AverageMeter('Time', ':6.3f')\n    mae_losses = [AverageMeter('mae {}'.format(factor), ':6.3f') for factor in factors]\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time] + mae_losses,\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            for j in range(len(factors)):\n                mae_loss = F.l1_loss(output[:, j], target[:, j])\n                mae_losses[j].update(mae_loss.item(), images.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        for i, factor in enumerate(factors):\n            print(\"{} MAE {mae.avg:6.3f}\".format(factor, mae=mae_losses[i]))\n        mean_mae = sum(l.avg for l in mae_losses) / len(factors)\n    return mean_mae\n"
  },
  {
    "path": "examples/domain_adaptation/keypoint_detection/README.md",
    "content": "# Unsupervised Domain Adaptation for Keypoint Detection\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to better reproduce the benchmark results.\n\n## Dataset\nFollowing datasets can be downloaded automatically:\n\n- [Rendered Handpose Dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html)\n- [Hand-3d-Studio Dataset](https://www.yangangwang.com/papers/ZHAO-H3S-2020-02.html)\n- [FreiHAND Dataset](https://lmb.informatik.uni-freiburg.de/projects/freihand/)\n- [Surreal Dataset](https://www.di.ens.fr/willow/research/surreal/data/)\n- [LSP Dataset](http://sam.johnson.io/research/lsp.html)\n\nYou need to prepare following datasets manually if you want to use them:\n- [Human3.6M Dataset](http://vision.imar.ro/human3.6m/description.php)\n\nand prepare them following [Documentations for Human3.6M Dataset](/common/vision/datasets/keypoint_detection/human36m.py).\n\n## Supported Methods\n\nSupported methods include:\n\n- [Regressive Domain Adaptation for Unsupervised Keypoint Detection (RegDA, CVPR 2021)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/regressive-domain-adaptation-cvpr21.pdf)\n\n## Experiment and Results\n\nThe shell files give the script to reproduce the results with specified hyper-parameters.\nFor example, if you want to train RegDA on RHD->H3D, use the following script\n\n```shell script\n# Train a RegDA on RHD -> H3D task using PoseResNet.\n# Assume you have put the datasets under the path `data/RHD` and  `data/H3D_crop`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python regda.py data/RHD data/H3D_crop \\\n    -s RenderedHandPose -t Hand3DStudio --finetune --seed 0 --debug --log logs/regda/rhd2h3d\n```\n\n### RHD->H3D accuracy on ResNet-101\n\n| Methods     | MCP  | PIP  | DIP  | Fingertip | Avg  |\n|-------------|------|------|------|-----------|------|\n| ERM | 67.4 | 64.2 | 63.3 | 54.8      | 61.8 |\n| RegDA       | 79.6 | 74.4 | 71.2 | 62.9      | 72.5 |\n| Oracle      | 97.7 | 97.2 | 95.7 | 92.5      | 95.8 |\n\n### Surreal->Human3.6M accuracy on ResNet-101\n\n| Methods     | Shoulder | Elbow | Wrist | Hip  | Knee | Ankle | Avg  |\n|-------------|----------|-------|-------|------|------|-------|------|\n| ERM | 69.4     | 75.4  | 66.4  | 37.9 | 77.3 | 77.7  | 67.3 |\n| RegDA       | 73.3     | 86.4  | 72.8  | 54.8 | 82.0 | 84.4  | 75.6 |\n| Oracle      | 95.3     | 91.8  | 86.9  | 95.6 | 94.1 | 93.6  | 92.9 |\n\n### Surreal->LSP accuracy on ResNet-101\n\n| Methods     | Shoulder | Elbow | Wrist | Hip  | Knee | Ankle | Avg  |\n|-------------|----------|-------|-------|------|------|-------|------|\n| ERM | 51.5     | 65.0  | 62.9  | 68.0 | 68.7 | 67.4  | 63.9 |\n| RegDA       | 62.7     | 76.7  | 71.1  | 81.0 | 80.3 | 75.3  | 74.6 |\n| Oracle      | 95.3     | 91.8  | 86.9  | 95.6 | 94.1 | 93.6  | 92.9 |\n\n## Visualization\nIf you want to visualize the keypoint detection results during training, you should set --debug.\n\n```\nCUDA_VISIBLE_DEVICES=0 python erm.py data/RHD data/H3D_crop -s RenderedHandPose -t Hand3DStudio --log logs/erm/rhd2h3d --debug --seed 0\n```\n\nThen you can find visualization images in directory ``logs/erm/rhd2h3d/visualize/``.\n\n<img src=\"./fig/keypoint_detection.jpg\" width=\"300\"/>\n\n\n## TODO\nSupport methods:  CycleGAN\n\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@InProceedings{RegDA,\n    author    = {Junguang Jiang and\n                Yifei Ji and\n                Ximei Wang and\n                Yufeng Liu and\n                Jianmin Wang and\n                Mingsheng Long},\n    title     = {Regressive Domain Adaptation for Unsupervised Keypoint Detection},\n    booktitle = {CVPR},\n    year = {2021}\n}\n\n```\n"
  },
  {
    "path": "examples/domain_adaptation/keypoint_detection/erm.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom torch.utils.data import DataLoader\nfrom torchvision.transforms import Compose, ToPILImage\n\nsys.path.append('../../..')\nimport tllib.vision.models.keypoint_detection as models\nfrom tllib.vision.models.keypoint_detection.loss import JointsMSELoss\nimport tllib.vision.datasets.keypoint_detection as datasets\nimport tllib.vision.transforms.keypoint_detection as T\nfrom tllib.vision.transforms import Denormalize\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict\nfrom tllib.utils.metric.keypoint_detection import accuracy\nfrom tllib.utils.logger import CompleteLogger\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        T.RandomRotation(args.rotation),\n        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),\n        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),\n        T.GaussianBlur(),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        T.Resize(args.image_size),\n        T.ToTensor(),\n        normalize\n    ])\n    image_size = (args.image_size, args.image_size)\n    heatmap_size = (args.heatmap_size, args.heatmap_size)\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform,\n                                          image_size=image_size, heatmap_size=heatmap_size)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform,\n                                        image_size=image_size, heatmap_size=heatmap_size)\n    val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)\n\n    target_dataset = datasets.__dict__[args.target]\n    train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform,\n                                          image_size=image_size, heatmap_size=heatmap_size)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform,\n                                        image_size=image_size, heatmap_size=heatmap_size)\n    val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)\n\n    print(\"Source train:\", len(train_source_loader))\n    print(\"Target train:\", len(train_target_loader))\n    print(\"Source test:\", len(val_source_loader))\n    print(\"Target test:\", len(val_target_loader))\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    model = models.__dict__[args.arch](num_keypoints=train_source_dataset.num_keypoints).to(device)\n    criterion = JointsMSELoss()\n\n    # define optimizer and lr scheduler\n    optimizer = Adam(model.get_parameters(lr=args.lr))\n    lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)\n\n    # optionally resume from a checkpoint\n    start_epoch = 0\n    if args.resume:\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n        start_epoch = checkpoint['epoch'] + 1\n\n    # define visualization function\n    tensor_to_image = Compose([\n        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n        ToPILImage()\n    ])\n\n    def visualize(image, keypoint2d, name):\n        \"\"\"\n        Args:\n            image (tensor): image in shape 3 x H x W\n            keypoint2d (tensor): keypoints in shape K x 2\n            name: name of the saving image\n        \"\"\"\n        train_source_dataset.visualize(tensor_to_image(image),\n                                       keypoint2d, logger.get_image_path(\"{}.jpg\".format(name)))\n\n    if args.phase == 'test':\n        # evaluate on validation set\n        source_val_acc = validate(val_source_loader, model, criterion, None, args)\n        target_val_acc = validate(val_target_loader, model, criterion, visualize, args)\n        print(\"Source: {:4.3f} Target: {:4.3f}\".format(source_val_acc['all'], target_val_acc['all']))\n        for name, acc in target_val_acc.items():\n            print(\"{}: {:4.3f}\".format(name, acc))\n        return\n\n    # start training\n    best_acc = 0\n    for epoch in range(start_epoch, args.epochs):\n        logger.set_epoch(epoch)\n        lr_scheduler.step()\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, model, criterion, optimizer, epoch,\n              visualize if args.debug else None, args)\n\n        # evaluate on validation set\n        source_val_acc = validate(val_source_loader, model, criterion, None, args)\n        target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args)\n\n        # remember best acc and save checkpoint\n        torch.save(\n            {\n                'model': model.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'lr_scheduler': lr_scheduler.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n        if target_val_acc['all'] > best_acc:\n            shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))\n            best_acc = target_val_acc['all']\n        print(\"Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}\".format(source_val_acc['all'], target_val_acc['all'], best_acc))\n        for name, acc in target_val_acc.items():\n            print(\"{}: {:4.3f}\".format(name, acc))\n\n    logger.close()\n\n\ndef train(train_source_iter, train_target_iter, model, criterion,\n          optimizer, epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', \":.2e\")\n    acc_s = AverageMeter(\"Acc (s)\", \":3.2f\")\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s, acc_s],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, label_s, weight_s, meta_s = next(train_source_iter)\n\n        x_s = x_s.to(device)\n        label_s = label_s.to(device)\n        weight_s = weight_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s = model(x_s)\n        loss_s = criterion(y_s, label_s, weight_s)\n\n        # compute gradient and do SGD step\n        loss_s.backward()\n        optimizer.step()\n\n        # measure accuracy and record loss\n        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),\n                                               label_s.detach().cpu().numpy())\n        acc_s.update(avg_acc_s, cnt_s)\n        losses_s.update(loss_s, cnt_s)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n            if visualize is not None:\n                visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, \"source_{}_pred.jpg\".format(i))\n                visualize(x_s[0], meta_s['keypoint2d'][0], \"source_{}_label.jpg\".format(i))\n\n\ndef validate(val_loader, model, criterion, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.2e')\n    acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), \":3.2f\")\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, acc['all']],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (x, label, weight, meta) in enumerate(val_loader):\n            x = x.to(device)\n            label = label.to(device)\n            weight = weight.to(device)\n\n            # compute output\n            y = model(x)\n            loss = criterion(y, label, weight)\n\n            # measure accuracy and record loss\n            losses.update(loss.item(), x.size(0))\n            acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(),\n                                                          label.cpu().numpy())\n\n            group_acc = val_loader.dataset.group_accuracy(acc_per_points)\n            acc.update(group_acc, x.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n                if visualize is not None:\n                    visualize(x[0], pred[0] * args.image_size / args.heatmap_size, \"val_{}_pred.jpg\".format(i))\n                    visualize(x[0], meta['keypoint2d'][0], \"val_{}_label.jpg\".format(i))\n\n    return acc.average()\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='Source Only for Keypoint Detection Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3),\n                        help='scale range for the RandomResizeCrop augmentation')\n    parser.add_argument('--rotation', type=int, default=180,\n                        help='rotation range of the RandomRotation augmentation')\n    parser.add_argument('--image-size', type=int, default=256,\n                        help='input image size')\n    parser.add_argument('--heatmap-size', type=int, default=64,\n                        help='output heatmap size')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='pose_resnet101',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: pose_resnet101)')\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"where restore model parameters from.\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler')\n    parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=70, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='src_only',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--debug', action=\"store_true\",\n                        help='In the debug mode, save images and predictions')\n    args = parser.parse_args()\n    main(args)\n\n"
  },
  {
    "path": "examples/domain_adaptation/keypoint_detection/erm.sh",
    "content": "# Source Only\n# Hands Dataset\nCUDA_VISIBLE_DEVICES=0 python erm.py data/RHD data/H3D_crop \\\n    -s RenderedHandPose -t Hand3DStudio --log logs/erm/rhd2h3d --debug --seed 0\nCUDA_VISIBLE_DEVICES=0 python erm.py data/FreiHand data/RHD \\\n    -s FreiHand -t RenderedHandPose --log logs/erm/freihand2rhd --debug --seed 0\n\n# Body Dataset\nCUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/Human36M \\\n    -s SURREAL -t Human36M --log logs/erm/surreal2human36m --debug --seed 0 --rotation 30\nCUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/lsp \\\n    -s SURREAL -t LSP --log logs/erm/surreal2lsp --debug --seed 0 --rotation 30\n\n# Oracle Results\nCUDA_VISIBLE_DEVICES=0 python erm.py data/H3D_crop data/H3D_crop \\\n    -s Hand3DStudio -t Hand3DStudio --log logs/oracle/h3d --debug --seed 0\nCUDA_VISIBLE_DEVICES=0 python erm.py data/Human36M data/Human36M \\\n    -s Human36M -t Human36M --log logs/oracle/human36m --debug --seed 0 --rotation 30\n"
  },
  {
    "path": "examples/domain_adaptation/keypoint_detection/regda.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR, MultiStepLR\nfrom torch.utils.data import DataLoader\nfrom torchvision.transforms import Compose, ToPILImage\n\nsys.path.append('../../..')\nfrom tllib.alignment.regda import PoseResNet2d as RegDAPoseResNet, \\\n    PseudoLabelGenerator2d, RegressionDisparity\nimport tllib.vision.models as models\nfrom tllib.vision.models.keypoint_detection.pose_resnet import Upsampling, PoseResNet\nfrom tllib.vision.models.keypoint_detection.loss import JointsKLLoss\nimport tllib.vision.datasets.keypoint_detection as datasets\nimport tllib.vision.transforms.keypoint_detection as T\nfrom tllib.vision.transforms import Denormalize\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict\nfrom tllib.utils.metric.keypoint_detection import accuracy\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        T.RandomRotation(args.rotation),\n        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),\n        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),\n        T.GaussianBlur(),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        T.Resize(args.image_size),\n        T.ToTensor(),\n        normalize\n    ])\n    image_size = (args.image_size, args.image_size)\n    heatmap_size = (args.heatmap_size, args.heatmap_size)\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform,\n                                          image_size=image_size, heatmap_size=heatmap_size)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform,\n                                        image_size=image_size, heatmap_size=heatmap_size)\n    val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)\n\n    target_dataset = datasets.__dict__[args.target]\n    train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform,\n                                          image_size=image_size, heatmap_size=heatmap_size)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform,\n                                        image_size=image_size, heatmap_size=heatmap_size)\n    val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)\n\n    print(\"Source train:\", len(train_source_loader))\n    print(\"Target train:\", len(train_target_loader))\n    print(\"Source test:\", len(val_source_loader))\n    print(\"Target test:\", len(val_target_loader))\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    backbone = models.__dict__[args.arch](pretrained=True)\n    upsampling = Upsampling(backbone.out_features)\n    num_keypoints = train_source_dataset.num_keypoints\n    model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device)\n    # define loss function\n    criterion = JointsKLLoss()\n    pseudo_label_generator = PseudoLabelGenerator2d(num_keypoints, args.heatmap_size, args.heatmap_size)\n    regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7))\n\n    # define optimizer and lr scheduler\n    optimizer_f = SGD([\n        {'params': backbone.parameters(), 'lr': 0.1},\n        {'params': upsampling.parameters(), 'lr': 0.1},\n    ], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)\n    lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)\n    lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)\n    lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)\n    start_epoch = 0\n\n    if args.resume is None:\n        if args.pretrain is None:\n            # first pretrain the backbone and upsampling\n            print(\"Pretraining the model on source domain.\")\n            args.pretrain = logger.get_checkpoint_path('pretrain')\n            pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device)\n            optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n            lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)\n            best_acc = 0\n            for epoch in range(args.pretrain_epochs):\n                lr_scheduler.step()\n                print(lr_scheduler.get_lr())\n\n                pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args)\n                source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args)\n\n                # remember best acc and save checkpoint\n                if source_val_acc['all'] > best_acc:\n                    best_acc = source_val_acc['all']\n                    torch.save(\n                        {\n                            'model': pretrained_model.state_dict()\n                        }, args.pretrain\n                    )\n                print(\"Source: {} best: {}\".format(source_val_acc['all'], best_acc))\n\n        # load from the pretrained checkpoint\n        pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model']\n        model_dict = model.state_dict()\n        # remove keys from pretrained dict that doesn't appear in model dict\n        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n        model.load_state_dict(pretrained_dict, strict=False)\n    else:\n        # optionally resume from a checkpoint\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n        optimizer_f.load_state_dict(checkpoint['optimizer_f'])\n        optimizer_h.load_state_dict(checkpoint['optimizer_h'])\n        optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])\n        lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])\n        lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])\n        lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])\n        start_epoch = checkpoint['epoch'] + 1\n\n    # define visualization function\n    tensor_to_image = Compose([\n        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n        ToPILImage()\n    ])\n\n    def visualize(image, keypoint2d, name, heatmaps=None):\n        \"\"\"\n        Args:\n            image (tensor): image in shape 3 x H x W\n            keypoint2d (tensor): keypoints in shape K x 2\n            name: name of the saving image\n        \"\"\"\n        train_source_dataset.visualize(tensor_to_image(image),\n                                       keypoint2d, logger.get_image_path(\"{}.jpg\".format(name)))\n\n    if args.phase == 'test':\n        # evaluate on validation set\n        source_val_acc = validate(val_source_loader, model, criterion, None, args)\n        target_val_acc = validate(val_target_loader, model, criterion, visualize, args)\n        print(\"Source: {:4.3f} Target: {:4.3f}\".format(source_val_acc['all'], target_val_acc['all']))\n        for name, acc in target_val_acc.items():\n            print(\"{}: {:4.3f}\".format(name, acc))\n        return\n\n    # start training\n    best_acc = 0\n    print(\"Start regression domain adaptation.\")\n    for epoch in range(start_epoch, args.epochs):\n        logger.set_epoch(epoch)\n        print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, model, criterion, regression_disparity,\n              optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,\n              epoch, visualize if args.debug else None, args)\n\n        # evaluate on validation set\n        source_val_acc = validate(val_source_loader, model, criterion, None, args)\n        target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args)\n\n        # remember best acc and save checkpoint\n        torch.save(\n            {\n                'model': model.state_dict(),\n                'optimizer_f': optimizer_f.state_dict(),\n                'optimizer_h': optimizer_h.state_dict(),\n                'optimizer_h_adv': optimizer_h_adv.state_dict(),\n                'lr_scheduler_f': lr_scheduler_f.state_dict(),\n                'lr_scheduler_h': lr_scheduler_h.state_dict(),\n                'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n        if target_val_acc['all'] > best_acc:\n            shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))\n            best_acc = target_val_acc['all']\n        print(\"Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}\".format(source_val_acc['all'], target_val_acc['all'], best_acc))\n        for name, acc in target_val_acc.items():\n            print(\"{}: {:4.3f}\".format(name, acc))\n\n    logger.close()\n\n\ndef pretrain(train_source_iter, model, criterion, optimizer,\n             epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', \":.2e\")\n    acc_s = AverageMeter(\"Acc (s)\", \":3.2f\")\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s, acc_s],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, label_s, weight_s, meta_s = next(train_source_iter)\n\n        x_s = x_s.to(device)\n        label_s = label_s.to(device)\n        weight_s = weight_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s = model(x_s)\n        loss_s = criterion(y_s, label_s, weight_s)\n\n        # compute gradient and do SGD step\n        loss_s.backward()\n        optimizer.step()\n\n        # measure accuracy and record loss\n        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),\n                                               label_s.detach().cpu().numpy())\n        acc_s.update(avg_acc_s, cnt_s)\n        losses_s.update(loss_s, cnt_s)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\ndef train(train_source_iter, train_target_iter, model, criterion,regression_disparity,\n          optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,\n          epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', \":.2e\")\n    losses_gf = AverageMeter('Loss (t, false)', \":.2e\")\n    losses_gt = AverageMeter('Loss (t, truth)', \":.2e\")\n    acc_s = AverageMeter(\"Acc (s)\", \":3.2f\")\n    acc_t = AverageMeter(\"Acc (t)\", \":3.2f\")\n    acc_s_adv = AverageMeter(\"Acc (s, adv)\", \":3.2f\")\n    acc_t_adv = AverageMeter(\"Acc (t, adv)\", \":3.2f\")\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, label_s, weight_s, meta_s = next(train_source_iter)\n        x_t, label_t, weight_t, meta_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        label_s = label_s.to(device)\n        weight_s = weight_s.to(device)\n\n        x_t = x_t.to(device)\n        label_t = label_t.to(device)\n        weight_t = weight_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # Step A train all networks to minimize loss on source domain\n        optimizer_f.zero_grad()\n        optimizer_h.zero_grad()\n        optimizer_h_adv.zero_grad()\n\n        y_s, y_s_adv = model(x_s)\n        loss_s = criterion(y_s, label_s, weight_s) + \\\n                 args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')\n        loss_s.backward()\n        optimizer_f.step()\n        optimizer_h.step()\n        optimizer_h_adv.step()\n\n        # Step B train adv regressor to maximize regression disparity\n        optimizer_h_adv.zero_grad()\n        y_t, y_t_adv = model(x_t)\n        loss_ground_false = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='max')\n        loss_ground_false.backward()\n        optimizer_h_adv.step()\n\n        # Step C train feature extractor to minimize regression disparity\n        optimizer_f.zero_grad()\n        y_t, y_t_adv = model(x_t)\n        loss_ground_truth = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='min')\n        loss_ground_truth.backward()\n        optimizer_f.step()\n\n        # do update step\n        model.step()\n        lr_scheduler_f.step()\n        lr_scheduler_h.step()\n        lr_scheduler_h_adv.step()\n\n        # measure accuracy and record loss\n        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),\n                                               label_s.detach().cpu().numpy())\n        acc_s.update(avg_acc_s, cnt_s)\n        _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),\n                                               label_t.detach().cpu().numpy())\n        acc_t.update(avg_acc_t, cnt_t)\n        _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(y_s_adv.detach().cpu().numpy(),\n                                               label_s.detach().cpu().numpy())\n        acc_s_adv.update(avg_acc_s_adv, cnt_s)\n        _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(y_t_adv.detach().cpu().numpy(),\n                                               label_t.detach().cpu().numpy())\n        acc_t_adv.update(avg_acc_t_adv, cnt_t)\n        losses_s.update(loss_s, cnt_s)\n        losses_gf.update(loss_ground_false, cnt_s)\n        losses_gt.update(loss_ground_truth, cnt_s)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n            if visualize is not None:\n                visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, \"source_{}_pred\".format(i))\n                visualize(x_s[0], meta_s['keypoint2d'][0], \"source_{}_label\".format(i))\n                visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, \"target_{}_pred\".format(i))\n                visualize(x_t[0], meta_t['keypoint2d'][0], \"target_{}_label\".format(i))\n                visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, \"source_adv_{}_pred\".format(i))\n                visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, \"target_adv_{}_pred\".format(i))\n\n\ndef validate(val_loader, model, criterion, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.2e')\n    acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), \":3.2f\")\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, acc['all']],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (x, label, weight, meta) in enumerate(val_loader):\n            x = x.to(device)\n            label = label.to(device)\n            weight = weight.to(device)\n\n            # compute output\n            y = model(x)\n            loss = criterion(y, label, weight)\n\n            # measure accuracy and record loss\n            losses.update(loss.item(), x.size(0))\n            acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(),\n                                                          label.cpu().numpy())\n\n            group_acc = val_loader.dataset.group_accuracy(acc_per_points)\n            acc.update(group_acc, x.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n                if visualize is not None:\n                    visualize(x[0], pred[0] * args.image_size / args.heatmap_size, \"val_{}_pred.jpg\".format(i))\n                    visualize(x[0], meta['keypoint2d'][0], \"val_{}_label.jpg\".format(i))\n\n    return acc.average()\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='RegDA for Keypoint Detection Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3),\n                        help='scale range for the RandomResizeCrop augmentation')\n    parser.add_argument('--rotation', type=int, default=180,\n                        help='rotation range of the RandomRotation augmentation')\n    parser.add_argument('--image-size', type=int, default=256,\n                        help='input image size')\n    parser.add_argument('--heatmap-size', type=int, default=64,\n                        help='output heatmap size')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet101',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet101)')\n    parser.add_argument(\"--pretrain\", type=str, default=None,\n                        help=\"Where restore pretrained model parameters from.\")\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"where restore model parameters from.\")\n    parser.add_argument('--num-head-layers', type=int, default=2)\n    parser.add_argument('--margin', type=float, default=4., help=\"margin gamma\")\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0001, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--lr-gamma', default=0.0001, type=float)\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler')\n    parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--pretrain_epochs', default=70, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--epochs', default=30, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='regda',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--debug', action=\"store_true\",\n                        help='In the debug mode, save images and predictions')\n    args = parser.parse_args()\n    main(args)\n\n"
  },
  {
    "path": "examples/domain_adaptation/keypoint_detection/regda.sh",
    "content": "# Hands Dataset\nCUDA_VISIBLE_DEVICES=0 python regda.py data/RHD data/H3D_crop \\\n    -s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda/rhd2h3d\nCUDA_VISIBLE_DEVICES=0 python regda.py data/FreiHand data/RHD \\\n    -s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda/freihand2rhd\n\n# Body Dataset\nCUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/Human36M \\\n    -s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda/surreal2human36m\nCUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/lsp \\\n    -s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda/surreal2lsp\n"
  },
  {
    "path": "examples/domain_adaptation/keypoint_detection/regda_fast.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR, MultiStepLR\nfrom torch.utils.data import DataLoader\nfrom torchvision.transforms import Compose, ToPILImage\n\nsys.path.append('../../..')\nfrom tllib.alignment.regda import PoseResNet2d as RegDAPoseResNet, \\\n    FastPseudoLabelGenerator2d, RegressionDisparity\nimport tllib.vision.models as models\nfrom tllib.vision.models.keypoint_detection.pose_resnet import Upsampling, PoseResNet\nfrom tllib.vision.models.keypoint_detection.loss import JointsKLLoss\nimport tllib.vision.datasets.keypoint_detection as datasets\nimport tllib.vision.transforms.keypoint_detection as T\nfrom tllib.vision.transforms import Denormalize\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict\nfrom tllib.utils.metric.keypoint_detection import accuracy\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n    train_transform = T.Compose([\n        T.RandomRotation(args.rotation),\n        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),\n        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),\n        T.GaussianBlur(),\n        T.ToTensor(),\n        normalize\n    ])\n    val_transform = T.Compose([\n        T.Resize(args.image_size),\n        T.ToTensor(),\n        normalize\n    ])\n    image_size = (args.image_size, args.image_size)\n    heatmap_size = (args.heatmap_size, args.heatmap_size)\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform,\n                                          image_size=image_size, heatmap_size=heatmap_size)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform,\n                                        image_size=image_size, heatmap_size=heatmap_size)\n    val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)\n\n    target_dataset = datasets.__dict__[args.target]\n    train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform,\n                                          image_size=image_size, heatmap_size=heatmap_size)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform,\n                                        image_size=image_size, heatmap_size=heatmap_size)\n    val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)\n\n    print(\"Source train:\", len(train_source_loader))\n    print(\"Target train:\", len(train_target_loader))\n    print(\"Source test:\", len(val_source_loader))\n    print(\"Target test:\", len(val_target_loader))\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    backbone = models.__dict__[args.arch](pretrained=True)\n    upsampling = Upsampling(backbone.out_features)\n    num_keypoints = train_source_dataset.num_keypoints\n    model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device)\n    # define loss function\n    criterion = JointsKLLoss()\n    pseudo_label_generator = FastPseudoLabelGenerator2d()\n    regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7))\n\n    # define optimizer and lr scheduler\n    optimizer_f = SGD([\n        {'params': backbone.parameters(), 'lr': 0.1},\n        {'params': upsampling.parameters(), 'lr': 0.1},\n    ], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)\n    lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)\n    lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)\n    lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)\n    start_epoch = 0\n\n    if args.resume is None:\n        if args.pretrain is None:\n            # first pretrain the backbone and upsampling\n            print(\"Pretraining the model on source domain.\")\n            args.pretrain = logger.get_checkpoint_path('pretrain')\n            pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device)\n            optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n            lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)\n            best_acc = 0\n            for epoch in range(args.pretrain_epochs):\n                lr_scheduler.step()\n                print(lr_scheduler.get_lr())\n\n                pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args)\n                source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args)\n\n                # remember best acc and save checkpoint\n                if source_val_acc['all'] > best_acc:\n                    best_acc = source_val_acc['all']\n                    torch.save(\n                        {\n                            'model': pretrained_model.state_dict()\n                        }, args.pretrain\n                    )\n                print(\"Source: {} best: {}\".format(source_val_acc['all'], best_acc))\n\n        # load from the pretrained checkpoint\n        pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model']\n        model_dict = model.state_dict()\n        # remove keys from pretrained dict that doesn't appear in model dict\n        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n        model.load_state_dict(pretrained_dict, strict=False)\n    else:\n        # optionally resume from a checkpoint\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n        optimizer_f.load_state_dict(checkpoint['optimizer_f'])\n        optimizer_h.load_state_dict(checkpoint['optimizer_h'])\n        optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])\n        lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])\n        lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])\n        lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])\n        start_epoch = checkpoint['epoch'] + 1\n\n    # define visualization function\n    tensor_to_image = Compose([\n        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n        ToPILImage()\n    ])\n\n    def visualize(image, keypoint2d, name, heatmaps=None):\n        \"\"\"\n        Args:\n            image (tensor): image in shape 3 x H x W\n            keypoint2d (tensor): keypoints in shape K x 2\n            name: name of the saving image\n        \"\"\"\n        train_source_dataset.visualize(tensor_to_image(image),\n                                       keypoint2d, logger.get_image_path(\"{}.jpg\".format(name)))\n\n    if args.phase == 'test':\n        # evaluate on validation set\n        source_val_acc = validate(val_source_loader, model, criterion, None, args)\n        target_val_acc = validate(val_target_loader, model, criterion, visualize, args)\n        print(\"Source: {:4.3f} Target: {:4.3f}\".format(source_val_acc['all'], target_val_acc['all']))\n        for name, acc in target_val_acc.items():\n            print(\"{}: {:4.3f}\".format(name, acc))\n        return\n\n    # start training\n    best_acc = 0\n    print(\"Start regression domain adaptation.\")\n    for epoch in range(start_epoch, args.epochs):\n        logger.set_epoch(epoch)\n        print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, model, criterion, regression_disparity,\n              optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,\n              epoch, visualize if args.debug else None, args)\n\n        # evaluate on validation set\n        source_val_acc = validate(val_source_loader, model, criterion, None, args)\n        target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args)\n\n        # remember best acc and save checkpoint\n        torch.save(\n            {\n                'model': model.state_dict(),\n                'optimizer_f': optimizer_f.state_dict(),\n                'optimizer_h': optimizer_h.state_dict(),\n                'optimizer_h_adv': optimizer_h_adv.state_dict(),\n                'lr_scheduler_f': lr_scheduler_f.state_dict(),\n                'lr_scheduler_h': lr_scheduler_h.state_dict(),\n                'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n        if target_val_acc['all'] > best_acc:\n            shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))\n            best_acc = target_val_acc['all']\n        print(\"Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}\".format(source_val_acc['all'], target_val_acc['all'], best_acc))\n        for name, acc in target_val_acc.items():\n            print(\"{}: {:4.3f}\".format(name, acc))\n\n    logger.close()\n\n\ndef pretrain(train_source_iter, model, criterion, optimizer,\n             epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', \":.2e\")\n    acc_s = AverageMeter(\"Acc (s)\", \":3.2f\")\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s, acc_s],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, label_s, weight_s, meta_s = next(train_source_iter)\n\n        x_s = x_s.to(device)\n        label_s = label_s.to(device)\n        weight_s = weight_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s = model(x_s)\n        loss_s = criterion(y_s, label_s, weight_s)\n\n        # compute gradient and do SGD step\n        loss_s.backward()\n        optimizer.step()\n\n        # measure accuracy and record loss\n        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),\n                                               label_s.detach().cpu().numpy())\n        acc_s.update(avg_acc_s, cnt_s)\n        losses_s.update(loss_s, cnt_s)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\ndef train(train_source_iter, train_target_iter, model, criterion,regression_disparity,\n          optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,\n          epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', \":.2e\")\n    losses_gf = AverageMeter('Loss (t, false)', \":.2e\")\n    losses_gt = AverageMeter('Loss (t, truth)', \":.2e\")\n    acc_s = AverageMeter(\"Acc (s)\", \":3.2f\")\n    acc_t = AverageMeter(\"Acc (t)\", \":3.2f\")\n    acc_s_adv = AverageMeter(\"Acc (s, adv)\", \":3.2f\")\n    acc_t_adv = AverageMeter(\"Acc (t, adv)\", \":3.2f\")\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, label_s, weight_s, meta_s = next(train_source_iter)\n        x_t, label_t, weight_t, meta_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        label_s = label_s.to(device)\n        weight_s = weight_s.to(device)\n\n        x_t = x_t.to(device)\n        label_t = label_t.to(device)\n        weight_t = weight_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # Step A train all networks to minimize loss on source domain\n        optimizer_f.zero_grad()\n        optimizer_h.zero_grad()\n        optimizer_h_adv.zero_grad()\n\n        y_s, y_s_adv = model(x_s)\n        loss_s = criterion(y_s, label_s, weight_s) + \\\n                 args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')\n        loss_s.backward()\n        optimizer_f.step()\n        optimizer_h.step()\n        optimizer_h_adv.step()\n\n        # Step B train adv regressor to maximize regression disparity\n        optimizer_h_adv.zero_grad()\n        y_t, y_t_adv = model(x_t)\n        loss_ground_false = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='max')\n        loss_ground_false.backward()\n        optimizer_h_adv.step()\n\n        # Step C train feature extractor to minimize regression disparity\n        optimizer_f.zero_grad()\n        y_t, y_t_adv = model(x_t)\n        loss_ground_truth = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='min')\n        loss_ground_truth.backward()\n        optimizer_f.step()\n\n        # do update step\n        model.step()\n        lr_scheduler_f.step()\n        lr_scheduler_h.step()\n        lr_scheduler_h_adv.step()\n\n        # measure accuracy and record loss\n        _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),\n                                               label_s.detach().cpu().numpy())\n        acc_s.update(avg_acc_s, cnt_s)\n        _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),\n                                               label_t.detach().cpu().numpy())\n        acc_t.update(avg_acc_t, cnt_t)\n        _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(y_s_adv.detach().cpu().numpy(),\n                                               label_s.detach().cpu().numpy())\n        acc_s_adv.update(avg_acc_s_adv, cnt_s)\n        _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(y_t_adv.detach().cpu().numpy(),\n                                               label_t.detach().cpu().numpy())\n        acc_t_adv.update(avg_acc_t_adv, cnt_t)\n        losses_s.update(loss_s, cnt_s)\n        losses_gf.update(loss_ground_false, cnt_s)\n        losses_gt.update(loss_ground_truth, cnt_s)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n            if visualize is not None:\n                visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, \"source_{}_pred\".format(i))\n                visualize(x_s[0], meta_s['keypoint2d'][0], \"source_{}_label\".format(i))\n                visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, \"target_{}_pred\".format(i))\n                visualize(x_t[0], meta_t['keypoint2d'][0], \"target_{}_label\".format(i))\n                visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, \"source_adv_{}_pred\".format(i))\n                visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, \"target_adv_{}_pred\".format(i))\n\n\ndef validate(val_loader, model, criterion, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.2e')\n    acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), \":3.2f\")\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, acc['all']],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (x, label, weight, meta) in enumerate(val_loader):\n            x = x.to(device)\n            label = label.to(device)\n            weight = weight.to(device)\n\n            # compute output\n            y = model(x)\n            loss = criterion(y, label, weight)\n\n            # measure accuracy and record loss\n            losses.update(loss.item(), x.size(0))\n            acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(),\n                                                          label.cpu().numpy())\n\n            group_acc = val_loader.dataset.group_accuracy(acc_per_points)\n            acc.update(group_acc, x.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n                if visualize is not None:\n                    visualize(x[0], pred[0] * args.image_size / args.heatmap_size, \"val_{}_pred.jpg\".format(i))\n                    visualize(x[0], meta['keypoint2d'][0], \"val_{}_label.jpg\".format(i))\n\n    return acc.average()\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='RegDA (fast) for Keypoint Detection Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3),\n                        help='scale range for the RandomResizeCrop augmentation')\n    parser.add_argument('--rotation', type=int, default=180,\n                        help='rotation range of the RandomRotation augmentation')\n    parser.add_argument('--image-size', type=int, default=256,\n                        help='input image size')\n    parser.add_argument('--heatmap-size', type=int, default=64,\n                        help='output heatmap size')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet101',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet101)')\n    parser.add_argument(\"--pretrain\", type=str, default=None,\n                        help=\"Where restore pretrained model parameters from.\")\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"where restore model parameters from.\")\n    parser.add_argument('--num-head-layers', type=int, default=2)\n    parser.add_argument('--margin', type=float, default=4., help=\"margin gamma\")\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0001, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--lr-gamma', default=0.0001, type=float)\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler')\n    parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--pretrain_epochs', default=70, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--epochs', default=30, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='regda_fast',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--debug', action=\"store_true\",\n                        help='In the debug mode, save images and predictions')\n    args = parser.parse_args()\n    main(args)\n\n"
  },
  {
    "path": "examples/domain_adaptation/keypoint_detection/regda_fast.sh",
    "content": "# regda_fast is provided by https://github.com/YouJiacheng?tab=repositories\n# On single V100(16G), overall adversarial training time is reduced by about 40%.\n# yet the PCK might drop 1% for each dataset.\n# Hands Dataset\nCUDA_VISIBLE_DEVICES=0 python regda_fast.py data/RHD data/H3D_crop \\\n    -s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda_fast/rhd2h3d\nCUDA_VISIBLE_DEVICES=0 python regda_fast.py data/FreiHand data/RHD \\\n    -s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda_fast/freihand2rhd\n\n# Body Dataset\nCUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/Human36M \\\n    -s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda_fast/surreal2human36m\nCUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/lsp \\\n    -s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda_fast/surreal2lsp\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/README.md",
    "content": "# Unsupervised Domain Adaptation for Object Detection\n\n## Updates\n- *04/2022*: Provide CycleGAN translated datasets.\n\n\n## Installation\nOur code is based on [Detectron latest(v0.6)](https://detectron2.readthedocs.io/en/latest/tutorials/install.html), please install it before usage.\n\nThe following is an example based on PyTorch 1.9.0 with CUDA 11.1. For other versions, please refer to \nthe official website of [PyTorch](https://pytorch.org/) and \n[Detectron](https://detectron2.readthedocs.io/en/latest/tutorials/install.html).\n```shell\n# create environment\nconda create -n detection python=3.8.3\n# activate environment\nconda activate detection\n# install pytorch \npip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html\n# install detectron\npython -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html\n# install other requirements\npip install -r requirements.txt\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n- [PASCAL_VOC 07+12](http://host.robots.ox.ac.uk/pascal/VOC/)\n- Clipart\n- WaterColor\n- Comic\n\nYou need to prepare following datasets manually if you want to use them:\n\n#### Cityscapes, Foggy Cityscapes\n  - Download Cityscapes and Foggy Cityscapes dataset from the [link](https://www.cityscapes-dataset.com/downloads/). Particularly, we use *leftImg8bit_trainvaltest.zip* for Cityscapes and *leftImg8bit_trainvaltest_foggy.zip* for Foggy Cityscapes.\n  - Unzip them under the directory like\n\n```\nobject_detction/datasets/cityscapes\n├── gtFine\n├── leftImg8bit\n├── leftImg8bit_foggy\n└── ...\n```\nThen run \n```\npython prepare_cityscapes_to_voc.py \n```\nThis will automatically generate dataset in `VOC` format.\n```\nobject_detction/datasets/cityscapes_in_voc\n├── Annotations\n├── ImageSets\n└── JPEGImages\nobject_detction/datasets/foggy_cityscapes_in_voc\n├── Annotations\n├── ImageSets\n└── JPEGImages\n```\n\n#### Sim10k\n  - Download Sim10k dataset from the following links: [Sim10k](https://fcav.engin.umich.edu/projects/driving-in-the-matrix). Particularly, we use *repro_10k_images.tgz* , *repro_image_sets.tgz* and *repro_10k_annotations.tgz* for Sim10k.\n  - Extract the training set from *repro_10k_images.tgz*, *repro_image_sets.tgz* and *repro_10k_annotations.tgz*, then rename directory `VOC2012/` to `sim10k/`.\n  \nAfter preparation, there should exist following files:\n```\nobject_detction/datasets/\n├── VOC2007\n│   ├── Annotations\n│   ├──ImageSets\n│   └──JPEGImages\n├── VOC2012\n│   ├── Annotations\n│   ├── ImageSets\n│   └── JPEGImages\n├── clipart\n│   ├── Annotations\n│   ├── ImageSets\n│   └── JPEGImages\n├── watercolor\n│   ├── Annotations\n│   ├── ImageSets\n│   └── JPEGImages\n├── comic\n│   ├── Annotations\n│   ├── ImageSets\n│   └── JPEGImages\n├── cityscapes_in_voc\n│   ├── Annotations\n│   ├── ImageSets\n│   └── JPEGImages\n├── foggy_cityscapes_in_voc\n│   ├── Annotations\n│   ├── ImageSets\n│   └── JPEGImages\n└── sim10k\n    ├── Annotations\n    ├── ImageSets\n    └── JPEGImages\n```\n\n**Note**: The above is a tutorial for using standard datasets. To use your own datasets, \nyou need to convert them into corresponding format.\n\n#### CycleGAN translated dataset\n\nThe following command use CycleGAN to translate VOC (with directory `datasets/VOC2007` and `datasets/VOC2012`) to Clipart (with directory `datasets/VOC2007_to_clipart` and `datasets/VOC2012_to_clipart`).\n```\nmkdir datasets/VOC2007_to_clipart\ncp -r datasets/VOC2007/* datasets/VOC2007_to_clipart\nmkdir datasets/VOC2012_to_clipart\ncp -r datasets/VOC2012/* datasets/VOC2012_to_clipart\n\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \\\n  --translated-source datasets/VOC2007_to_clipart datasets/VOC2012_to_clipart \\\n  --log logs/cyclegan_resnet9/translation/voc2clipart --netG resnet_9\n```\n\nYou can also download and use datasets that are translated by us.\n\n- PASCAL_VOC to Clipart [[07]](https://cloud.tsinghua.edu.cn/f/1b6b060d202145aea416/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/818dbd8e41a043fab7c3/?dl=1) (with directory `datasets/VOC2007_to_clipart` and `datasets/VOC2012_to_clipart`)\n- PASCAL_VOC to Comic [[07]](https://cloud.tsinghua.edu.cn/f/89382bba64514210a9f8/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/f90289137fd5465f806d/?dl=1) (with directory `datasets/VOC2007_to_comic` and `datasets/VOC2012_to_comic`)\n- PASCAL_VOC to WaterColor [[07]](https://cloud.tsinghua.edu.cn/f/8e982e9f21294b38be8a/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/b8235034cb4247ce809f/?dl=1) (with directory `datasets/VOC2007_to_watercolor` and `datasets/VOC2012_to_watercolor`)\n- Cityscapes to Foggy Cityscapes [[Part1]](https://cloud.tsinghua.edu.cn/f/09ceeb25a476481bae29/?dl=1) [[Part2]](https://cloud.tsinghua.edu.cn/f/51fb05d3ee614e7d87a0/?dl=1) [[Part3]](https://cloud.tsinghua.edu.cn/f/646415daf6b344c3a9e3/?dl=1) [[Part4]](https://cloud.tsinghua.edu.cn/f/008d5d3c54344f83b101/?dl=1) (with directory `datasets/cityscapes_to_foggy_cityscapes`). Note that you need to use ``cat`` to merge the downloaded files.\n- Sim10k to Cityscapes (Car) [[Download]](https://cloud.tsinghua.edu.cn/f/33ac656fcde34f758dcd/?dl=1) (with directory `datasets/sim10k2cityscapes_car`).\n\n\n## Supported Methods\n\nSupported methods include:\n\n- [Cycle-Consistent Adversarial Networks (CycleGAN)](https://arxiv.org/pdf/1703.10593.pdf)\n- [Decoupled Adaptation for Cross-Domain Object Detection (D-adapt)](https://arxiv.org/abs/2110.02578)\n\n## Experiment and Results\n\nThe shell files give the script to reproduce the [benchmarks](/docs/dalib/benchmarks/object_detection.rst) with specified hyper-parameters.\nThe basic training pipeline is as follows.\n\nThe following command trains a Faster-RCNN detector on task VOC->Clipart, with only source (VOC) data.\n```\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \\\n  --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2clipart\n```\nExplanation of some arguments\n- `--config-file`: path to config file that specifies training hyper-parameters.\n- `-s`: a list that specifies source datasets, for each dataset you should pass in a `(name, path)` pair, in the\n    above command, there are two source datasets **VOC2007** and **VOC2012**.\n- `-t`: a list that specifies target datasets, same format as above.\n- `--test`: a list that specifiers test datasets, same format as above.\n\n### VOC->Clipart\n\n|                         |          | AP   | AP50 | AP75 | aeroplane | bicycle | bird | boat | bottle | bus  | car  | cat  | chair | cow  | diningtable | dog  | horse | motorbike | person | pottedplant | sheep | sofa | train | tvmonitor |\n|-------------------------|----------|------|------|------|-----------|---------|------|------|--------|------|------|------|-------|------|-------------|------|-------|-----------|--------|-------------|-------|------|-------|-----------|\n| Faster RCNN (ResNet101) | Source   | 14.9 | 29.3 | 12.6 | 29.6      | 38.0    | 24.7 | 21.7 | 31.9   | 48.0 | 30.8 | 15.9 | 32.0  | 19.2 | 18.2        | 12.1 | 28.2  | 48.8      | 38.3   | 34.6        | 3.8   | 22.5 | 43.7  | 44.0      |\n|                         | CycleGAN | 20.0 | 37.7 | 18.3 | 37.1      | 41.9    | 29.9 | 26.5 | 40.9   | 65.1 | 37.8 | 23.8 | 40.7  | 48.9 | 12.7        | 14.4 | 27.8  | 63.0      | 55.1   | 40.1        | 8.0   | 30.7 | 54.1  | 55.7      |\n|                         | D-adapt  | 24.8 | 49.0 | 21.5 | 56.4      | 63.2    | 42.3 | 40.9 | 45.3   | 77.0 | 48.7 | 25.4 | 44.3  | 58.4 | 31.4        | 24.5 | 47.1  | 75.3      | 69.3   | 43.5        | 27.9  | 34.1 | 60.7  | 64.0      |\n|                         |          |      |      |      |           |         |      |      |        |      |      |      |       |      |             |      |       |           |        |             |       |      |       |           |\n| RetinaNet               | Source   | 18.3 | 32.2 | 17.6 | 34.2      | 42.4    | 27.0 | 21.6 | 36.8   | 48.4 | 35.9 | 16.4 | 38.9  | 22.6 | 27.0        | 15.1 | 27.1  | 46.7      | 42.1   | 36.2        | 8.3   | 29.5 | 42.1  | 46.2      |\n|                         | D-adapt  | 25.1 | 46.3 | 23.9 | 47.4      | 65.0    | 33.1 | 37.5 | 56.8   | 61.2 | 55.1 | 27.3 | 45.5  | 51.8 | 29.1        | 29.6 | 38.0  | 74.5      | 66.7   | 46.0        | 24.2  | 29.3 | 54.2  | 53.8      |\n\n### VOC->WaterColor\n\n|                         | AP   | AP50 | AP75 | bicycle | bird | car  | cat  | dog  | person |\n|-------------------------|------|------|------|---------|------|------|------|------|--------|\n| Faster RCNN (ResNet101) | 23.0 | 45.9 | 18.5 | 71.1    | 48.3 | 48.6 | 23.7 | 23.3 | 60.3   |\n| CycleGAN                | 24.9 | 50.8 | 22.4 | 75.8    | 52.1 | 49.8 | 30.1 | 33.4 | 63.6   |\n| D-adapt                 | 28.5 | 57.5 | 23.6 | 77.4    | 54.0 | 52.8 | 43.9 | 48.1 | 68.9   |\n| Target                  | 23.8 | 51.3 | 17.4 | 48.5    | 54.7 | 41.3 | 36.2 | 52.6 | 74.6   |\n\n### VOC->Comic\n\n|                         |  AP  | AP50 | AP75 | bicycle | bird |  car |  cat |  dog | person |\n|:-----------------------:|:----:|:----:|:----:|:-------:|:----:|:----:|:----:|:----:|:------:|\n| Faster RCNN (ResNet101) | 13.0 | 25.5 | 11.4 |   33.0  | 15.8 | 28.9 | 16.8 | 19.6 |  39.0  |\n|         CycleGAN        | 16.9 | 34.6 | 14.2 |   28.1  | 25.7 | 37.7 | 28.0 | 33.8 |  54.1  |\n|         D-adapt         | 20.8 | 41.1 | 18.5 |   49.4  | 25.7 | 43.3 | 36.9 | 32.7 |  58.5  |\n|          Target         | 21.9 | 44.6 | 16.0 |   40.7  | 32.3 | 38.3 | 43.9 | 41.3 |  71.0  |\n\n\n### Cityscapes->Foggy Cityscapes\n|                         |          |  AP  | AP50 | AP75 | bicycle |  bus |  car | motorcycle | person | rider | train | truck |\n|:-----------------------:|:--------:|:----:|:----:|:----:|:-------:|:----:|:----:|:----------:|:------:|:-----:|:-----:|:-----:|\n|   Faster RCNN (VGG16)   |  Source  | 14.3 | 25.9 | 13.2 |   33.6  | 27.0 | 40.0 |    22.3    |  31.3  |  38.5 |  2.3  |  12.2 |\n|                         | CycleGAN | 22.5 | 41.6 | 20.7 |   46.5  | 41.5 | 62.0 |    33.8    |  45.0  |  54.5 |  21.7 |  27.7 |\n|                         |  D-adapt | 19.4 | 38.1 | 17.5 |   42.0  | 36.8 | 58.1 |    32.2    |  43.1  |  51.8 |  14.6 |  26.3 |\n|                         |  Target  | 24.0 | 45.3 | 21.3 |   45.9  | 47.4 | 67.3 |    39.7    |  49.0  |  53.2 |  30.0 |  29.6 |\n|                         |          |      |      |      |         |      |      |            |        |       |       |       |\n| Faster RCNN (ResNet101) |  Source  | 18.8 | 33.3 | 19.0 |   36.1  | 34.5 | 43.8 |    24.0    |  36.3  |  39.9 |  29.1 |  22.8 |\n|                         | CycleGAN | 22.9 | 41.8 | 21.9 |   42.0  | 44.5 | 57.6 |    36.3    |  40.9  |  48.0 |  30.8 |  34.3 |\n|                         |  D-adapt | 22.7 | 42.4 | 21.6 |   41.8  | 44.4 | 56.6 |    31.4    |  41.8  |  48.6 |  42.3 |  32.4 |\n|                         |  Target  | 25.5 | 45.3 | 24.3 |   41.9  | 53.2 | 63.4 |    36.1    |  42.6  |  47.9 |  42.4 |  35.3 |\n\n### Sim10k->Cityscapes Car\n\n|                         |          |  AP  | AP50 | AP75 |\n|:-----------------------:|:--------:|:----:|:----:|:----:|\n|   Faster RCNN (VGG16)   |  Source  | 24.8 | 43.4 | 23.6 |\n|                         | CycleGAN | 29.3 | 51.9 | 28.6 |\n|                         |  D-adapt | 23.6 | 48.5 | 18.7 |\n|                         |  Target  | 24.8 | 43.4 | 23.6 |\n|                         |          |      |      |      |\n| Faster RCNN (ResNet101) |  Source  | 24.6 | 44.4 | 23.0 |\n|                         | CycleGAN | 26.5 | 47.4 | 24.0 |\n|                         |  D-adapt | 27.4 | 51.9 | 25.7 |\n|                         |  Target  | 24.6 | 44.4 | 23.0 |\n\n### Visualization\nWe provide code for visualization in `visualize.py`. For example, suppose you have trained the source only model \nof task VOC->Clipart using provided scripts. The following code visualizes the prediction of the \ndetector on Clipart.\n```shell\nCUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  --test Clipart datasets/clipart --save-path visualizations/source_only/voc2clipart \\\n  MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth\n```\nExplanation of some arguments\n- `--test`: a list that specifiers test datasets for visualization.\n- `--save-path`: where to save visualization results.\n- `MODEL.WEIGHTS`: path to the model.\n\n## TODO\nSupport methods: SWDA, Global/Local Alignment\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{jiang2021decoupled,\n  title     = {Decoupled Adaptation for Cross-Domain Object Detection},\n  author    = {Junguang Jiang and Baixu Chen and Jianmin Wang and Mingsheng Long},\n  booktitle = {ICLR},\n  year      = {2022}\n}\n\n@inproceedings{CycleGAN,\n    title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},\n    author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},\n    booktitle={ICCV},\n    year={2017}\n}\n```\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_cityscapes.yaml",
    "content": "MODEL:\n  META_ARCHITECTURE: \"TLGeneralizedRCNN\"\n  WEIGHTS: \"detectron2://ImageNetPretrained/MSRA/R-101.pkl\"\n  MASK_ON: False\n  RESNETS:\n    DEPTH: 101\n  ROI_HEADS:\n    NAME: \"TLRes5ROIHeads\"\n    NUM_CLASSES: 8\n    BATCH_SIZE_PER_IMAGE: 512\n  ANCHOR_GENERATOR:\n    SIZES: [ [ 64, 128, 256, 512 ] ]\n  RPN:\n    PRE_NMS_TOPK_TEST: 6000\n    POST_NMS_TOPK_TEST: 1000\n    BATCH_SIZE_PER_IMAGE: 256\n  PROPOSAL_GENERATOR:\n    NAME: \"TLRPN\"\nINPUT:\n  MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,)\n  MIN_SIZE_TEST: 608\n  MAX_SIZE_TRAIN: 1166\nDATASETS:\n  TRAIN: (\"cityscapes_trainval\",)\n  TEST: (\"cityscapes_test\",)\nSOLVER:\n  STEPS: (12000,)\n  MAX_ITER: 16000  # 16 epochs\n  WARMUP_ITERS: 100\n  CHECKPOINT_PERIOD: 2000\n  IMS_PER_BATCH: 2\n  BASE_LR: 0.005\nTEST:\n  EVAL_PERIOD: 2000\nVIS_PERIOD: 500\nVERSION: 2\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_voc.yaml",
    "content": "MODEL:\n  META_ARCHITECTURE: \"TLGeneralizedRCNN\"\n  WEIGHTS: \"detectron2://ImageNetPretrained/MSRA/R-101.pkl\"\n  MASK_ON: False\n  RESNETS:\n    DEPTH: 101\n  ROI_HEADS:\n    NAME: \"TLRes5ROIHeads\"\n    NUM_CLASSES: 20\n    BATCH_SIZE_PER_IMAGE: 256\n  ANCHOR_GENERATOR:\n    SIZES: [ [ 64, 128, 256, 512 ] ]\n  RPN:\n    PRE_NMS_TOPK_TEST: 6000\n    POST_NMS_TOPK_TEST: 1000\n    BATCH_SIZE_PER_IMAGE: 128\n  PROPOSAL_GENERATOR:\n    NAME: \"TLRPN\"\nINPUT:\n  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,)\n  MIN_SIZE_TEST: 608\n  MAX_SIZE_TRAIN: 1166\nDATASETS:\n  TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')\n  TEST: ('voc_2007_test',)\nSOLVER:\n  STEPS: (12000, )\n  MAX_ITER: 16000  # 16 epochs\n  WARMUP_ITERS: 100\n  CHECKPOINT_PERIOD: 2000\n  IMS_PER_BATCH: 4\n  BASE_LR: 0.005\nTEST:\n  EVAL_PERIOD: 2000\nVIS_PERIOD: 500\nVERSION: 2"
  },
  {
    "path": "examples/domain_adaptation/object_detection/config/faster_rcnn_vgg_16_cityscapes.yaml",
    "content": "MODEL:\r\n  META_ARCHITECTURE: \"TLGeneralizedRCNN\"\r\n  WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth'\r\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\r\n  PIXEL_STD: [58.395, 57.120, 57.375]\r\n  MASK_ON: False\r\n  BACKBONE:\r\n    NAME: \"build_vgg_fpn_backbone\"\r\n  ROI_HEADS:\r\n    IN_FEATURES: [\"p3\", \"p4\", \"p5\", \"p6\"]\r\n    NAME: \"TLStandardROIHeads\"\r\n    NUM_CLASSES: 8\r\n  ROI_BOX_HEAD:\r\n    NAME: \"FastRCNNConvFCHead\"\r\n    NUM_FC: 2\r\n    POOLER_RESOLUTION: 7\r\n  ANCHOR_GENERATOR:\r\n    SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ]  # One size for each in feature map\r\n    ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ]  # Three aspect\r\n  RPN:\r\n    IN_FEATURES: [\"p3\", \"p4\", \"p5\", \"p6\", \"p7\"]\r\n    PRE_NMS_TOPK_TRAIN: 2000  # Per FPN level\r\n    PRE_NMS_TOPK_TEST: 1000  # Per FPN level\r\n    # Detectron1 uses 2000 proposals per-batch,\r\n    # (See \"modeling/rpn/rpn_outputs.py\" for details of this legacy issue)\r\n    # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.\r\n    POST_NMS_TOPK_TRAIN: 1000\r\n    POST_NMS_TOPK_TEST: 1000\r\n  PROPOSAL_GENERATOR:\r\n    NAME: \"TLRPN\"\r\nINPUT:\r\n  FORMAT: \"RGB\"\r\n  MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)\r\n  MIN_SIZE_TEST: 800\r\n  MAX_SIZE_TEST: 1280\r\n  MAX_SIZE_TRAIN: 1280\r\nDATASETS:\r\n  TRAIN: (\"cityscapes_trainval\",)\r\n  TEST: (\"cityscapes_test\",)\r\nSOLVER:\r\n  STEPS: (12000,)\r\n  MAX_ITER: 16000  # 16 epochs\r\n  WARMUP_ITERS: 100\r\n  CHECKPOINT_PERIOD: 2000\r\n  IMS_PER_BATCH: 8\r\n  BASE_LR: 0.01\r\nTEST:\r\n  EVAL_PERIOD: 2000\r\nVIS_PERIOD: 500\r\nVERSION: 2"
  },
  {
    "path": "examples/domain_adaptation/object_detection/config/retinanet_R_101_FPN_voc.yaml",
    "content": "MODEL:\n  META_ARCHITECTURE: \"TLRetinaNet\"\n  WEIGHTS: \"detectron2://ImageNetPretrained/MSRA/R-101.pkl\"\n  BACKBONE:\n    NAME: \"build_retinanet_resnet_fpn_backbone\"\n  MASK_ON: False\n  RESNETS:\n    DEPTH: 101\n    OUT_FEATURES: [ \"res4\", \"res5\" ]\n  ANCHOR_GENERATOR:\n    SIZES: !!python/object/apply:eval [ \"[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]\" ]\n  RETINANET:\n    NUM_CLASSES: 20\n    IN_FEATURES: [\"p4\", \"p5\", \"p6\", \"p7\"]\n    IOU_THRESHOLDS: [ 0.4, 0.5 ]\n    IOU_LABELS: [ 0, -1, 1 ]\n    SMOOTH_L1_LOSS_BETA: 0.0\n  FPN:\n    IN_FEATURES: [\"res4\", \"res5\"]\nINPUT:\n  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, )\n  MIN_SIZE_TEST: 608\n  MAX_SIZE_TRAIN: 1166\nDATASETS:\n  TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')\n  TEST: ('voc_2007_test',)\nSOLVER:\n  STEPS: (12000, )\n  MAX_ITER: 16000  # 16 epochs\n  WARMUP_ITERS: 100\n  CHECKPOINT_PERIOD: 2000\n  IMS_PER_BATCH: 8\n  BASE_LR: 0.005\nTEST:\n  EVAL_PERIOD: 2000\nVIS_PERIOD: 500\nVERSION: 2"
  },
  {
    "path": "examples/domain_adaptation/object_detection/cycle_gan.py",
    "content": "\"\"\"\nCycleGAN for VOC-format Object Detection Dataset\nYou need to modify function build_dataset if you want to use your own dataset.\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport itertools\nimport os\nimport tqdm\nfrom typing import Optional, Callable, Tuple, Any, List\nfrom PIL import Image\n\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader, ConcatDataset\nfrom torchvision.transforms import ToPILImage, Compose\nimport torchvision.datasets as datasets\nfrom torchvision.datasets.folder import default_loader\nimport torchvision.transforms as T\n\n\nsys.path.append('../../..')\nimport tllib.translation.cyclegan as cyclegan\nfrom tllib.translation.cyclegan.util import ImagePool, set_requires_grad\nfrom tllib.vision.transforms import Denormalize\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef make_power_2(img, base, method=Image.BICUBIC):\n    ow, oh = img.size\n    h = int(max(round(oh / base), 1) * base)\n    w = int(max(round(ow / base), 1) * base)\n    if h == oh and w == ow:\n        return img\n    return img.resize((w, h), method)\n\n\nclass VOCImageFolder(datasets.VisionDataset):\n    \"\"\"A VOC-format Dataset class for image translation\n    \"\"\"\n\n    def __init__(self, root: str, phase='trainval',\n                 transform: Optional[Callable] = None, extension='.jpg'):\n        super().__init__(root, transform=transform)\n        data_list_file = os.path.join(root, \"ImageSets/Main/{}.txt\".format(phase))\n        self.samples = self.parse_data_file(data_list_file, extension)\n        self.loader = default_loader\n        self.data_list_file = data_list_file\n\n    def __getitem__(self, index: int) -> Tuple[Any, str]:\n        \"\"\"\n        Args:\n            index (int): Index\n            return (tuple): (image, target) where target is index of the target class.\n        \"\"\"\n        path = self.samples[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        return img, path\n\n    def __len__(self) -> int:\n        return len(self.samples)\n\n    def parse_data_file(self, file_name: str, extension: str) -> List[str]:\n        \"\"\"Parse file to data list\n\n        Args:\n            file_name (str): The path of data file\n            return (list): List of (image path, class_index) tuples\n        \"\"\"\n        with open(file_name, \"r\") as f:\n            data_list = []\n            for line in f.readlines():\n                line = line.strip()\n                if extension is None:\n                    path = line\n                else:\n                    path = line + extension\n                if not os.path.isabs(path):\n                    path = os.path.join(self.root, \"JPEGImages\", path)\n                data_list.append((path))\n        return data_list\n\n    def translate(self, transform: Callable, target_root: str, image_base=4):\n        \"\"\" Translate an image and save it into a specified directory\n\n        Args:\n            transform (callable): a transform function that maps (image, label) pair from one domain to another domain\n            target_root (str): the root directory to save images and labels\n\n        \"\"\"\n        os.makedirs(target_root, exist_ok=True)\n        for path in tqdm.tqdm(self.samples):\n            image = Image.open(path).convert('RGB')\n            translated_path = path.replace(self.root, target_root)\n            ow, oh = image.size\n            image = make_power_2(image, image_base)\n            translated_image = transform(image)\n            translated_image = translated_image.resize((ow, oh))\n            os.makedirs(os.path.dirname(translated_path), exist_ok=True)\n            translated_image.save(translated_path)\n\n\ndef main(args):\n    logger = CompleteLogger(args.log, args.phase)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = T.Compose([\n        T.RandomRotation(args.rotation),\n        T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=args.resize_scale),\n        T.RandomHorizontalFlip(),\n        T.ToTensor(),\n        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n    ])\n    train_source_dataset = build_dataset(args.source[::2], args.source[1::2], train_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    train_target_dataset = build_dataset(args.target[::2], args.target[1::2], train_transform)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # define networks (both generators and discriminators)\n    netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n    netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n\n    # create image buffer to store previously generated images\n    fake_S_pool = ImagePool(args.pool_size)\n    fake_T_pool = ImagePool(args.pool_size)\n\n    # define optimizer and lr scheduler\n    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999))\n    optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))\n    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)\n    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)\n    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        print(\"Resume from\", args.resume)\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        netG_S2T.load_state_dict(checkpoint['netG_S2T'])\n        netG_T2S.load_state_dict(checkpoint['netG_T2S'])\n        netD_S.load_state_dict(checkpoint['netD_S'])\n        netD_T.load_state_dict(checkpoint['netD_T'])\n        optimizer_G.load_state_dict(checkpoint['optimizer_G'])\n        optimizer_D.load_state_dict(checkpoint['optimizer_D'])\n        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])\n        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    if args.phase == 'train':\n        # define loss function\n        criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()\n        criterion_cycle = nn.L1Loss()\n        criterion_identity = nn.L1Loss()\n\n        # define visualization function\n        tensor_to_image = Compose([\n            Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n            ToPILImage()\n        ])\n\n        def visualize(image, name):\n            \"\"\"\n            Args:\n                image (tensor): image in shape 3 x H x W\n                name: name of the saving image\n            \"\"\"\n            tensor_to_image(image).save(logger.get_image_path(\"{}.png\".format(name)))\n\n        # start training\n        for epoch in range(args.start_epoch, args.epochs+args.epochs_decay):\n            logger.set_epoch(epoch)\n            print(lr_scheduler_G.get_lr())\n\n            # train for one epoch\n            train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,\n                  criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,\n                  fake_S_pool, fake_T_pool, epoch, visualize, args)\n\n            # update learning rates\n            lr_scheduler_G.step()\n            lr_scheduler_D.step()\n\n            # save checkpoint\n            torch.save(\n                {\n                    'netG_S2T': netG_S2T.state_dict(),\n                    'netG_T2S': netG_T2S.state_dict(),\n                    'netD_S': netD_S.state_dict(),\n                    'netD_T': netD_T.state_dict(),\n                    'optimizer_G': optimizer_G.state_dict(),\n                    'optimizer_D': optimizer_D.state_dict(),\n                    'lr_scheduler_G': lr_scheduler_G.state_dict(),\n                    'lr_scheduler_D': lr_scheduler_D.state_dict(),\n                    'epoch': epoch,\n                    'args': args\n                }, logger.get_checkpoint_path('latest')\n            )\n\n    if args.translated_source is not None:\n        transform = cyclegan.transform.Translation(netG_S2T, device)\n        for dataset, translated_source in zip(train_source_dataset.datasets, args.translated_source):\n            dataset.translate(transform, translated_source, image_base=args.image_base)\n\n    if args.translated_target is not None:\n        transform = cyclegan.transform.Translation(netG_T2S, device)\n        for dataset, translated_target in zip(train_target_dataset.datasets, args.translated_target):\n            dataset.translate(transform, translated_target, image_base=args.image_base)\n\n    logger.close()\n\n\ndef train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,\n          criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,\n          fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')\n    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')\n    losses_D_S = AverageMeter('D_S', ':3.2f')\n    losses_D_T = AverageMeter('D_T', ':3.2f')\n    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')\n    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')\n    losses_identity_S = AverageMeter('idt_S', ':3.2f')\n    losses_identity_T = AverageMeter('idt_T', ':3.2f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,\n         losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        real_S, _ = next(train_source_iter)\n        real_T, _ = next(train_target_iter)\n\n        real_S = real_S.to(device)\n        real_T = real_T.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # Compute fake images and reconstruction images.\n        fake_T = netG_S2T(real_S)\n        rec_S = netG_T2S(fake_T)\n        fake_S = netG_T2S(real_T)\n        rec_T = netG_S2T(fake_S)\n\n        # Optimizing generators\n        # discriminators require no gradients\n        set_requires_grad(netD_S, False)\n        set_requires_grad(netD_T, False)\n\n        optimizer_G.zero_grad()\n        # GAN loss D_T(G_S2T(S))\n        loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)\n        # GAN loss D_S(G_T2S(B))\n        loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)\n        # Cycle loss || G_T2S(G_S2T(S)) - S||\n        loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle\n        # Cycle loss || G_S2T(G_T2S(T)) - T||\n        loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle\n        # Identity loss\n        # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||\n        identity_T = netG_S2T(real_T)\n        loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity\n        # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||\n        identity_S = netG_T2S(real_S)\n        loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity\n        # combined loss and calculate gradients\n        loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T\n        loss_G.backward()\n        optimizer_G.step()\n\n        # Optimize discriminator\n        set_requires_grad(netD_S, True)\n        set_requires_grad(netD_T, True)\n        optimizer_D.zero_grad()\n        # Calculate GAN loss for discriminator D_S\n        fake_S_ = fake_S_pool.query(fake_S.detach())\n        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))\n        loss_D_S.backward()\n        # Calculate GAN loss for discriminator D_T\n        fake_T_ = fake_T_pool.query(fake_T.detach())\n        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))\n        loss_D_T.backward()\n        optimizer_D.step()\n\n        # measure elapsed time\n        losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))\n        losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))\n        losses_D_S.update(loss_D_S.item(), real_S.size(0))\n        losses_D_T.update(loss_D_T.item(), real_S.size(0))\n        losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))\n        losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))\n        losses_identity_S.update(loss_identity_S.item(), real_S.size(0))\n        losses_identity_T.update(loss_identity_T.item(), real_S.size(0))\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n            for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T],\n                                    [\"real_S\", \"real_T\", \"fake_S\", \"fake_T\", \"rec_S\",\n                                     \"rec_T\", \"identity_S\", \"identity_T\"]):\n                visualize(tensor[0], \"{}_{}\".format(i, name))\n\n\ndef build_dataset(dataset_names, dataset_roots, transform):\n    \"\"\"\n    Give a sequence of dataset class name and a sequence of dataset root directory,\n    return a sequence of built datasets\n    \"\"\"\n    dataset_lists = []\n    for dataset_name, root in zip(dataset_names, dataset_roots):\n        if dataset_name in [\"WaterColor\", \"Comic\"]:\n            dataset = VOCImageFolder(root, phase='train', transform=transform)\n        elif dataset_name in [\"Cityscapes\", \"FoggyCityscapes\"]:\n            dataset = VOCImageFolder(root, phase=\"trainval\", transform=transform, extension=\".png\")\n        elif dataset_name in [\"Sim10k\"]:\n            dataset = VOCImageFolder(root, phase=\"trainval10k\", transform=transform)\n        else:\n            dataset = VOCImageFolder(root, phase=\"trainval\", transform=transform)\n        dataset_lists.append(dataset)\n    return ConcatDataset(dataset_lists)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='CycleGAN for Segmentation')\n    # dataset parameters\n    parser.add_argument('-s', '--source', nargs='+', help='source domain(s)')\n    parser.add_argument('-t', '--target', nargs='+', help='target domain(s)')\n    parser.add_argument('--rotation', type=int, default=0,\n                        help='rotation range of the RandomRotation augmentation')\n    parser.add_argument('--resize-ratio', nargs='+', type=float, default=(0.5, 1.0),\n                        help='the resize ratio for the random resize crop')\n    parser.add_argument('--resize-scale', nargs='+', type=float, default=(3./4., 4./3.),\n                        help='the resize scale for the random resize crop')\n    parser.add_argument('--train-size', nargs='+', type=int, default=(512, 512),\n                        help='the input and output image size during training')\n    # model parameters\n    parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n    parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n    parser.add_argument('--netD', type=str, default='patch',\n                        help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')\n    parser.add_argument('--netG', type=str, default='unet_256',\n                        help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')\n    parser.add_argument('--norm', type=str, default='instance',\n                        help='instance normalization or batch normalization [instance | batch | none]')\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')\n    parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=1, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 1)')\n    parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n    parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--epochs-decay', type=int, default=20,\n                        help='number of epochs to linearly decay learning rate to zero')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('--pool-size', type=int, default=50,\n                        help='the size of image buffer that stores previously generated images')\n    parser.add_argument('-p', '--print-freq', default=500, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='cyclegan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    # test parameters\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--test-input-size', nargs='+', type=int, default=(512, 512),\n                        help='the input image size during test')\n    parser.add_argument('--translated-source', type=str, default=None, nargs='+',\n                        help=\"The root to put the translated source dataset\")\n    parser.add_argument('--translated-target', type=str, default=None, nargs='+',\n                        help=\"The root to put the translated target dataset\")\n    parser.add_argument('--image-base', default=4, type=int,\n                        help='the input image will be multiple of image-base before translated')\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/cycle_gan.sh",
    "content": "# VOC to Clipart\nmkdir datasets/VOC2007_to_clipart\ncp -r datasets/VOC2007/* datasets/VOC2007_to_clipart\nmkdir datasets/VOC2012_to_clipart\ncp -r datasets/VOC2012/* datasets/VOC2012_to_clipart\n\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \\\n  --translated-source datasets/VOC2007_to_clipart datasets/VOC2012_to_clipart \\\n  --log logs/cyclegan_resnet9/translation/voc2clipart --netG resnet_9\n\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 VOC2007 datasets/VOC2007_to_clipart VOC2012 datasets/VOC2012_to_clipart \\\n  -t Clipart datasets/clipart \\\n  --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \\\n  OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2clipart\n\n# VOC to Comic\nmkdir datasets/VOC2007_to_comic\ncp -r datasets/VOC2007/* datasets/VOC2007_to_comic\nmkdir datasets/VOC2012_to_comic\ncp -r datasets/VOC2012/* datasets/VOC2012_to_comic\n\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Comic datasets/comic \\\n  --translated-source datasets/VOC2007_to_comic datasets/VOC2012_to_comic \\\n  --log logs/cyclegan_resnet9/translation/voc2comic --netG resnet_9\n\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 VOC2007Partial datasets/VOC2007_to_comic VOC2012Partial datasets/VOC2012_to_comic \\\n  -t Comic datasets/comic \\\n  --test VOC2007Test datasets/VOC2007 ComicTest datasets/comic --finetune \\\n  OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2comic MODEL.ROI_HEADS.NUM_CLASSES 6\n\n# VOC to WaterColor\nmkdir datasets/VOC2007_to_watercolor\ncp -r datasets/VOC2007/* datasets/VOC2007_to_watercolor\nmkdir datasets/VOC2012_to_watercolor\ncp -r datasets/VOC2012/* datasets/VOC2012_to_watercolor\n\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t WaterColor datasets/watercolor \\\n  --translated-source datasets/VOC2007_to_watercolor datasets/VOC2012_to_watercolor \\\n  --log logs/cyclegan_resnet9/translation/voc2watercolor --netG resnet_9\n\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 VOC2007Partial datasets/VOC2007_to_watercolor VOC2012Partial datasets/VOC2012_to_watercolor \\\n  -t WaterColor datasets/watercolor \\\n  --test VOC2007Test datasets/VOC2007 WaterColorTest datasets/watercolor --finetune \\\n  OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2watercolor MODEL.ROI_HEADS.NUM_CLASSES 6\n\n# Cityscapes to Foggy Cityscapes\nmkdir datasets/cityscapes_to_foggy_cityscapes\ncp -r datasets/cityscapes_in_voc/* datasets/cityscapes_to_foggy_cityscapes\n\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s Cityscapes datasets/cityscapes_in_voc \\\n  -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \\\n  --translated-source datasets/cityscapes_to_foggy_cityscapes \\\n  --log logs/cyclegan/translation/cityscapes2foggy\n\n# ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s Cityscapes datasets/cityscapes_in_voc/ Cityscapes datasets/cityscapes_to_foggy_cityscapes/ \\\n  -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \\\n  --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \\\n  OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/cityscapes2foggy\n\n# VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Cityscapes datasets/cityscapes_in_voc/ Cityscapes datasets/cityscapes_to_foggy_cityscapes/ \\\n  -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \\\n  --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \\\n  OUTPUT_DIR logs/cyclegan/faster_rcnn_vgg_16/cityscapes2foggy\n\n\n# Sim10k to Cityscapes Car\nmkdir datasets/sim10k_to_cityscapes_car\ncp -r datasets/sim10k/* datasets/sim10k_to_cityscapes_car\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s Sim10k datasets/sim10k -t Cityscapes datasets/cityscapes_in_voc \\\n    --log logs/cyclegan/translation/sim10k2cityscapes_car --translated-source datasets/sim10k_to_cityscapes_car --image-base 256\n\n# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s Sim10kCar datasets/sim10k Sim10kCar datasets/sim10k_to_cityscapes_car -t CityscapesCar datasets/cityscapes_in_voc/ \\\n  --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \\\n  OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1\n\n# VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Sim10kCar datasets/sim10k Sim10kCar datasets/sim10k_to_cityscapes_car -t CityscapesCar datasets/cityscapes_in_voc/  \\\n  --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \\\n  OUTPUT_DIR logs/cyclegan/faster_rcnn_vgg_16/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1\n\n# GTA5 to Cityscapes\nmkdir datasets/gta5_to_cityscapes\ncp -r datasets/synscapes_detection/* datasets/gta5_to_cityscapes\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s GTA5 datasets/synscapes_detection -t Cityscapes datasets/cityscapes_in_voc \\\n    --log logs/cyclegan/translation/gta52cityscapes --translated-source datasets/gta5_to_cityscapes --image-base 256\n\n# ResNet101 Based Faster RCNN: GTA5 -> Cityscapes\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s GTA5 datasets/synscapes_detection GTA5 datasets/gta5_to_cityscapes -t Cityscapes datasets/cityscapes_in_voc \\\n  --test CityscapesTest datasets/cityscapes_in_voc/ --finetune \\\n  OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/gta52cityscapes\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/README.md",
    "content": "# Decoupled Adaptation for Cross-Domain Object Detection\n\n## Installation\nOur code is based on \n- [Detectron latest(v0.6)](https://detectron2.readthedocs.io/en/latest/tutorials/install.html)\n- [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models)\n\nplease install them before usage.\n\n## Method\nCompared with previous cross-domain object detection methods, D-adapt decouples the adversarial adaptation from the training of detector.\n<div align=\"center\">\n\t<img src=\"./fig/comparison.png\" alt=\"Editor\" width=\"800\">\n</div>\n\nThe whole pipeline is as follows:\n<div align=\"center\">\n\t<img src=\"./fig/d_adapt_pipeline.png\" alt=\"Editor\" width=\"500\">\n</div>\n\nFirst, you need to run ``source_only.py`` to obtain pre-trained models. (See source_only.sh for scripts.)\nThen you need to run ``d_adapt.py`` to obtain adapted models. (See d_adapt.sh for scripts).\nWhen the domain discrepancy is large, you need to run  ``d_adapt.py`` multiple times.\n\nFor better readability, we implement the training of category adaptor in ``category_adaptation.py``, \nimplement the training of the bounding box adaptor in``bbox_adaptation.py``,\nand  implement the training of the detector and connect the above components in ``d_adapt.py``.\nThis can facilitate you to modify and replace other adaptors.\n\nWe provide independent training arguments for detector, category adaptor and bounding box adaptor.\nThe arguments of latter two end with ``-c`` and ``-b`` respectively.\n\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{jiang2021decoupled,\n  title     = {Decoupled Adaptation for Cross-Domain Object Detection},\n  author    = {Junguang Jiang and Baixu Chen and Jianmin Wang and Mingsheng Long},\n  booktitle = {ICLR},\n  year      = {2022}\n}\n```\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/bbox_adaptation.py",
    "content": "\"\"\"\nTraining a bounding box adaptor\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport os.path as osp\nimport argparse\nfrom collections import deque\nimport tqdm\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD, Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\nimport torch.nn.functional as F\nfrom detectron2.modeling.box_regression import Box2BoxTransform\n\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.modules.regressor import Regressor\nfrom tllib.alignment.mdd import ImageRegressor, RegressionMarginDisparityDiscrepancy\nfrom tllib.alignment.d_adapt.proposal import ProposalDataset, PersistentProposalList, flatten, ExpandCrop\n\nimport utils\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\nclass BoxTransform(nn.Module):\n    def __init__(self):\n        super(BoxTransform, self).__init__()\n        BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)\n        self.box_transform = Box2BoxTransform(weights=BBOX_REG_WEIGHTS)\n\n    def forward(self, pred_delta, gt_classes, proposal_boxes):\n        \"\"\"\n        Args:\n            - pred_delta: predicted bounding box offset for each classes\n            - gt_classes: ground truth classes\n            - proposal_boxes: referenced bounding box\n\n        Returns:\n            predicted bounding box offset for ground truth classes\n            and  predicted bounding box\n        \"\"\"\n        gt_class_cols = 4 * gt_classes[:, None] + torch.arange(4, device=device)\n        pred_delta = torch.gather(pred_delta, dim=1, index=gt_class_cols)\n        pred_box = self.box_transform.apply_deltas(pred_delta, proposal_boxes)\n        return pred_delta, pred_box\n\n\ndef iou_between(\n    boxes1: torch.Tensor,\n    boxes2: torch.Tensor,\n    eps: float = 1e-7,\n    reduction: str = \"none\"\n):\n    \"\"\"Intersections over Union between two boxes\"\"\"\n    x1, y1, x2, y2 = boxes1.unbind(dim=-1)\n    x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)\n\n    assert (x2 >= x1).all(), \"bad box: x1 larger than x2\"\n    assert (y2 >= y1).all(), \"bad box: y1 larger than y2\"\n\n    # Intersection keypoints\n    xkis1 = torch.max(x1, x1g)\n    ykis1 = torch.max(y1, y1g)\n    xkis2 = torch.min(x2, x2g)\n    ykis2 = torch.min(y2, y2g)\n\n    intsctk = torch.zeros_like(x1)\n    mask = (ykis2 > ykis1) & (xkis2 > xkis1)\n    intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])\n    unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk\n    iouk = intsctk / (unionk + eps)\n\n    if reduction == 'mean':\n        return iouk.mean()\n    elif reduction == 'sum':\n        return iouk.sum()\n    else:\n        return iouk\n\n\ndef clamp_single(box, w, h):\n    x1, y1, x2, y2 = box\n    x1 = x1.clamp(min=0, max=w)\n    x2 = x2.clamp(min=0, max=w)\n    y1 = y1.clamp(min=0, max=h)\n    y2 = y2.clamp(min=0, max=h)\n    return torch.tensor((x1, y1, x2, y2))\n\n\ndef clamp(boxes, widths, heights):\n    \"\"\"clamp (limit) the values in boxes within the widths and heights of the image.\"\"\"\n    clamped_boxes = []\n    for box, w, h in zip(boxes, widths, heights):\n        clamped_boxes.append(clamp_single(box, w, h))\n    return torch.stack(clamped_boxes, dim=0)\n\n\nclass BoundingBoxAdaptor:\n    def __init__(self, class_names, log, args):\n        self.class_names = class_names\n        for k, v in args._get_kwargs():\n            setattr(args, k.replace(\"_b\", \"\"), v)\n        self.args = args\n        print(self.args)\n        self.logger = CompleteLogger(log)\n        # create model\n        print(\"=> using pre-trained model '{}'\".format(args.arch))\n        backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n        num_classes = len(class_names)\n        bottleneck_dim = args.bottleneck_dim\n        bottleneck = nn.Sequential(\n            nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(bottleneck_dim),\n            nn.ReLU(),\n        )\n        head = nn.Sequential(\n            nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(bottleneck_dim),\n            nn.ReLU(),\n            nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            nn.Flatten(),\n            nn.Linear(bottleneck_dim, num_classes * 4),\n        )\n        for layer in head:\n            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n                nn.init.normal_(layer.weight, 0, 0.01)\n                nn.init.constant_(layer.bias, 0)\n        adv_head = nn.Sequential(\n            nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n            nn.BatchNorm2d(bottleneck_dim),\n            nn.ReLU(),\n            nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            nn.Flatten(),\n            nn.Linear(bottleneck_dim, num_classes * 4),\n        )\n        for layer in adv_head:\n            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n                nn.init.normal_(layer.weight, 0, 0.01)\n                nn.init.constant_(layer.bias, 0)\n        self.model = ImageRegressor(\n            backbone, num_classes * 4, bottleneck=bottleneck,\n            head=head, adv_head=adv_head\n        ).to(device)\n        self.box_transform = BoxTransform()\n\n    def load_checkpoint(self, path=None):\n        if path is None:\n            path = self.logger.get_checkpoint_path('latest')\n        if osp.exists(path):\n            checkpoint = torch.load(path, map_location='cpu')\n            self.model.load_state_dict(checkpoint)\n            return True\n        else:\n            return False\n\n    def prepare_training_data(self, proposal_list: PersistentProposalList, labeled=True):\n        if not labeled:\n            # remove (predicted) background proposals\n            filtered_proposals_list = []\n            for proposals in proposal_list:\n                keep_indices = (0 <= proposals.pred_classes) & (proposals.pred_classes < len(self.class_names))\n                filtered_proposals_list.append(proposals[keep_indices])\n        else:\n            # remove proposals with low IoU\n            filtered_proposals_list = []\n            for proposals in proposal_list:\n                keep_indices = proposals.gt_ious > 0.3\n                filtered_proposals_list.append(proposals[keep_indices])\n\n        filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_train)\n\n        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        transform = T.Compose([\n            T.Resize((self.args.resize_size, self.args.resize_size)),\n            # T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),\n            # T.RandomGrayscale(),\n            T.ToTensor(),\n            normalize\n        ])\n\n        dataset = ProposalDataset(filtered_proposals_list, transform, crop_func=ExpandCrop(self.args.expand))\n        dataloader = DataLoader(dataset, batch_size=self.args.batch_size,\n                                shuffle=True, num_workers=self.args.workers, drop_last=True)\n        return dataloader\n\n    def prepare_validation_data(self, proposal_list: PersistentProposalList):\n        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        transform = T.Compose([\n            T.Resize((self.args.resize_size, self.args.resize_size)),\n            T.ToTensor(),\n            normalize\n        ])\n\n        # remove (predicted) background proposals\n        filtered_proposals_list = []\n        for proposals in proposal_list:\n            # keep_indices = (0 <= proposals.gt_classes) & (proposals.gt_classes < len(self.class_names))\n            keep_indices = (0 <= proposals.pred_classes) & (proposals.pred_classes < len(self.class_names))\n            filtered_proposals_list.append(proposals[keep_indices])\n\n        filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_val)\n        dataset = ProposalDataset(filtered_proposals_list, transform, crop_func=ExpandCrop(self.args.expand))\n        dataloader = DataLoader(dataset, batch_size=self.args.batch_size,\n                                shuffle=False, num_workers=self.args.workers, drop_last=False)\n        return dataloader\n\n    def prepare_test_data(self, proposal_list: PersistentProposalList):\n        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        transform = T.Compose([\n            T.Resize((self.args.resize_size, self.args.resize_size)),\n            T.ToTensor(),\n            normalize\n        ])\n\n        dataset = ProposalDataset(proposal_list, transform, crop_func=ExpandCrop(self.args.expand))\n        dataloader = DataLoader(dataset, batch_size=self.args.batch_size,\n                                shuffle=False, num_workers=self.args.workers, drop_last=False)\n        return dataloader\n\n    def predict(self, data_loader):\n        # switch to evaluate mode\n        self.model.eval()\n        predictions = deque()\n\n        with torch.no_grad():\n            for images, labels in tqdm.tqdm(data_loader):\n                images = images.to(device)\n                pred_classes = labels['pred_classes'].to(device)\n                pred_boxes = labels['pred_boxes'].to(device).float()\n                # compute output\n                pred_deltas = self.model(images)\n                _, pred_boxes = self.box_transform(pred_deltas, pred_classes, pred_boxes)\n                pred_boxes = clamp(pred_boxes.cpu(), labels['width'], labels['height'])\n                pred_boxes = pred_boxes.numpy().tolist()\n                for p in pred_boxes:\n                    predictions.append(p)\n        return predictions\n\n    def validate_baseline(self, val_loader):\n        \"\"\"call this function if you have labeled data for validation\"\"\"\n        ious = AverageMeter(\"IoU\", \":.4e\")\n        print(\"Calculate baseline IoU:\")\n        for _, labels in tqdm.tqdm(val_loader):\n            gt_boxes = labels['gt_boxes']\n            pred_boxes = labels['pred_boxes']\n            ious.update(iou_between(pred_boxes, gt_boxes).mean().item(), gt_boxes.size(0))\n\n        print(' * Baseline IoU {:.3f}'.format(ious.avg))\n        return ious.avg\n\n    @staticmethod\n    def validate(val_loader, model, box_transform, args) -> float:\n        \"\"\"call this function if you have labeled data for validation\"\"\"\n        batch_time = AverageMeter('Time', ':6.3f')\n        ious = AverageMeter(\"IoU\", \":.4e\")\n        progress = ProgressMeter(\n            len(val_loader),\n            [batch_time, ious],\n            prefix='Test: ')\n\n        # switch to evaluate mode\n        model.eval()\n\n        with torch.no_grad():\n            end = time.time()\n            for i, (images, labels) in enumerate(val_loader):\n                images = images.to(device)\n                pred_classes = labels['pred_classes'].to(device)\n                gt_boxes = labels['gt_boxes'].to(device).float()\n                pred_boxes = labels['pred_boxes'].to(device).float()\n\n                # compute output\n                pred_deltas = model(images)\n                _, pred_boxes = box_transform(pred_deltas, pred_classes, pred_boxes)\n                pred_boxes = clamp(pred_boxes.cpu(), labels['width'], labels['height'])\n                ious.update(iou_between(pred_boxes, gt_boxes.cpu()).mean().item(), images.size(0))\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                if i % args.print_freq == 0:\n                    progress.display(i)\n\n            print(' * IoU {:.3f}'.format(ious.avg))\n\n        return ious.avg\n\n    def fit(self, data_loader_source, data_loader_target, data_loader_validation=None):\n        \"\"\"When no labels exists on target domain, please set data_loader_validation=None\"\"\"\n        args = self.args\n        print(args)\n        if args.seed is not None:\n            random.seed(args.seed)\n            torch.manual_seed(args.seed)\n            cudnn.deterministic = True\n            warnings.warn('You have chosen to seed training. '\n                          'This will turn on the CUDNN deterministic setting, '\n                          'which can slow down your training considerably! '\n                          'You may see unexpected behavior when restarting '\n                          'from checkpoints.')\n\n        cudnn.benchmark = True\n\n        iter_source = ForeverDataIterator(data_loader_source)\n        iter_target = ForeverDataIterator(data_loader_target)\n\n\n        best_iou = 0.\n        box_transform = self.box_transform\n\n        # first pre-train on the source domain\n        model = Regressor(\n            self.model.backbone, len(self.class_names) * 4,\n            bottleneck=nn.Sequential(\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            ),\n            head=nn.Linear(self.model.backbone.out_features, len(self.class_names) * 4),\n            bottleneck_dim=self.model.backbone.out_features\n        ).to(device)\n        optimizer = Adam(model.get_parameters(), args.pretrain_lr, weight_decay=args.pretrain_weight_decay)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.pretrain_lr * (1. + args.pretrain_lr_gamma * float(x)) ** (-args.pretrain_lr_decay))\n\n        for epoch in range(args.pretrain_epochs):\n            print(\"lr:\", lr_scheduler.get_last_lr()[0])\n            batch_time = AverageMeter('Time', ':3.1f')\n            data_time = AverageMeter('Data', ':3.1f')\n            losses = AverageMeter('Loss', ':3.2f')\n            ious = AverageMeter(\"IoU\", \":.4e\")\n            progress = ProgressMeter(\n                args.iters_per_epoch,\n                [batch_time, data_time, losses, ious],\n                prefix=\"Epoch: [{}]\".format(epoch))\n\n            # switch to train mode\n            model.train()\n\n            end = time.time()\n            for i in range(args.iters_per_epoch):\n                x_s, labels_s = next(iter_source)\n                x_s = x_s.to(device)\n                # bounding box offsets\n                delta_s = box_transform.box_transform.get_deltas(labels_s['pred_boxes'], labels_s['gt_boxes']).to(device).float()\n                pred_boxes_s = labels_s['pred_boxes'].to(device).float()\n                gt_classes_s = labels_s['gt_fg_classes'].to(device)\n                gt_boxes_s = labels_s['gt_boxes'].to(device).float()\n\n                # measure data loading time\n                data_time.update(time.time() - end)\n\n                # compute output\n                pred_delta_s, _ = model(x_s)\n                pred_delta_s, pred_boxes_s = box_transform(pred_delta_s, gt_classes_s, pred_boxes_s)\n                reg_loss = F.smooth_l1_loss(pred_delta_s, delta_s)\n                loss = reg_loss\n\n                losses.update(loss.item(), x_s.size(0))\n                ious.update(iou_between(pred_boxes_s.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0))\n\n                # compute gradient and do SGD step\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n                lr_scheduler.step()\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                if i % args.print_freq == 0:\n                    progress.display(i)\n\n            # evaluate on validation set\n            if data_loader_validation is not None:\n                iou = self.validate(data_loader_validation, model, box_transform, args)\n                best_iou = max(iou, best_iou)\n\n        # training on both domains\n        model = self.model\n        optimizer = SGD(model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n        for epoch in range(args.epochs):\n            print(\"lr:\", lr_scheduler.get_last_lr()[0])\n            # train for one epoch\n            batch_time = AverageMeter('Time', ':3.1f')\n            data_time = AverageMeter('Data', ':3.1f')\n            losses = AverageMeter('Loss', ':3.2f')\n            ious = AverageMeter(\"IoU\", \":.4e\")\n            ious_t = AverageMeter(\"IoU (t)\", \":.4e\")\n            ious_s_adv = AverageMeter(\"IoU (s, adv)\", \":.4e\")\n            ious_t_adv = AverageMeter(\"IoU (t, adv)\", \":.4e\")\n            trans_losses = AverageMeter('Trans Loss', ':3.2f')\n            progress = ProgressMeter(\n                args.iters_per_epoch,\n                [batch_time, data_time, losses, trans_losses, ious, ious_t, ious_s_adv, ious_t_adv],\n                prefix=\"Epoch: [{}]\".format(epoch))\n            # switch to train mode\n            model.train()\n            mdd = RegressionMarginDisparityDiscrepancy(args.margin).to(device)\n\n            end = time.time()\n            for i in range(args.iters_per_epoch):\n                x_s, labels_s = next(iter_source)\n                x_t, labels_t = next(iter_target)\n                x_s = x_s.to(device)\n                x_t = x_t.to(device)\n                # bounding box offsets\n                delta_s = box_transform.box_transform.get_deltas(labels_s['pred_boxes'], labels_s['gt_boxes']).to(device).float()\n                pred_boxes_s = labels_s['pred_boxes'].to(device).float()\n                gt_classes_s = labels_s['gt_fg_classes'].to(device)\n                gt_boxes_s = labels_s['gt_boxes'].to(device).float()\n                pred_boxes_t = labels_t['pred_boxes'].to(device).float()\n                gt_classes_t = labels_t['pred_classes'].to(device)\n                gt_boxes_t = labels_t['gt_boxes'].to(device).float()\n\n                # measure data loading time\n                data_time.update(time.time() - end)\n\n                # compute output\n                x = torch.cat([x_s, x_t], dim=0)\n                outputs, outputs_adv = model(x)\n                pred_delta_s, pred_delta_t = outputs.chunk(2, dim=0)\n                pred_delta_s_adv, pred_delta_t_adv = outputs_adv.chunk(2, dim=0)\n                pred_delta_s, pred_boxes_s = box_transform(pred_delta_s, gt_classes_s, pred_boxes_s)\n                pred_delta_t, pred_boxes_t = box_transform(pred_delta_t, gt_classes_t, pred_boxes_t)\n                pred_delta_s_adv, pred_boxes_s_adv = box_transform(pred_delta_s_adv, gt_classes_s, pred_boxes_s)\n                pred_delta_t_adv, pred_boxes_t_adv = box_transform(pred_delta_t_adv, gt_classes_t, pred_boxes_t)\n\n                reg_loss = F.smooth_l1_loss(pred_delta_s, delta_s)\n                # compute margin disparity discrepancy between domains\n                transfer_loss = mdd(pred_delta_s, pred_delta_s_adv, pred_delta_t, pred_delta_t_adv)\n                # for adversarial classifier, minimize negative mdd is equal to maximize mdd\n                loss = reg_loss - transfer_loss * args.trade_off\n                model.step()\n\n                losses.update(loss.item(), x_s.size(0))\n                ious.update(iou_between(pred_boxes_s.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0))\n                ious_t.update(iou_between(pred_boxes_t.cpu(), gt_boxes_t.cpu()).mean().item(), x_s.size(0))\n                ious_s_adv.update(iou_between(pred_boxes_s_adv.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0))\n                ious_t_adv.update(iou_between(pred_boxes_t_adv.cpu(), gt_boxes_t.cpu()).mean().item(), x_s.size(0))\n                trans_losses.update(transfer_loss.item(), x_s.size(0))\n\n                # compute gradient and do SGD step\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n                lr_scheduler.step()\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                if i % args.print_freq == 0:\n                    progress.display(i)\n\n            # evaluate on validation set\n            if data_loader_validation is not None:\n                iou = self.validate(data_loader_validation, model, box_transform, args)\n                best_iou = max(iou, best_iou)\n\n            # save checkpoint\n            torch.save(model.state_dict(), self.logger.get_checkpoint_path('latest'))\n\n        print(\"best_iou = {:3.1f}\".format(best_iou))\n\n        self.logger.logger.flush()\n\n    @staticmethod\n    def get_parser() -> argparse.ArgumentParser:\n        parser = argparse.ArgumentParser(add_help=False)\n        # dataset parameters\n        parser.add_argument('--resize-size-b', type=int, default=224,\n                            help='the image size after resizing')\n        parser.add_argument('--max-train-b', type=int, default=10)\n        parser.add_argument('--max-val-b', type=int, default=10)\n        parser.add_argument('--expand-b', type=float, default=2.,\n                            help='The expanding ratio between the input of the bounding box adaptor'\n                                 '(the crops of objects) and the the original predicted box.')\n        # model parameters\n        parser.add_argument('--arch-b', metavar='ARCH', default='resnet101',\n                            choices=utils.get_model_names(),\n                            help='backbone architecture: ' +\n                                 ' | '.join(utils.get_model_names()) +\n                                 ' (default: resnet101)')\n        parser.add_argument('--bottleneck-dim-b', default=1024, type=int,\n                            help='Dimension of bottleneck')\n        parser.add_argument('--no-pool-b', action='store_true',\n                            help='no pool layer after the feature extractor.')\n        parser.add_argument('--scratch-b', action='store_true', help='whether train from scratch.')\n        parser.add_argument('--margin', type=float, default=4., help=\"margin hyper-parameter\")\n        parser.add_argument('--trade-off', default=0.1, type=float,\n                            help='the trade-off hyper-parameter for transfer loss')\n        # training parameters\n        parser.add_argument('--batch-size-b', default=32, type=int,\n                            metavar='N',\n                            help='mini-batch size (default: 64)')\n        parser.add_argument('--lr-b', default=0.004, type=float,\n                            metavar='LR', help='initial learning rate')\n        parser.add_argument('--lr-gamma-b', default=0.0002, type=float, help='parameter for lr scheduler')\n        parser.add_argument('--lr-decay-b', default=0.75, type=float, help='parameter for lr scheduler')\n        parser.add_argument('--weight-decay-b', default=5e-4, type=float,\n                            metavar='W', help='weight decay (default: 5e-4)')\n        parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')\n        parser.add_argument('--workers-b', default=4, type=int, metavar='N',\n                            help='number of data loading workers (default: 2)')\n        parser.add_argument('--epochs-b', default=2, type=int, metavar='N',\n                            help='number of total epochs to run')\n        parser.add_argument('--pretrain-lr-b', default=0.001, type=float,\n                            metavar='LR', help='initial learning rate')\n        parser.add_argument('--pretrain-lr-gamma-b', default=0.0002, type=float, help='parameter for lr scheduler')\n        parser.add_argument('--pretrain-lr-decay-b', default=0.75, type=float, help='parameter for lr scheduler')\n        parser.add_argument('--pretrain-weight-decay-b', default=1e-3, type=float,\n                            metavar='W', help='weight decay (default: 1e-3)')\n        parser.add_argument('--pretrain-epochs-b', default=10, type=int, metavar='N',\n                            help='number of total epochs to run')\n        parser.add_argument('--iters-per-epoch-b', default=1000, type=int,\n                            help='Number of iterations per epoch')\n        parser.add_argument('--print-freq-b', default=100, type=int,\n                            metavar='N', help='print frequency (default: 100)')\n        parser.add_argument('--seed-b', default=None, type=int,\n                            help='seed for initializing training. ')\n        parser.add_argument(\"--log-b\", type=str, default='box',\n                            help=\"Where to save logs, checkpoints and debugging images.\")\n        return parser\n\n\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/category_adaptation.py",
    "content": "\"\"\"\nTraining a category adaptor\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport os.path as osp\nfrom collections import deque\nimport tqdm\nfrom typing import List\n\nimport torch\nfrom torch import Tensor\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\nimport torch.nn.functional as F\n\nsys.path.append('../../../..')\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier\nfrom tllib.alignment.d_adapt.proposal import ProposalDataset, flatten, Proposal\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.vision.transforms import ResizeImage\n\nimport utils\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\nclass ConfidenceBasedDataSelector:\n    \"\"\"Select data point based on confidence\"\"\"\n    def __init__(self, confidence_ratio=0.1, category_names=()):\n        self.confidence_ratio = confidence_ratio\n        self.categories = []\n        self.scores = []\n        self.category_names = category_names\n        self.per_category_thresholds = None\n\n    def extend(self, categories, scores):\n        self.categories.extend(categories)\n        self.scores.extend(scores)\n\n    def calculate(self):\n        per_category_scores = {c: [] for c in self.category_names}\n        for c, s in zip(self.categories, self.scores):\n            per_category_scores[c].append(s)\n\n        per_category_thresholds = {}\n        print(per_category_scores.keys())\n        for c, s in per_category_scores.items():\n            s.sort(reverse=True)\n            print(c, len(s), int(self.confidence_ratio * len(s)))\n            per_category_thresholds[c] = s[int(self.confidence_ratio * len(s))] if len(s) else 1.\n\n        print('----------------------------------------------------')\n        print(\"confidence threshold for each category:\")\n        for c in self.category_names:\n            print('\\t', c, round(per_category_thresholds[c], 3))\n        print('----------------------------------------------------')\n\n        self.per_category_thresholds = per_category_thresholds\n\n    def whether_select(self, categories, scores):\n        assert self.per_category_thresholds is not None, \"please call calculate before selection!\"\n        return [s > self.per_category_thresholds[c] for c, s in zip(categories, scores)]\n\n\nclass RobustCrossEntropyLoss(nn.CrossEntropyLoss):\n    \"\"\"Cross-entropy that's robust to label noise\"\"\"\n    def __init__(self, *args, offset=0.1, **kwargs):\n        self.offset = offset\n        super(RobustCrossEntropyLoss, self).__init__(*args, **kwargs)\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        return F.cross_entropy(torch.clamp(input + self.offset, max=1.), target, weight=self.weight,\n                               ignore_index=self.ignore_index, reduction='sum') / input.shape[0]\n\n\nclass CategoryAdaptor:\n    def __init__(self, class_names, log, args):\n        self.class_names = class_names\n        for k, v in args._get_kwargs():\n            setattr(args, k.rstrip(\"_c\"), v)\n        self.args = args\n        print(self.args)\n        self.logger = CompleteLogger(log)\n        self.selector = ConfidenceBasedDataSelector(self.args.confidence_ratio, range(len(self.class_names) + 1))\n\n        # create model\n        print(\"=> using model '{}'\".format(args.arch))\n        backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n        pool_layer = nn.Identity() if args.no_pool else None\n        num_classes = len(self.class_names) + 1\n        self.model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,\n                                     pool_layer=pool_layer, finetune=not args.scratch).to(device)\n\n    def load_checkpoint(self):\n        if osp.exists(self.logger.get_checkpoint_path('latest')):\n            checkpoint = torch.load(self.logger.get_checkpoint_path('latest'), map_location='cpu')\n            self.model.load_state_dict(checkpoint)\n            return True\n        else:\n            return False\n\n    def prepare_training_data(self, proposal_list: List[Proposal], labeled=True):\n        if not labeled:\n            # remove proposals with confidence score between (ignored_scores[0], ignored_scores[1])\n            filtered_proposals_list = []\n            assert len(self.args.ignored_scores) == 2 and self.args.ignored_scores[0] <= self.args.ignored_scores[1], \\\n                \"Please provide a range for ignored_scores!\"\n            for proposals in proposal_list:\n                keep_indices = ~((self.args.ignored_scores[0] < proposals.pred_scores)\n                                 & (proposals.pred_scores < self.args.ignored_scores[1]))\n                filtered_proposals_list.append(proposals[keep_indices])\n\n            # calculate confidence threshold for each cateogry on the target domain\n            for proposals in filtered_proposals_list:\n                self.selector.extend(proposals.pred_classes.tolist(), proposals.pred_scores.tolist())\n            self.selector.calculate()\n        else:\n            # remove proposals with ignored classes or ious between (ignored_ious[0], ignored_ious[1])\n            filtered_proposals_list = []\n            for proposals in proposal_list:\n                keep_indices = (proposals.gt_classes != -1) & \\\n                               ~((self.args.ignored_ious[0] < proposals.gt_ious) &\n                                 (proposals.gt_ious < self.args.ignored_ious[1]))\n                filtered_proposals_list.append(proposals[keep_indices])\n\n        filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_train)\n\n        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        transform = T.Compose([\n            ResizeImage(self.args.resize_size),\n            T.RandomHorizontalFlip(),\n            T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5),\n            T.RandomGrayscale(),\n            T.ToTensor(),\n            normalize\n        ])\n\n        dataset = ProposalDataset(filtered_proposals_list, transform)\n        dataloader = DataLoader(dataset, batch_size=self.args.batch_size,\n                                shuffle=True, num_workers=self.args.workers, drop_last=True)\n        return dataloader\n\n    def prepare_validation_data(self, proposal_list: List[Proposal]):\n        \"\"\"call this function if you have labeled data for validation\"\"\"\n        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        transform = T.Compose([\n            ResizeImage(self.args.resize_size),\n            T.ToTensor(),\n            normalize\n        ])\n\n        # remove proposals with ignored classes\n        filtered_proposals_list = []\n        for proposals in proposal_list:\n            keep_indices = proposals.gt_classes != -1\n            filtered_proposals_list.append(proposals[keep_indices])\n\n        filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_val)\n        dataset = ProposalDataset(filtered_proposals_list, transform)\n        dataloader = DataLoader(dataset, batch_size=self.args.batch_size,\n                                shuffle=False, num_workers=self.args.workers, drop_last=False)\n        return dataloader\n\n    def prepare_test_data(self, proposal_list: List[Proposal]):\n        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        transform = T.Compose([\n            ResizeImage(self.args.resize_size),\n            T.ToTensor(),\n            normalize\n        ])\n\n        dataset = ProposalDataset(proposal_list, transform)\n        dataloader = DataLoader(dataset, batch_size=self.args.batch_size,\n                                shuffle=False, num_workers=self.args.workers, drop_last=False)\n        return dataloader\n\n    def fit(self, data_loader_source, data_loader_target, data_loader_validation=None):\n        \"\"\"When no labels exists on target domain, please set data_loader_validation=None\"\"\"\n        args = self.args\n        if args.seed is not None:\n            random.seed(args.seed)\n            torch.manual_seed(args.seed)\n            cudnn.deterministic = True\n            warnings.warn('You have chosen to seed training. '\n                          'This will turn on the CUDNN deterministic setting, '\n                          'which can slow down your training considerably! '\n                          'You may see unexpected behavior when restarting '\n                          'from checkpoints.')\n\n        cudnn.benchmark = True\n\n        iter_source = ForeverDataIterator(data_loader_source)\n        iter_target = ForeverDataIterator(data_loader_target)\n\n        model = self.model\n        feature_dim = model.features_dim\n        num_classes = len(self.class_names) + 1\n\n        if args.randomized:\n            domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device)\n        else:\n            domain_discri = DomainDiscriminator(feature_dim * num_classes, hidden_size=1024).to(device)\n\n        all_parameters = model.get_parameters() + domain_discri.get_parameters()\n        # define optimizer and lr scheduler\n        optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n        # define loss function\n        domain_adv = ConditionalDomainAdversarialLoss(\n            domain_discri, entropy_conditioning=args.entropy,\n            num_classes=num_classes, features_dim=feature_dim, randomized=args.randomized,\n            randomized_dim=args.randomized_dim\n        ).to(device)\n\n        # start training\n        best_acc1 = 0.\n        for epoch in range(args.epochs):\n            print(\"lr:\", lr_scheduler.get_last_lr()[0])\n            # train for one epoch\n            batch_time = AverageMeter('Time', ':3.1f')\n            data_time = AverageMeter('Data', ':3.1f')\n            losses = AverageMeter('Loss', ':3.2f')\n            losses_t = AverageMeter('Loss(t)', ':3.2f')\n            trans_losses = AverageMeter('Trans Loss', ':3.2f')\n            cls_accs = AverageMeter('Cls Acc', ':3.1f')\n            domain_accs = AverageMeter('Domain Acc', ':3.1f')\n            progress = ProgressMeter(\n                args.iters_per_epoch,\n                [batch_time, data_time, losses, losses_t, trans_losses, cls_accs, domain_accs],\n                prefix=\"Epoch: [{}]\".format(epoch))\n\n            # switch to train mode\n            model.train()\n            domain_adv.train()\n\n            end = time.time()\n            for i in range(args.iters_per_epoch):\n                x_s, labels_s = next(iter_source)\n                x_t, labels_t = next(iter_target)\n\n                # assign pseudo labels for target-domain proposals with extremely high confidence\n                selected = torch.tensor(\n                    self.selector.whether_select(\n                        labels_t['pred_classes'].numpy().tolist(),\n                        labels_t['pred_scores'].numpy().tolist()\n                    )\n                )\n                pseudo_classes_t = selected * labels_t['pred_classes'] + (~selected) * -1\n                pseudo_classes_t = pseudo_classes_t.to(device)\n\n                x_s = x_s.to(device)\n                x_t = x_t.to(device)\n                gt_classes_s = labels_s['gt_classes'].to(device)\n\n                # measure data loading time\n                data_time.update(time.time() - end)\n\n                # compute output\n                x = torch.cat((x_s, x_t), dim=0)\n                y, f = model(x)\n                y_s, y_t = y.chunk(2, dim=0)\n                f_s, f_t = f.chunk(2, dim=0)\n\n                cls_loss = F.cross_entropy(y_s, gt_classes_s, ignore_index=-1)\n                cls_loss_t = RobustCrossEntropyLoss(ignore_index=-1, offset=args.epsilon)(y_t, pseudo_classes_t)\n                transfer_loss = domain_adv(y_s, f_s, y_t, f_t)\n                domain_acc = domain_adv.domain_discriminator_accuracy\n                loss = cls_loss + transfer_loss * args.trade_off + cls_loss_t\n\n                cls_acc = accuracy(y_s, gt_classes_s)[0]\n\n                losses.update(loss.item(), x_s.size(0))\n                cls_accs.update(cls_acc, x_s.size(0))\n                domain_accs.update(domain_acc, x_s.size(0))\n                trans_losses.update(transfer_loss.item(), x_s.size(0))\n                losses_t.update(cls_loss_t.item(), x_s.size(0))\n\n                # compute gradient and do SGD step\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n                lr_scheduler.step()\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                if i % args.print_freq == 0:\n                    progress.display(i)\n\n            # evaluate on validation set\n            if data_loader_validation is not None:\n                acc1 = self.validate(data_loader_validation, model, self.class_names, args)\n                best_acc1 = max(acc1, best_acc1)\n\n            # save checkpoint\n            torch.save(model.state_dict(), self.logger.get_checkpoint_path('latest'))\n\n        print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n        domain_adv.to(torch.device(\"cpu\"))\n        self.logger.logger.flush()\n\n    def predict(self, data_loader):\n        # switch to evaluate mode\n        self.model.eval()\n        predictions = deque()\n\n        with torch.no_grad():\n            for images, _ in tqdm.tqdm(data_loader):\n                images = images.to(device)\n\n                # compute output\n                output = self.model(images)\n                prediction = output.argmax(-1).cpu().numpy().tolist()\n                for p in prediction:\n                    predictions.append(p)\n        return predictions\n\n    @staticmethod\n    def validate(val_loader, model, class_names, args) -> float:\n        batch_time = AverageMeter('Time', ':6.3f')\n        losses = AverageMeter('Loss', ':.4e')\n        top1 = AverageMeter('Acc@1', ':6.2f')\n        progress = ProgressMeter(\n            len(val_loader),\n            [batch_time, losses, top1],\n            prefix='Test: ')\n\n        # switch to evaluate mode\n        model.eval()\n        confmat = ConfusionMatrix(len(class_names)+1)\n\n        with torch.no_grad():\n            end = time.time()\n            for i, (images, labels) in enumerate(val_loader):\n                images = images.to(device)\n                gt_classes = labels['gt_classes'].to(device)\n\n                # compute output\n                output = model(images)\n                loss = F.cross_entropy(output, gt_classes)\n\n                # measure accuracy and record loss\n                acc1, = accuracy(output, gt_classes, topk=(1,))\n                confmat.update(gt_classes, output.argmax(1))\n                losses.update(loss.item(), images.size(0))\n                top1.update(acc1.item(), images.size(0))\n\n                # measure elapsed time\n                batch_time.update(time.time() - end)\n                end = time.time()\n\n                if i % args.print_freq == 0:\n                    progress.display(i)\n\n            print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))\n            print(confmat.format(class_names+[\"bg\"]))\n\n        return top1.avg\n\n    @staticmethod\n    def get_parser() -> argparse.ArgumentParser:\n        parser = argparse.ArgumentParser(add_help=False)\n        # dataset parameters\n        parser.add_argument('--resize-size-c', type=int, default=112,\n                            help='the image size after resizing')\n        parser.add_argument('--ignored-scores-c', type=float, nargs='+', default=[0.05, 0.3])\n        parser.add_argument('--max-train-c', type=int, default=10)\n        parser.add_argument('--max-val-c', type=int, default=2)\n        parser.add_argument('--ignored-ious-c', type=float, nargs='+', default=(0.4, 0.5),\n                            help='the iou threshold for ignored boxes')\n        # model parameters\n        parser.add_argument('--arch-c', metavar='ARCH', default='resnet101',\n                            choices=utils.get_model_names(),\n                            help='backbone architecture: ' +\n                                 ' | '.join(utils.get_model_names()) +\n                                 ' (default: resnet101)')\n        parser.add_argument('--bottleneck-dim-c', default=1024, type=int,\n                            help='Dimension of bottleneck')\n        parser.add_argument('--no-pool-c', action='store_true',\n                            help='no pool layer after the feature extractor.')\n        parser.add_argument('--scratch-c', action='store_true', help='whether train from scratch.')\n        parser.add_argument('--randomized-c', action='store_true',\n                            help='using randomized multi-linear-map (default: False)')\n        parser.add_argument('--randomized-dim-c', default=1024, type=int,\n                            help='randomized dimension when using randomized multi-linear-map (default: 1024)')\n        parser.add_argument('--entropy-c', default=False, action='store_true', help='use entropy conditioning')\n        parser.add_argument('--trade-off-c', default=1., type=float,\n                            help='the trade-off hyper-parameter for transfer loss')\n        parser.add_argument('--confidence-ratio-c', default=0.0, type=float)\n        parser.add_argument('--epsilon-c', default=0.01, type=float,\n                            help='epsilon hyper-parameter in Robust Cross Entropy')\n        # training parameters\n        parser.add_argument('--batch-size-c', default=64, type=int,\n                            metavar='N',\n                            help='mini-batch size (default: 64)')\n        parser.add_argument('--learning-rate-c', default=0.01, type=float,\n                            metavar='LR', help='initial learning rate', dest='lr')\n        parser.add_argument('--lr-gamma-c', default=0.001, type=float, help='parameter for lr scheduler')\n        parser.add_argument('--lr-decay-c', default=0.75, type=float, help='parameter for lr scheduler')\n        parser.add_argument('--momentum-c', default=0.9, type=float, metavar='M', help='momentum')\n        parser.add_argument('--weight-decay-c', default=1e-3, type=float,\n                            metavar='W', help='weight decay (default: 1e-3)',\n                            dest='weight_decay')\n        parser.add_argument('--workers-c', default=2, type=int, metavar='N',\n                            help='number of data loading workers (default: 2)')\n        parser.add_argument('--epochs-c', default=10, type=int, metavar='N',\n                            help='number of total epochs to run')\n        parser.add_argument('--iters-per-epoch-c', default=1000, type=int,\n                            help='Number of iterations per epoch')\n        parser.add_argument('--print-freq-c', default=100, type=int,\n                            metavar='N', help='print frequency (default: 100)')\n        parser.add_argument('--seed-c', default=None, type=int,\n                            help='seed for initializing training. ')\n        parser.add_argument(\"--log-c\", type=str, default='cdan',\n                            help=\"Where to save logs, checkpoints and debugging images.\")\n        return parser\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_cityscapes.yaml",
    "content": "MODEL:\n  META_ARCHITECTURE: \"DecoupledGeneralizedRCNN\"\n  WEIGHTS: \"detectron2://ImageNetPretrained/MSRA/R-101.pkl\"\n  MASK_ON: False\n  RESNETS:\n    DEPTH: 101\n  ROI_HEADS:\n    NAME: \"DecoupledRes5ROIHeads\"\n    NUM_CLASSES: 8\n    BATCH_SIZE_PER_IMAGE: 512\n  ANCHOR_GENERATOR:\n    SIZES: [ [ 64, 128, 256, 512 ] ]\n  RPN:\n    PRE_NMS_TOPK_TEST: 6000\n    POST_NMS_TOPK_TEST: 1000\n    BATCH_SIZE_PER_IMAGE: 256\n  PROPOSAL_GENERATOR:\n    NAME: \"TLRPN\"\nINPUT:\n  MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,)\n  MIN_SIZE_TEST: 800\n  MAX_SIZE_TRAIN: 1166\nDATASETS:\n  TRAIN: (\"cityscapes_trainval\",)\n  TEST: (\"cityscapes_test\",)\nSOLVER:\n  STEPS: (3999, )\n  MAX_ITER: 4000  # 4 epochs\n  WARMUP_ITERS: 100\n  CHECKPOINT_PERIOD: 1000\n  IMS_PER_BATCH: 2\n  BASE_LR: 0.005\n  LR_SCHEDULER_NAME: \"ExponentialLR\"\n  GAMMA: 0.1\nTEST:\n  EVAL_PERIOD: 500\nVIS_PERIOD: 20\nVERSION: 2\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_voc.yaml",
    "content": "MODEL:\n  META_ARCHITECTURE: \"DecoupledGeneralizedRCNN\"\n  WEIGHTS: \"detectron2://ImageNetPretrained/MSRA/R-101.pkl\"\n  MASK_ON: False\n  RESNETS:\n    DEPTH: 101\n  ROI_HEADS:\n    NAME: \"DecoupledRes5ROIHeads\"\n    NUM_CLASSES: 20\n    BATCH_SIZE_PER_IMAGE: 256\n  ANCHOR_GENERATOR:\n    SIZES: [ [ 64, 128, 256, 512 ] ]\n  RPN:\n    PRE_NMS_TOPK_TEST: 6000\n    POST_NMS_TOPK_TEST: 1000\n    BATCH_SIZE_PER_IMAGE: 128\n  PROPOSAL_GENERATOR:\n    NAME: \"TLRPN\"\nINPUT:\n  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,)\n  MIN_SIZE_TEST: 608\n  MAX_SIZE_TRAIN: 1166\nDATASETS:\n  TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')\n  TEST: ('voc_2007_test',)\nSOLVER:\n  STEPS: (3999, )\n  MAX_ITER: 4000  # 16 epochs\n  WARMUP_ITERS: 100\n  CHECKPOINT_PERIOD: 1000\n  IMS_PER_BATCH: 4\n  BASE_LR: 0.00025\n  LR_SCHEDULER_NAME: \"ExponentialLR\"\n  GAMMA: 0.1\nTEST:\n  EVAL_PERIOD: 500\nVIS_PERIOD: 20\nVERSION: 2"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_vgg_16_cityscapes.yaml",
    "content": "MODEL:\r\n  META_ARCHITECTURE: \"DecoupledGeneralizedRCNN\"\r\n  WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth'\r\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\r\n  PIXEL_STD: [58.395, 57.120, 57.375]\r\n  MASK_ON: False\r\n  BACKBONE:\r\n    NAME: \"build_vgg_fpn_backbone\"\r\n  ROI_HEADS:\r\n    IN_FEATURES: [\"p3\", \"p4\", \"p5\", \"p6\"]\r\n    NAME: \"DecoupledStandardROIHeads\"\r\n    NUM_CLASSES: 8\r\n  ROI_BOX_HEAD:\r\n    NAME: \"FastRCNNConvFCHead\"\r\n    NUM_FC: 2\r\n    POOLER_RESOLUTION: 7\r\n  ANCHOR_GENERATOR:\r\n    SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ]  # One size for each in feature map\r\n    ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ]  # Three aspect\r\n  RPN:\r\n    IN_FEATURES: [\"p3\", \"p4\", \"p5\", \"p6\", \"p7\"]\r\n    PRE_NMS_TOPK_TRAIN: 2000  # Per FPN level\r\n    PRE_NMS_TOPK_TEST: 1000  # Per FPN level\r\n    POST_NMS_TOPK_TRAIN: 1000\r\n    POST_NMS_TOPK_TEST: 1000\r\n  PROPOSAL_GENERATOR:\r\n    NAME: \"TLRPN\"\r\nINPUT:\r\n  FORMAT: \"RGB\"\r\n  MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)\r\n  MIN_SIZE_TEST: 800\r\n  MAX_SIZE_TEST: 1280\r\n  MAX_SIZE_TRAIN: 1280\r\nDATASETS:\r\n  TRAIN: (\"cityscapes_trainval\",)\r\n  TEST: (\"cityscapes_test\",)\r\nSOLVER:\r\n  STEPS: (3999, )\r\n  MAX_ITER: 4000  # 4 epochs\r\n  WARMUP_ITERS: 100\r\n  CHECKPOINT_PERIOD: 1000\r\n  IMS_PER_BATCH: 8\r\n  BASE_LR: 0.01\r\n  LR_SCHEDULER_NAME: \"ExponentialLR\"\r\n  GAMMA: 0.1\r\nTEST:\r\n  EVAL_PERIOD: 500\r\nVIS_PERIOD: 20\r\nVERSION: 2"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/config/retinanet_R_101_FPN_voc.yaml",
    "content": "MODEL:\n  META_ARCHITECTURE: \"DecoupledRetinaNet\"\n  WEIGHTS: \"detectron2://ImageNetPretrained/MSRA/R-101.pkl\"\n  BACKBONE:\n    NAME: \"build_retinanet_resnet_fpn_backbone\"\n  MASK_ON: False\n  RESNETS:\n    DEPTH: 101\n    OUT_FEATURES: [ \"res4\", \"res5\" ]\n  ANCHOR_GENERATOR:\n    SIZES: !!python/object/apply:eval [ \"[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]\" ]\n  RETINANET:\n    NUM_CLASSES: 20\n    IN_FEATURES: [\"p4\", \"p5\", \"p6\", \"p7\"]\n  FPN:\n    IN_FEATURES: [\"res4\", \"res5\"]\nINPUT:\n  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, )\n  MIN_SIZE_TEST: 608\n  MAX_SIZE_TRAIN: 1166\nDATASETS:\n  TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')\n  TEST: ('voc_2007_test',)\nSOLVER:\n  STEPS: (3999, )\n  MAX_ITER: 4000  # 16 epochs\n  WARMUP_ITERS: 100\n  CHECKPOINT_PERIOD: 1000\n  IMS_PER_BATCH: 8\n  BASE_LR: 0.001\nTEST:\n  EVAL_PERIOD: 500\nVIS_PERIOD: 20\nVERSION: 2"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/d_adapt.py",
    "content": "\"\"\"\n`D-adapt: Decoupled Adaptation for Cross-Domain Object Detection <https://openreview.net/pdf?id=VNqaB1g9393>`_.\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport logging\nimport os\nimport argparse\nimport sys\nimport pprint\nimport numpy as np\n\nimport torch\nfrom torch.nn.parallel import DistributedDataParallel\nfrom detectron2.engine import default_writers, launch\nfrom detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer\nimport detectron2.utils.comm as comm\nfrom detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping\nfrom detectron2.data import (\n    build_detection_train_loader,\n    build_detection_test_loader,\n    MetadataCatalog\n)\nfrom detectron2.utils.events import EventStorage\nfrom detectron2.evaluation import inference_on_dataset\n\nsys.path.append('../../../..')\nimport tllib.alignment.d_adapt.modeling.meta_arch as models\nfrom tllib.alignment.d_adapt.proposal import ProposalGenerator, ProposalMapper, PersistentProposalList, flatten\nfrom tllib.alignment.d_adapt.feedback import get_detection_dataset_dicts, DatasetMapper\n\nsys.path.append('..')\nimport utils\n\nimport category_adaptation\nimport bbox_adaptation\n\n\ndef generate_proposals(model, num_classes, dataset_names, cache_root, cfg):\n    \"\"\"Generate foreground proposals and background proposals from `model` and save them to the disk\"\"\"\n    fg_proposals_list = PersistentProposalList(os.path.join(cache_root, \"{}_fg.json\".format(dataset_names[0])))\n    bg_proposals_list = PersistentProposalList(os.path.join(cache_root, \"{}_bg.json\".format(dataset_names[0])))\n    if not (fg_proposals_list.load() and bg_proposals_list.load()):\n        for dataset_name in dataset_names:\n            data_loader = build_detection_test_loader(cfg, dataset_name, mapper=ProposalMapper(cfg, False))\n            generator = ProposalGenerator(num_classes=num_classes)\n            fg_proposals_list_data, bg_proposals_list_data = inference_on_dataset(model, data_loader, generator)\n            fg_proposals_list.extend(fg_proposals_list_data)\n            bg_proposals_list.extend(bg_proposals_list_data)\n        fg_proposals_list.flush()\n        bg_proposals_list.flush()\n    return fg_proposals_list, bg_proposals_list\n\n\ndef generate_category_labels(prop, category_adaptor, cache_filename):\n    \"\"\"Generate category labels for each proposals in `prop` and save them to the disk\"\"\"\n    prop_w_category = PersistentProposalList(cache_filename)\n    if not prop_w_category.load():\n        for p in prop:\n            prop_w_category.append(p)\n\n        data_loader_test = category_adaptor.prepare_test_data(flatten(prop_w_category))\n        predictions = category_adaptor.predict(data_loader_test)\n        for p in prop_w_category:\n            p.pred_classes = np.array([predictions.popleft() for _ in range(len(p))])\n        prop_w_category.flush()\n    return prop_w_category\n\n\ndef generate_bounding_box_labels(prop, bbox_adaptor, class_names, cache_filename):\n    \"\"\"Generate bounding box labels for each proposals in `prop` and save them to the disk\"\"\"\n    prop_w_bbox = PersistentProposalList(cache_filename)\n    if not prop_w_bbox.load():\n        # remove (predicted) background proposals\n        for p in prop:\n            keep_indices = (0 <= p.pred_classes) & (p.pred_classes < len(class_names))\n            prop_w_bbox.append(p[keep_indices])\n\n        data_loader_test = bbox_adaptor.prepare_test_data(flatten(prop_w_bbox))\n        predictions = bbox_adaptor.predict(data_loader_test)\n        for p in prop_w_bbox:\n            p.pred_boxes = np.array([predictions.popleft() for _ in range(len(p))])\n        prop_w_bbox.flush()\n    return prop_w_bbox\n\n\ndef train(model, logger, cfg, args, args_cls, args_box):\n    model.train()\n    distributed = comm.get_world_size() > 1\n    if distributed:\n        model_without_parallel = model.module\n    else:\n        model_without_parallel = model\n\n    # define optimizer and lr scheduler\n    params = []\n    for module, lr in model_without_parallel.get_parameters(cfg.SOLVER.BASE_LR):\n        params.extend(\n            get_default_optimizer_params(\n                module,\n                base_lr=lr,\n                weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,\n                bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,\n                weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,\n            )\n        )\n    optimizer = maybe_add_gradient_clipping(cfg, torch.optim.SGD)(\n        params,\n        lr=cfg.SOLVER.BASE_LR,\n        momentum=cfg.SOLVER.MOMENTUM,\n        nesterov=cfg.SOLVER.NESTEROV,\n        weight_decay=cfg.SOLVER.WEIGHT_DECAY,\n    )\n    scheduler = utils.build_lr_scheduler(cfg, optimizer)\n\n    # resume from the last checkpoint\n    checkpointer = DetectionCheckpointer(\n        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler\n    )\n    checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)\n    start_iter = 0\n    max_iter = cfg.SOLVER.MAX_ITER\n\n    periodic_checkpointer = PeriodicCheckpointer(\n        checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter\n    )\n\n    writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []\n\n    # generate proposals from detector\n    classes = MetadataCatalog.get(args.targets[0]).thing_classes\n    cache_proposal_root = os.path.join(cfg.OUTPUT_DIR, \"cache\", \"proposal\")\n    prop_t_fg, prop_t_bg = generate_proposals(model, len(classes), args.targets, cache_proposal_root, cfg)\n    prop_s_fg, prop_s_bg = generate_proposals(model, len(classes), args.sources, cache_proposal_root, cfg)\n    model = model.to(torch.device('cpu'))\n\n    # train the category adaptor\n    category_adaptor = category_adaptation.CategoryAdaptor(classes, os.path.join(cfg.OUTPUT_DIR, \"cls\"), args_cls)\n    if not category_adaptor.load_checkpoint():\n        data_loader_source = category_adaptor.prepare_training_data(prop_s_fg + prop_s_bg, True)\n        data_loader_target = category_adaptor.prepare_training_data(prop_t_fg + prop_t_bg, False)\n        data_loader_validation = category_adaptor.prepare_validation_data(prop_t_fg + prop_t_bg)\n        category_adaptor.fit(data_loader_source, data_loader_target, data_loader_validation)\n\n    # generate category labels for each proposals\n    cache_feedback_root = os.path.join(cfg.OUTPUT_DIR, \"cache\", \"feedback\")\n    prop_t_fg = generate_category_labels(\n        prop_t_fg, category_adaptor, os.path.join(cache_feedback_root, \"{}_fg.json\".format(args.targets[0]))\n    )\n    prop_t_bg = generate_category_labels(\n        prop_t_bg, category_adaptor, os.path.join(cache_feedback_root, \"{}_bg.json\".format(args.targets[0]))\n    )\n    category_adaptor.model.to(torch.device(\"cpu\"))\n\n    if args.bbox_refine:\n        # train the bbox adaptor\n        bbox_adaptor = bbox_adaptation.BoundingBoxAdaptor(classes, os.path.join(cfg.OUTPUT_DIR, \"bbox\"), args_box)\n        if not bbox_adaptor.load_checkpoint():\n            data_loader_source = bbox_adaptor.prepare_training_data(prop_s_fg, True)\n            data_loader_target = bbox_adaptor.prepare_training_data(prop_t_fg, False)\n            data_loader_validation = bbox_adaptor.prepare_validation_data(prop_t_fg)\n            bbox_adaptor.validate_baseline(data_loader_validation)\n            bbox_adaptor.fit(data_loader_source, data_loader_target, data_loader_validation)\n\n        # generate bounding box labels for each proposals\n        cache_feedback_root = os.path.join(cfg.OUTPUT_DIR, \"cache\", \"feedback_bbox\")\n        prop_t_fg_refined = generate_bounding_box_labels(\n            prop_t_fg, bbox_adaptor, classes,\n            os.path.join(cache_feedback_root, \"{}_fg.json\".format(args.targets[0]))\n        )\n        prop_t_bg_refined = generate_bounding_box_labels(\n            prop_t_bg, bbox_adaptor, classes,\n            os.path.join(cache_feedback_root, \"{}_bg.json\".format(args.targets[0]))\n        )\n        prop_t_fg += prop_t_fg_refined\n        prop_t_bg += prop_t_bg_refined\n        bbox_adaptor.model.to(torch.device(\"cpu\"))\n\n    if args.reduce_proposals:\n        # remove proposals\n        prop_t_bg_new = []\n        for p in prop_t_bg:\n            keep_indices = p.pred_classes == len(classes)\n            prop_t_bg_new.append(p[keep_indices])\n        prop_t_bg = prop_t_bg_new\n\n        prop_t_fg_new = []\n        for p in prop_t_fg:\n            prop_t_fg_new.append(p[:20])\n        prop_t_fg = prop_t_fg_new\n\n    model = model.to(torch.device(cfg.MODEL.DEVICE))\n    # Data loading code\n    train_source_dataset = get_detection_dataset_dicts(args.sources)\n    train_source_loader = build_detection_train_loader(dataset=train_source_dataset, cfg=cfg)\n    train_target_dataset = get_detection_dataset_dicts(args.targets, proposals_list=prop_t_fg+prop_t_bg)\n\n    mapper = DatasetMapper(cfg, precomputed_proposal_topk=1000, augmentations=utils.build_augmentation(cfg, True))\n    train_target_loader = build_detection_train_loader(dataset=train_target_dataset, cfg=cfg, mapper=mapper,\n                                                       total_batch_size=cfg.SOLVER.IMS_PER_BATCH)\n\n    # training the object detector\n    logger.info(\"Starting training from iteration {}\".format(start_iter))\n    with EventStorage(start_iter) as storage:\n        for data_s, data_t, iteration in zip(train_source_loader, train_target_loader, range(start_iter, max_iter)):\n            storage.iter = iteration\n            optimizer.zero_grad()\n\n            # compute losses and gradient on source domain\n            loss_dict_s = model(data_s)\n            losses_s = sum(loss_dict_s.values())\n            assert torch.isfinite(losses_s).all(), loss_dict_s\n\n            loss_dict_reduced_s = {\"{}_s\".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_s).items()}\n            losses_reduced_s = sum(loss for loss in loss_dict_reduced_s.values())\n            losses_s.backward()\n\n            # compute losses and gradient on target domain\n            loss_dict_t = model(data_t, labeled=False)\n            losses_t = sum(loss_dict_t.values())\n            assert torch.isfinite(losses_t).all()\n\n            loss_dict_reduced_t = {\"{}_t\".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_t).items()}\n            (losses_t * args.trade_off).backward()\n\n            if comm.is_main_process():\n                storage.put_scalars(total_loss_s=losses_reduced_s, **loss_dict_reduced_s, **loss_dict_reduced_t)\n\n            # do SGD step\n            optimizer.step()\n            storage.put_scalar(\"lr\", optimizer.param_groups[0][\"lr\"], smoothing_hint=False)\n            scheduler.step()\n\n            # evaluate on validation set\n            if (\n                    cfg.TEST.EVAL_PERIOD > 0\n                    and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0\n                    and iteration != max_iter - 1\n            ):\n                utils.validate(model, logger, cfg, args)\n                comm.synchronize()\n\n            if iteration - start_iter > 5 and (\n                    (iteration + 1) % 20 == 0 or iteration == max_iter - 1\n            ):\n                for writer in writers:\n                    writer.write()\n            periodic_checkpointer.step(iteration)\n\n\ndef main(args, args_cls, args_box):\n    logger = logging.getLogger(\"detectron2\")\n    cfg = utils.setup(args)\n\n    # dataset\n    args.sources = utils.build_dataset(args.sources[::2], args.sources[1::2])\n    args.targets = utils.build_dataset(args.targets[::2], args.targets[1::2])\n    args.test = utils.build_dataset(args.test[::2], args.test[1::2])\n\n    # create model\n    model = models.__dict__[cfg.MODEL.META_ARCHITECTURE](cfg, finetune=args.finetune)\n    model.to(torch.device(cfg.MODEL.DEVICE))\n    logger.info(\"Model:\\n{}\".format(model))\n\n    if args.eval_only:\n        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(\n            cfg.MODEL.WEIGHTS, resume=args.resume\n        )\n        return utils.validate(model, logger, cfg, args)\n\n    distributed = comm.get_world_size() > 1\n    if distributed:\n        model = DistributedDataParallel(\n            model, device_ids=[comm.get_local_rank()], broadcast_buffers=False\n        )\n\n    train(model, logger, cfg, args, args_cls, args_box)\n\n    # evaluate on validation set\n    return utils.validate(model, logger, cfg, args)\n\n\nif __name__ == \"__main__\":\n    args_cls, argv = category_adaptation.CategoryAdaptor.get_parser().parse_known_args()\n    print(\"Category Adaptation Args:\")\n    pprint.pprint(args_cls)\n\n    args_box, argv = bbox_adaptation.BoundingBoxAdaptor.get_parser().parse_known_args(args=argv)\n    print(\"Bounding Box Adaptation Args:\")\n    pprint.pprint(args_box)\n\n    parser = argparse.ArgumentParser(add_help=True)\n    # dataset parameters\n    parser.add_argument('-s', '--sources', nargs='+', help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', help='target domain(s)')\n    parser.add_argument('--test', nargs='+', help='test domain(s)')\n    # model parameters\n    parser.add_argument('--finetune', action='store_true',\n                        help='whether use 10x smaller learning rate for backbone')\n    parser.add_argument(\n        \"--resume\",\n        action=\"store_true\",\n        help=\"Whether to attempt to resume from the checkpoint directory. \"\n             \"See documentation of `DefaultTrainer.resume_or_load()` for what it means.\",\n    )\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='trade-off hyper-parameter for losses on target domain')\n    parser.add_argument('--bbox-refine', action='store_true',\n                        help='whether perform bounding box refinement')\n    parser.add_argument('--reduce-proposals', action='store_true',\n                        help='whether remove some low-quality proposals.'\n                             'Helpful for RetinaNet')\n    # training parameters\n    parser.add_argument(\"--config-file\", default=\"\", metavar=\"FILE\", help=\"path to config file\")\n    parser.add_argument(\"--eval-only\", action=\"store_true\", help=\"perform evaluation only\")\n    parser.add_argument(\"--num-gpus\", type=int, default=1, help=\"number of gpus *per machine*\")\n    parser.add_argument(\"--num-machines\", type=int, default=1, help=\"total number of machines\")\n    parser.add_argument(\"--machine-rank\", type=int, default=0,\n                        help=\"the rank of this machine (unique per machine)\")\n    # PyTorch still may leave orphan processes in multi-gpu training.\n    # Therefore we use a deterministic way to obtain port,\n    # so that users are aware of orphan processes by seeing the port occupied.\n    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != \"win32\" else 1) % 2 ** 14\n    parser.add_argument(\n        \"--dist-url\",\n        default=\"tcp://127.0.0.1:{}\".format(port),\n        help=\"initialization URL for pytorch distributed backend. See \"\n             \"https://pytorch.org/docs/stable/distributed.html for details.\",\n    )\n    parser.add_argument(\n        \"opts\",\n        help=\"Modify config options by adding 'KEY VALUE' pairs at the end of the command. \"\n             \"See config references at \"\n             \"https://detectron2.readthedocs.io/modules/config.html#config-references\",\n        default=None,\n        nargs=argparse.REMAINDER,\n    )\n    args, argv = parser.parse_known_args(argv)\n    print(\"Detection Args:\")\n    pprint.pprint(args)\n\n    launch(\n        main,\n        args.num_gpus,\n        num_machines=args.num_machines,\n        machine_rank=args.machine_rank,\n        dist_url=args.dist_url,\n        args=(args, args_cls, args_box),\n    )\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/d_adapt/d_adapt.sh",
    "content": "# ResNet101 Based Faster RCNN: Faster RCNN: VOC->Clipart\n# 44.8\npretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012  \\\n  -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \\\n  --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# 47.9\npretrained_models=logs/faster_rcnn_R_101_C4/voc2clipart/phase1/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012  \\\n  -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \\\n  --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# 49.0\npretrained_models=logs/faster_rcnn_R_101_C4/voc2clipart/phase2/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.2 \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012  \\\n  -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \\\n  --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase3 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# ResNet101 Based Faster RCNN: Faster RCNN: VOC->WaterColor\n# 54.1\npretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012  \\\n  -t WaterColor ../datasets/watercolor --test WaterColorTest ../datasets/watercolor --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2watercolor/phase1 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# 57.5\npretrained_models=logs/faster_rcnn_R_101_C4/voc2watercolor/phase1/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012  \\\n  -t WaterColor ../datasets/watercolor --test WaterColorTest ../datasets/watercolor --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2watercolor/phase2 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# ResNet101 Based Faster RCNN: Faster RCNN: VOC->Comic\n# 39.7\npretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012  \\\n  -t Comic ../datasets/comic --test ComicTest ../datasets/comic --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2comic/phase1 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# 41.0\npretrained_models=logs/faster_rcnn_R_101_C4/voc2comic/phase1/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012  \\\n  -t Comic ../datasets/comic --test ComicTest ../datasets/comic --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2comic/phase2 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# ResNet101 Based Faster RCNN: Cityscapes -> Foggy Cityscapes\n# 40.1\npretrained_models=../logs/source_only/faster_rcnn_R_101_C4/cityscapes2foggy/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/  \\\n  --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# 42.4\npretrained_models=logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase1/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.1 \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/  \\\n  --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n\n# VGG Based Faster RCNN: Cityscapes -> Foggy Cityscapes\n# 33.3\npretrained_models=../logs/source_only/faster_rcnn_vgg_16/cityscapes2foggy/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/  \\\n  --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# 37.0\npretrained_models=logs/faster_rcnn_vgg_16/cityscapes2foggy/phase1/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.1 \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/  \\\n  --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n#  38.9\npretrained_models=logs/faster_rcnn_vgg_16/cityscapes2foggy/phase2/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.2 \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/  \\\n  --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase3 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car\n# 51.9\npretrained_models=../logs/source_only/faster_rcnn_R_101_C4/sim10k2cityscapes_car/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 8 --ignored-scores-c 0.05 0.5 --bottleneck-dim-c 256 --bottleneck-dim-b 256 \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s Sim10kCar ../datasets/sim10k -t CityscapesCar ../datasets/cityscapes_in_voc/  \\\n  --test CityscapesCarTest ../datasets/cityscapes_in_voc/ --finetune --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_R_101_C4/sim10k2cityscapes_car/phase1 MODEL.ROI_HEADS.NUM_CLASSES 1 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# VGG Based Faster RCNN: Sim10k -> Cityscapes Car\n# 49.3\npretrained_models=../logs/source_only/faster_rcnn_vgg_16/sim10k2cityscapes_car/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 8 --ignored-scores-c 0.05 0.5 --bottleneck-dim-c 256 --bottleneck-dim-b 256 \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Sim10kCar ../datasets/sim10k -t CityscapesCar ../datasets/cityscapes_in_voc/  \\\n  --test CityscapesCarTest ../datasets/cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \\\n  OUTPUT_DIR logs/faster_rcnn_vgg_16/sim10k2cityscapes_car/phase1 MODEL.ROI_HEADS.NUM_CLASSES 1 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# RetinaNet: VOC->Clipart\n# 44.7\npretrained_models=../logs/source_only/retinanet_R_101_FPN/voc2clipart/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --remove-bg \\\n  --config-file config/retinanet_R_101_FPN_voc.yaml \\\n  -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012  \\\n  -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \\\n  --finetune --bbox-refine \\\n  OUTPUT_DIR logs/retinanet_R_101_FPN/voc2clipart/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0\n\n# 46.3\npretrained_models=logs/retinanet_R_101_FPN/voc2clipart/phase1/model_final.pth\nCUDA_VISIBLE_DEVICES=0 python d_adapt.py --remove-bg --confidence-ratio 0.1 \\\n  --config-file config/retinanet_R_101_FPN_voc.yaml \\\n  -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012  \\\n  -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \\\n  --finetune --bbox-refine \\\n  OUTPUT_DIR logs/retinanet_R_101_FPN/voc2clipart/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/oracle.sh",
    "content": "# Faster RCNN: WaterColor\r\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\r\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\r\n  -s WaterColor datasets/watercolor -t WaterColor datasets/watercolor \\\r\n  --test WaterColorTest datasets/watercolor  --finetune \\\r\n  OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/watercolor MODEL.ROI_HEADS.NUM_CLASSES 6\r\n\r\n# Faster RCNN: Comic\r\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\r\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\r\n  -s Comic datasets/comic -t Comic datasets/comic \\\r\n  --test ComicTest datasets/comic  --finetune \\\r\n  OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/comic MODEL.ROI_HEADS.NUM_CLASSES 6\r\n\r\n# ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes\r\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\r\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\r\n  -s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \\\r\n  --test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \\\r\n  OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes2foggy\r\n\r\n# VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes\r\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\r\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\r\n  -s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \\\r\n  --test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \\\r\n  OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes2foggy\r\n\r\n# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car\r\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\r\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\r\n  -s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/  \\\r\n  --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \\\r\n  OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1\r\n\r\n# VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car\r\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\r\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\r\n -s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/  \\\r\n  --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \\\r\n  OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1\r\n\r\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/prepare_cityscapes_to_voc.py",
    "content": "from pascal_voc_writer import Writer\r\nimport matplotlib.pyplot as plt\r\nimport numpy as np\r\nimport os\r\nimport json\r\nimport glob\r\nimport time\r\nfrom shutil import move, copy\r\nimport tqdm\r\n\r\nclasses = {'bicycle': 'bicycle', 'bus': 'bus', 'car': 'car', 'motorcycle': 'motorcycle',\r\n           'person': 'person', 'rider': 'rider', 'train': 'train', 'truck': 'truck'}\r\nclasses_keys = list(classes.keys())\r\n\r\n\r\ndef make_dir(path):\r\n    if not os.path.isdir(path):\r\n        os.makedirs(path)\r\n\r\n#----------------------------------------------------------------------------------------------------------------\r\n#convert polygon to bounding box\r\n#code from:\r\n#https://stackoverflow.com/questions/46335488/how-to-efficiently-find-the-bounding-box-of-a-collection-of-points\r\n#----------------------------------------------------------------------------------------------------------------\r\ndef polygon_to_bbox(polygon):\r\n    x_coordinates, y_coordinates = zip(*polygon)\r\n    return [min(x_coordinates), min(y_coordinates), max(x_coordinates), max(y_coordinates)]\r\n\r\n\r\n# --------------------------------------------\r\n# read a json file and convert to voc format\r\n# --------------------------------------------\r\ndef read_json(file):\r\n    # if no relevant objects found in the image,\r\n    # don't save the xml for the image\r\n    relevant_file = False\r\n\r\n    data = []\r\n    with open(file, 'r') as f:\r\n        file_data = json.load(f)\r\n\r\n        for object in file_data['objects']:\r\n            label, polygon = object['label'], object['polygon']\r\n\r\n            # process only if label found in voc\r\n            if label in classes_keys:\r\n                polygon = np.array([x for x in polygon])\r\n                bbox = polygon_to_bbox(polygon)\r\n                data.append([classes[label]] + bbox)\r\n\r\n        # if relevant objects found in image, set the flag to True\r\n        if data:\r\n            relevant_file = True\r\n\r\n    return data, relevant_file\r\n\r\n\r\n#---------------------------\r\n#function to save xml file\r\n#---------------------------\r\ndef save_xml(img_path, img_shape, data, save_path):\r\n    writer = Writer(img_path,img_shape[1], img_shape[0])\r\n    for element in data:\r\n        writer.addObject(element[0],element[1],element[2],element[3],element[4])\r\n    writer.save(save_path)\r\n\r\n\r\ndef prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir):\r\n    cityscapes_dir_gt = os.path.join(cityscapes_dir, 'gtFine')\r\n\r\n    # ------------------------------------------\r\n    # reading json files from each subdirectory\r\n    # ------------------------------------------\r\n    valid_files = []\r\n    trainval_files = []\r\n    test_files = []\r\n\r\n    # make Annotations target directory if already doesn't exist\r\n    ann_dir = os.path.join(save_path, 'Annotations')\r\n    make_dir(ann_dir)\r\n\r\n    start = time.time()\r\n    for category in os.listdir(cityscapes_dir_gt):\r\n\r\n        # # no GT for test data\r\n        # if category == 'test':\r\n        #     continue\r\n\r\n        for city in tqdm.tqdm(os.listdir(os.path.join(cityscapes_dir_gt, category))):\r\n\r\n            # read files\r\n            files = glob.glob(os.path.join(cityscapes_dir, 'gtFine', category, city) + '/*.json')\r\n\r\n            # process json files\r\n            for file in files:\r\n                data, relevant_file = read_json(file)\r\n\r\n                if relevant_file:\r\n                    base_filename = os.path.basename(file)[:-21]\r\n                    xml_filepath = os.path.join(ann_dir, base_filename + '{}.xml'.format(suffix))\r\n                    img_name = base_filename + '{}.png'.format(suffix)\r\n                    img_path = os.path.join(cityscapes_dir, image_dir, category, city,\r\n                                            base_filename + '{}.png'.format(suffix))\r\n                    img_shape = plt.imread(img_path).shape\r\n                    valid_files.append([img_path, img_name])\r\n\r\n                    # make list of trainval and test files for voc format\r\n                    # lists will be stored in txt files\r\n                    trainval_files.append(img_name[:-4]) if category == 'train' else test_files.append(img_name[:-4])\r\n\r\n                    # save xml file\r\n                    save_xml(os.path.join(image_dir, category, city,\r\n                                            base_filename + '{}.png'.format(suffix)), img_shape, data, xml_filepath)\r\n\r\n    end = time.time() - start\r\n    print('Total Time taken: ', end)\r\n\r\n    # ----------------------------\r\n    # copy files into target path\r\n    # ----------------------------\r\n    images_savepath = os.path.join(save_path, 'JPEGImages')\r\n    make_dir(images_savepath)\r\n\r\n    start = time.time()\r\n    for file in valid_files:\r\n        copy(file[0], os.path.join(images_savepath, file[1]))\r\n\r\n    print('Total Time taken: ', end)\r\n\r\n    # ---------------------------------------------\r\n    # create text files of trainval and test files\r\n    # ---------------------------------------------\r\n    textfiles_savepath = os.path.join(save_path, 'ImageSets', 'Main')\r\n    make_dir(textfiles_savepath)\r\n\r\n    traival_files_wr = [x + '\\n' for x in trainval_files]\r\n    test_files_wr = [x + '\\n' for x in test_files]\r\n\r\n    with open(os.path.join(textfiles_savepath, 'trainval.txt'), 'w') as f:\r\n        f.writelines(traival_files_wr)\r\n\r\n    with open(os.path.join(textfiles_savepath, 'test.txt'), 'w') as f:\r\n        f.writelines(test_files_wr)\r\n\r\n\r\nif __name__ == '__main__':\r\n    cityscapes_dir = 'datasets/cityscapes/'\r\n    if not os.path.exists(cityscapes_dir):\r\n        print(\"Please put cityscapes datasets in: {}\".format(cityscapes_dir))\r\n        exit(0)\r\n    save_path = 'datasets/cityscapes_in_voc/'\r\n    suffix = \"_leftImg8bit\"\r\n    image_dir = \"leftImg8bit\"\r\n    prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir)\r\n\r\n    save_path = 'datasets/foggy_cityscapes_in_voc/'\r\n    suffix = \"_leftImg8bit_foggy_beta_0.02\"\r\n    image_dir = \"leftImg8bit_foggy\"\r\n    prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir)\r\n\r\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/requirements.txt",
    "content": "mmcv\ntimm\nprettytable\npascal_voc_writer\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/source_only.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport logging\nimport os\nimport argparse\nimport sys\n\nimport torch\nfrom torch.nn.parallel import DistributedDataParallel\nfrom detectron2.engine import default_writers, launch\nfrom detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer\nimport detectron2.utils.comm as comm\nfrom detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping\nfrom detectron2.solver import build_lr_scheduler\nfrom detectron2.data import (\n    build_detection_train_loader,\n    get_detection_dataset_dicts,\n)\nfrom detectron2.utils.events import EventStorage\n\nsys.path.append('../../..')\nimport tllib.vision.models.object_detection.meta_arch as models\n\nimport utils\n\n\ndef train(model, logger, cfg, args):\n    model.train()\n    distributed = comm.get_world_size() > 1\n    if distributed:\n        model_without_parallel = model.module\n    else:\n        model_without_parallel = model\n\n    # define optimizer and lr scheduler\n    params = []\n    for module, lr in model_without_parallel.get_parameters(cfg.SOLVER.BASE_LR):\n        params.extend(\n            get_default_optimizer_params(\n                module,\n                base_lr=lr,\n                weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,\n                bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,\n                weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,\n            )\n        )\n    optimizer = maybe_add_gradient_clipping(cfg, torch.optim.SGD)(\n        params,\n        lr=cfg.SOLVER.BASE_LR,\n        momentum=cfg.SOLVER.MOMENTUM,\n        nesterov=cfg.SOLVER.NESTEROV,\n        weight_decay=cfg.SOLVER.WEIGHT_DECAY,\n    )\n    scheduler = build_lr_scheduler(cfg, optimizer)\n\n    # resume from the last checkpoint\n    checkpointer = DetectionCheckpointer(\n        model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler\n    )\n    start_iter = (\n        checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume).get(\"iteration\", -1) + 1\n    )\n    max_iter = cfg.SOLVER.MAX_ITER\n\n    periodic_checkpointer = PeriodicCheckpointer(\n        checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter\n    )\n\n    writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []\n\n    # Data loading code\n    train_source_dataset = get_detection_dataset_dicts(args.source)\n    train_source_loader = build_detection_train_loader(dataset=train_source_dataset, cfg=cfg)\n\n    # start training\n    logger.info(\"Starting training from iteration {}\".format(start_iter))\n    with EventStorage(start_iter) as storage:\n        for data_s, iteration in zip(train_source_loader, range(start_iter, max_iter)):\n            storage.iter = iteration\n\n            # compute output\n            _, loss_dict_s = model(data_s)\n            losses_s = sum(loss_dict_s.values())\n            assert torch.isfinite(losses_s).all(), loss_dict_s\n\n            loss_dict_reduced_s = {\"{}_s\".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_s).items()}\n            losses_reduced_s = sum(loss for loss in loss_dict_reduced_s.values())\n            if comm.is_main_process():\n                storage.put_scalars(total_loss_s=losses_reduced_s, **loss_dict_reduced_s)\n\n            # compute gradient and do SGD step\n            optimizer.zero_grad()\n            losses_s.backward()\n            optimizer.step()\n            storage.put_scalar(\"lr\", optimizer.param_groups[0][\"lr\"], smoothing_hint=False)\n            scheduler.step()\n\n            # evaluate on validation set\n            if (\n                    cfg.TEST.EVAL_PERIOD > 0\n                    and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0\n                    and iteration != max_iter - 1\n            ):\n                utils.validate(model, logger, cfg, args)\n                comm.synchronize()\n\n            if iteration - start_iter > 5 and (\n                    (iteration + 1) % 20 == 0 or iteration == max_iter - 1\n            ):\n                for writer in writers:\n                    writer.write()\n            periodic_checkpointer.step(iteration)\n\n\ndef main(args):\n    logger = logging.getLogger(\"detectron2\")\n    cfg = utils.setup(args)\n\n    # dataset\n    args.source = utils.build_dataset(args.source[::2], args.source[1::2])\n    args.target = utils.build_dataset(args.target[::2], args.target[1::2])\n    args.test = utils.build_dataset(args.test[::2], args.test[1::2])\n\n    # create model\n    model = models.__dict__[cfg.MODEL.META_ARCHITECTURE](cfg, finetune=args.finetune)\n    model.to(torch.device(cfg.MODEL.DEVICE))\n    logger.info(\"Model:\\n{}\".format(model))\n\n    if args.eval_only:\n        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(\n            cfg.MODEL.WEIGHTS, resume=args.resume\n        )\n        return utils.validate(model, logger, cfg, args)\n\n    distributed = comm.get_world_size() > 1\n    if distributed:\n        model = DistributedDataParallel(\n            model, device_ids=[comm.get_local_rank()], broadcast_buffers=False\n        )\n\n    train(model, logger, cfg, args)\n\n    # evaluate on validation set\n    return utils.validate(model, logger, cfg, args)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    # dataset parameters\n    parser.add_argument('-s', '--source', nargs='+', help='source domain(s)')\n    parser.add_argument('-t', '--target', nargs='+', help='target domain(s)')\n    parser.add_argument('--test', nargs='+', help='test domain(s)')\n    # model parameters\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller learning rate for backbone')\n    parser.add_argument(\n        \"--resume\",\n        action=\"store_true\",\n        help=\"Whether to attempt to resume from the checkpoint directory. \"\n             \"See documentation of `DefaultTrainer.resume_or_load()` for what it means.\",\n    )\n    # training parameters\n    parser.add_argument(\"--config-file\", default=\"\", metavar=\"FILE\", help=\"path to config file\")\n    parser.add_argument(\"--eval-only\", action=\"store_true\", help=\"perform evaluation only\")\n    parser.add_argument(\"--num-gpus\", type=int, default=1, help=\"number of gpus *per machine*\")\n    parser.add_argument(\"--num-machines\", type=int, default=1, help=\"total number of machines\")\n    parser.add_argument(\"--machine-rank\", type=int, default=0, help=\"the rank of this machine (unique per machine)\")\n    # PyTorch still may leave orphan processes in multi-gpu training.\n    # Therefore we use a deterministic way to obtain port,\n    # so that users are aware of orphan processes by seeing the port occupied.\n    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != \"win32\" else 1) % 2 ** 14\n    parser.add_argument(\n        \"--dist-url\",\n        default=\"tcp://127.0.0.1:{}\".format(port),\n        help=\"initialization URL for pytorch distributed backend. See \"\n             \"https://pytorch.org/docs/stable/distributed.html for details.\",\n    )\n    parser.add_argument(\n        \"opts\",\n        help=\"Modify config options by adding 'KEY VALUE' pairs at the end of the command. \"\n             \"See config references at \"\n             \"https://detectron2.readthedocs.io/modules/config.html#config-references\",\n        default=None,\n        nargs=argparse.REMAINDER,\n    )\n    args = parser.parse_args()\n    print(\"Command Line Args:\", args)\n\n    launch(\n        main,\n        args.num_gpus,\n        num_machines=args.num_machines,\n        machine_rank=args.machine_rank,\n        dist_url=args.dist_url,\n        args=(args,),\n    )\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/source_only.sh",
    "content": "# Faster RCNN: VOC->Clipart\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \\\n  --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2clipart\n\n# Faster RCNN: VOC->WaterColor, Comic\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 -t WaterColor datasets/watercolor Comic datasets/comic \\\n  --test VOC2007PartialTest datasets/VOC2007 WaterColorTest datasets/watercolor ComicTest datasets/comic  --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic MODEL.ROI_HEADS.NUM_CLASSES 6\n\n# ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s Cityscapes datasets/cityscapes_in_voc/ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \\\n  --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/cityscapes2foggy\n\n# VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Cityscapes datasets/cityscapes_in_voc/ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \\\n  --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_vgg_16/cityscapes2foggy\n\n# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s Sim10kCar datasets/sim10k -t CityscapesCar datasets/cityscapes_in_voc/  \\\n  --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1\n\n# VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \\\n  -s Sim10kCar datasets/sim10k -t CityscapesCar datasets/cityscapes_in_voc/  \\\n  --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_vgg_16/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1\n\n# Faster RCNN: GTA5 -> Cityscapes\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \\\n  -s GTA5 datasets/synscapes_detection -t Cityscapes datasets/cityscapes_in_voc/  \\\n  --test CityscapesTest datasets/cityscapes_in_voc/ --finetune \\\n  OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/gta52cityscapes\n\n# RetinaNet: VOC->Clipart\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/retinanet_R_101_FPN_voc.yaml \\\n  -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \\\n  --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \\\n  OUTPUT_DIR logs/source_only/retinanet_R_101_FPN/voc2clipart\n\n# RetinaNet: VOC->WaterColor, Comic\nCUDA_VISIBLE_DEVICES=0 python source_only.py \\\n  --config-file config/retinanet_R_101_FPN_voc.yaml \\\n  -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 -t WaterColor datasets/watercolor Comic datasets/comic \\\n  --test VOC2007PartialTest datasets/VOC2007 WaterColorTest datasets/watercolor ComicTest datasets/comic --finetune \\\n  OUTPUT_DIR logs/source_only/retinanet_R_101_FPN/voc2watercolor_comic MODEL.RETINANET.NUM_CLASSES 6\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/utils.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport numpy as np\nimport os\nimport prettytable\nfrom typing import *\nfrom collections import OrderedDict, defaultdict\nimport tempfile\nimport logging\nimport matplotlib as mpl\n\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms as T\nimport detectron2.utils.comm as comm\nfrom detectron2.evaluation import PascalVOCDetectionEvaluator, inference_on_dataset\nfrom detectron2.evaluation.pascal_voc_evaluation import voc_eval\nfrom detectron2.config import get_cfg, CfgNode\nfrom detectron2.engine import default_setup\nfrom detectron2.data import (\n    build_detection_test_loader,\n)\nfrom detectron2.data.transforms.augmentation import Augmentation\nfrom detectron2.data.transforms import BlendTransform, ColorTransform\nfrom detectron2.solver.lr_scheduler import LRMultiplier, WarmupParamScheduler\nfrom detectron2.utils.visualizer import Visualizer\nfrom detectron2.utils.colormap import random_color\nfrom fvcore.common.param_scheduler import *\nimport timm\n\nimport tllib.vision.datasets.object_detection as datasets\nimport tllib.vision.models as models\n\n\nclass PascalVOCDetectionPerClassEvaluator(PascalVOCDetectionEvaluator):\n    \"\"\"\n    Evaluate Pascal VOC style AP with per-class AP for Pascal VOC dataset.\n    It contains a synchronization, therefore has to be called from all ranks.\n\n    Note that the concept of AP can be implemented in different ways and may not\n    produce identical results. This class mimics the implementation of the official\n    Pascal VOC Matlab API, and should produce similar but not identical results to the\n    official API.\n    \"\"\"\n\n    def evaluate(self):\n        \"\"\"\n        Returns:\n            dict: has a key \"segm\", whose value is a dict of \"AP\", \"AP50\", and \"AP75\".\n        \"\"\"\n        all_predictions = comm.gather(self._predictions, dst=0)\n        if not comm.is_main_process():\n            return\n        predictions = defaultdict(list)\n        for predictions_per_rank in all_predictions:\n            for clsid, lines in predictions_per_rank.items():\n                predictions[clsid].extend(lines)\n        del all_predictions\n\n        self._logger.info(\n            \"Evaluating {} using {} metric. \"\n            \"Note that results do not use the official Matlab API.\".format(\n                self._dataset_name, 2007 if self._is_2007 else 2012\n            )\n        )\n\n        with tempfile.TemporaryDirectory(prefix=\"pascal_voc_eval_\") as dirname:\n            res_file_template = os.path.join(dirname, \"{}.txt\")\n\n            aps = defaultdict(list)  # iou -> ap per class\n            for cls_id, cls_name in enumerate(self._class_names):\n                lines = predictions.get(cls_id, [\"\"])\n\n                with open(res_file_template.format(cls_name), \"w\") as f:\n                    f.write(\"\\n\".join(lines))\n\n                for thresh in range(50, 100, 5):\n                    rec, prec, ap = voc_eval(\n                        res_file_template,\n                        self._anno_file_template,\n                        self._image_set_path,\n                        cls_name,\n                        ovthresh=thresh / 100.0,\n                        use_07_metric=self._is_2007,\n                    )\n                    aps[thresh].append(ap * 100)\n\n        ret = OrderedDict()\n        mAP = {iou: np.mean(x) for iou, x in aps.items()}\n        ret[\"bbox\"] = {\"AP\": np.mean(list(mAP.values())), \"AP50\": mAP[50], \"AP75\": mAP[75]}\n        for cls_name, ap in zip(self._class_names, aps[50]):\n            ret[\"bbox\"][cls_name] = ap\n        return ret\n\n\ndef validate(model, logger, cfg, args):\n    results = OrderedDict()\n    for dataset_name in args.test:\n        data_loader = build_detection_test_loader(cfg, dataset_name)\n        evaluator = PascalVOCDetectionPerClassEvaluator(dataset_name)\n        results_i = inference_on_dataset(model, data_loader, evaluator)\n        results[dataset_name] = results_i\n        if comm.is_main_process():\n            logger.info(results_i)\n            table = prettytable.PrettyTable([\"class\", \"AP\"])\n            for class_name, ap in results_i[\"bbox\"].items():\n                table.add_row([class_name, ap])\n            logger.info(table.get_string())\n    if len(results) == 1:\n        results = list(results.values())[0]\n    return results\n\n\ndef build_dataset(dataset_categories, dataset_roots):\n    \"\"\"\n    Give a sequence of dataset class name and a sequence of dataset root directory,\n    return a sequence of built datasets\n    \"\"\"\n    dataset_lists = []\n    for dataset_category, root in zip(dataset_categories, dataset_roots):\n        dataset_lists.append(datasets.__dict__[dataset_category](root).name)\n    return dataset_lists\n\n\ndef rgb2gray(rgb):\n    return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])[:, :, np.newaxis].repeat(3, axis=2).astype(rgb.dtype)\n\n\nclass Grayscale(Augmentation):\n    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):\n        super().__init__()\n        self._init(locals())\n        self._transform = T.Grayscale()\n\n    def get_transform(self, image):\n        return ColorTransform(lambda x: rgb2gray(x))\n\n\ndef build_augmentation(cfg, is_train):\n    \"\"\"\n    Create a list of default :class:`Augmentation` from config.\n    Now it includes resizing and flipping.\n\n    Returns:\n        list[Augmentation]\n    \"\"\"\n    import detectron2.data.transforms as T\n    if is_train:\n        min_size = cfg.INPUT.MIN_SIZE_TRAIN\n        max_size = cfg.INPUT.MAX_SIZE_TRAIN\n        sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING\n    else:\n        min_size = cfg.INPUT.MIN_SIZE_TEST\n        max_size = cfg.INPUT.MAX_SIZE_TEST\n        sample_style = \"choice\"\n    augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]\n    if is_train and cfg.INPUT.RANDOM_FLIP != \"none\":\n        augmentation.append(\n            T.RandomFlip(\n                horizontal=cfg.INPUT.RANDOM_FLIP == \"horizontal\",\n                vertical=cfg.INPUT.RANDOM_FLIP == \"vertical\",\n            )\n        )\n        augmentation.append(\n            T.RandomApply(T.AugmentationList(\n                [\n                    T.RandomContrast(0.6, 1.4),\n                    T.RandomBrightness(0.6, 1.4),\n                    T.RandomSaturation(0.6, 1.4),\n                    T.RandomLighting(0.1)\n                ]\n            ), prob=0.8)\n        )\n        augmentation.append(\n            T.RandomApply(Grayscale(), prob=0.2)\n        )\n    return augmentation\n\n\ndef setup(args):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg = get_cfg()\n    cfg.merge_from_file(args.config_file)\n    cfg.merge_from_list(args.opts)\n    cfg.freeze()\n    default_setup(\n        cfg, args\n    )  # if you don't like any of the default setup, write your own setup code\n    return cfg\n\n\ndef build_lr_scheduler(\n        cfg: CfgNode, optimizer: torch.optim.Optimizer\n) -> torch.optim.lr_scheduler._LRScheduler:\n    \"\"\"\n    Build a LR scheduler from config.\n    \"\"\"\n    name = cfg.SOLVER.LR_SCHEDULER_NAME\n\n    if name == \"WarmupMultiStepLR\":\n        steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]\n        if len(steps) != len(cfg.SOLVER.STEPS):\n            logger = logging.getLogger(__name__)\n            logger.warning(\n                \"SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. \"\n                \"These values will be ignored.\"\n            )\n        sched = MultiStepParamScheduler(\n            values=[cfg.SOLVER.GAMMA ** k for k in range(len(steps) + 1)],\n            milestones=steps,\n            num_updates=cfg.SOLVER.MAX_ITER,\n        )\n    elif name == \"WarmupCosineLR\":\n        sched = CosineParamScheduler(1, 0)\n    elif name == \"ExponentialLR\":\n        sched = ExponentialParamScheduler(1, cfg.SOLVER.GAMMA)\n        return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)\n    else:\n        raise ValueError(\"Unknown LR scheduler: {}\".format(name))\n\n    sched = WarmupParamScheduler(\n        sched,\n        cfg.SOLVER.WARMUP_FACTOR,\n        cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER,\n        cfg.SOLVER.WARMUP_METHOD,\n    )\n    return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)\n\n\ndef get_model_names():\n    return sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    ) + timm.list_models()\n\n\ndef get_model(model_name, pretrain=True):\n    if model_name in models.__dict__:\n        # load models from common.vision.models\n        backbone = models.__dict__[model_name](pretrained=pretrain)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=pretrain)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n    return backbone\n\n\nclass VisualizerWithoutAreaSorting(Visualizer):\n    \"\"\"\n    Visualizer in detectron2 draw instances according to their area's order.\n    This visualizer removes sorting code to avoid that boxes with lower confidence\n    cover boxes with higher confidence.\n    \"\"\"\n\n    def __init__(self, *args, flip=False, **kwargs):\n        super(VisualizerWithoutAreaSorting, self).__init__(*args, **kwargs)\n        self.flip = flip\n\n    def overlay_instances(\n            self,\n            *,\n            boxes=None,\n            labels=None,\n            masks=None,\n            keypoints=None,\n            assigned_colors=None,\n            alpha=1\n    ):\n        \"\"\"\n        Args:\n            boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,\n                or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,\n                or a :class:`RotatedBoxes`,\n                or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format\n                for the N objects in a single image,\n            labels (list[str]): the text to be displayed for each instance.\n            masks (masks-like object): Supported types are:\n\n                * :class:`detectron2.structures.PolygonMasks`,\n                  :class:`detectron2.structures.BitMasks`.\n                * list[list[ndarray]]: contains the segmentation masks for all objects in one image.\n                  The first level of the list corresponds to individual instances. The second\n                  level to all the polygon that compose the instance, and the third level\n                  to the polygon coordinates. The third level should have the format of\n                  [x0, y0, x1, y1, ..., xn, yn] (n >= 3).\n                * list[ndarray]: each ndarray is a binary mask of shape (H, W).\n                * list[dict]: each dict is a COCO-style RLE.\n            keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),\n                where the N is the number of instances and K is the number of keypoints.\n                The last dimension corresponds to (x, y, visibility or score).\n            assigned_colors (list[matplotlib.colors]): a list of colors, where each color\n                corresponds to each mask or box in the image. Refer to 'matplotlib.colors'\n                for full list of formats that the colors are accepted in.\n\n        Returns:\n            output (VisImage): image object with visualizations.\n        \"\"\"\n        num_instances = None\n        if boxes is not None:\n            boxes = self._convert_boxes(boxes)\n            num_instances = len(boxes)\n        if masks is not None:\n            masks = self._convert_masks(masks)\n            if num_instances:\n                assert len(masks) == num_instances\n            else:\n                num_instances = len(masks)\n        if keypoints is not None:\n            if num_instances:\n                assert len(keypoints) == num_instances\n            else:\n                num_instances = len(keypoints)\n            keypoints = self._convert_keypoints(keypoints)\n        if labels is not None:\n            assert len(labels) == num_instances\n        if assigned_colors is None:\n            assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]\n        if num_instances == 0:\n            return self.output\n        if boxes is not None and boxes.shape[1] == 5:\n            return self.overlay_rotated_instances(\n                boxes=boxes, labels=labels, assigned_colors=assigned_colors\n            )\n\n        for i in range(num_instances):\n            color = assigned_colors[i]\n            if boxes is not None:\n                self.draw_box(boxes[i], edge_color=color)\n\n            if masks is not None:\n                for segment in masks[i].polygons:\n                    self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)\n\n            if labels is not None:\n                # first get a box\n                if boxes is not None:\n                    x0, y0, x1, y1 = boxes[i]\n                    text_pos = (x0 - 3, y0)  # if drawing boxes, put text on the box corner.\n                    horiz_align = \"left\"\n                elif masks is not None:\n                    # skip small mask without polygon\n                    if len(masks[i].polygons) == 0:\n                        continue\n\n                    x0, y0, x1, y1 = masks[i].bbox()\n\n                    # draw text in the center (defined by median) when box is not drawn\n                    # median is less sensitive to outliers.\n                    text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]\n                    horiz_align = \"center\"\n                else:\n                    continue  # drawing the box confidence for keypoints isn't very useful.\n                # for small objects, draw text at the side to avoid occlusion\n                instance_area = (y1 - y0) * (x1 - x0)\n                if (\n                        instance_area < 1000 * self.output.scale\n                        or y1 - y0 < 40 * self.output.scale\n                ):\n                    if y1 >= self.output.height - 5:\n                        text_pos = (x1, y0)\n                    else:\n                        text_pos = (x0, y1)\n\n                height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)\n                lighter_color = self._change_color_brightness(color, brightness_factor=0.7)\n                font_size = (\n                        np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)\n                        * 1\n                        * self._default_font_size\n                )\n                if self.flip:\n                    text_pos = (x1 - 3, y0 - 30)  # if drawing boxes, put text on the box corner.\n                self.draw_text(\n                    labels[i],\n                    text_pos,\n                    color=lighter_color,\n                    horizontal_alignment=horiz_align,\n                    font_size=font_size,\n                )\n\n        # draw keypoints\n        if keypoints is not None:\n            for keypoints_per_instance in keypoints:\n                self.draw_and_connect_keypoints(keypoints_per_instance)\n\n        return self.output\n\n    def draw_box(self, box_coord, alpha=1, edge_color=\"g\", line_style=\"-\"):\n        x0, y0, x1, y1 = box_coord\n        width = x1 - x0\n        height = y1 - y0\n\n        linewidth = max(self._default_font_size / 4, 3)\n\n        self.output.ax.add_patch(\n            mpl.patches.Rectangle(\n                (x0, y0),\n                width,\n                height,\n                fill=False,\n                edgecolor=edge_color,\n                linewidth=linewidth * self.output.scale,\n                alpha=alpha,\n                linestyle=line_style,\n            )\n        )\n        return self.output\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/visualize.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport os\nimport argparse\nimport sys\n\nimport torch\nfrom detectron2.checkpoint import DetectionCheckpointer\nfrom detectron2.config import get_cfg\nfrom detectron2.data import (\n    build_detection_test_loader,\n    MetadataCatalog\n)\nfrom detectron2.data import detection_utils\nfrom detectron2.engine import default_setup, launch\nfrom detectron2.utils.visualizer import ColorMode\n\nsys.path.append('../../..')\nimport tllib.vision.models.object_detection.meta_arch as models\n\nimport utils\n\n\ndef visualize(cfg, args, model):\n    for dataset_name in args.test:\n        data_loader = build_detection_test_loader(cfg, dataset_name)\n\n        # create folder\n        dirname = os.path.join(args.save_path, dataset_name)\n        os.makedirs(dirname, exist_ok=True)\n\n        metadata = MetadataCatalog.get(dataset_name)\n        n_current = 0\n\n        # switch to eval mode\n        model.eval()\n        with torch.no_grad():\n            for batch in data_loader:\n                if n_current >= args.n_visualizations:\n                    break\n                batch_predictions = model(batch)\n                for per_image, predictions in zip(batch, batch_predictions):\n                    instances = predictions[\"instances\"].to(torch.device(\"cpu\"))\n                    # only visualize boxes with highest confidence\n                    instances = instances[0: args.n_bboxes]\n                    # only visualize boxes with confidence exceeding the threshold\n                    instances = instances[instances.scores > args.threshold]\n                    # visualize in reverse order of confidence\n                    index = [i for i in range(len(instances))]\n                    index.reverse()\n                    instances = instances[index]\n                    img = per_image[\"image\"].permute(1, 2, 0).cpu().detach().numpy()\n                    img = detection_utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT)\n\n                    # scale pred_box to original resolution\n                    ori_height, ori_width, _ = img.shape\n                    height, width = instances.image_size\n                    ratio = ori_width / width\n                    for i in range(len(instances.pred_boxes)):\n                        instances.pred_boxes[i].scale(ratio, ratio)\n\n                    # save original image\n                    visualizer = utils.VisualizerWithoutAreaSorting(img, metadata=metadata,\n                                                                    instance_mode=ColorMode.IMAGE)\n                    output = visualizer.draw_instance_predictions(predictions=instances)\n\n                    filepath = str(n_current) + \".png\"\n                    filepath = os.path.join(dirname, filepath)\n                    output.save(filepath)\n\n                    n_current += 1\n                    if n_current >= args.n_visualizations:\n                        break\n\n\ndef setup(args):\n    \"\"\"\n    Create configs and perform basic setups.\n    \"\"\"\n    cfg = get_cfg()\n    cfg.merge_from_file(args.config_file)\n    cfg.merge_from_list(args.opts)\n    cfg.freeze()\n    default_setup(\n        cfg, args\n    )  # if you don't like any of the default setup, write your own setup code\n    return cfg\n\n\ndef main(args):\n    cfg = setup(args)\n    meta_arch = cfg.MODEL.META_ARCHITECTURE\n    model = models.__dict__[meta_arch](cfg, finetune=True)\n    model.to(torch.device(cfg.MODEL.DEVICE))\n    DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(\n        cfg.MODEL.WEIGHTS, resume=False\n    )\n    visualize(cfg, args, model)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config-file\", default=\"\", metavar=\"FILE\", help=\"path to config file\")\n    parser.add_argument(\"--num-gpus\", type=int, default=1, help=\"number of gpus *per machine*\")\n    parser.add_argument(\"--num-machines\", type=int, default=1, help=\"total number of machines\")\n    parser.add_argument(\n        \"--machine-rank\", type=int, default=0, help=\"the rank of this machine (unique per machine)\"\n    )\n\n    # PyTorch still may leave orphan processes in multi-gpu training.\n    # Therefore we use a deterministic way to obtain port,\n    # so that users are aware of orphan processes by seeing the port occupied.\n    port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != \"win32\" else 1) % 2 ** 14\n    parser.add_argument(\n        \"--dist-url\",\n        default=\"tcp://127.0.0.1:{}\".format(port),\n        help=\"initialization URL for pytorch distributed backend. See \"\n             \"https://pytorch.org/docs/stable/distributed.html for details.\",\n    )\n    parser.add_argument(\n        \"opts\",\n        help=\"Modify config options by adding 'KEY VALUE' pairs at the end of the command. \"\n             \"See config references at \"\n             \"https://detectron2.readthedocs.io/modules/config.html#config-references\",\n        default=None,\n        nargs=argparse.REMAINDER,\n    )\n    parser.add_argument('--test', nargs='+', help='test domain(s)')\n    parser.add_argument('--save-path', type=str,\n                        help='where to save visualization results ')\n    parser.add_argument('--n-visualizations', default=100, type=int,\n                        help='maximum number of images to visualize (default: 100)')\n    parser.add_argument('--threshold', default=0.5, type=float,\n                        help='confidence threshold of bounding boxes to visualize (default: 0.5)')\n    parser.add_argument('--n-bboxes', default=10, type=int,\n                        help='maximum number of bounding boxes to visualize in a single image (default: 10)')\n    args = parser.parse_args()\n    print(\"Command Line Args:\", args)\n    args.test = utils.build_dataset(args.test[::2], args.test[1::2])\n    launch(\n        main,\n        args.num_gpus,\n        num_machines=args.num_machines,\n        machine_rank=args.machine_rank,\n        dist_url=args.dist_url,\n        args=(args,),\n    )\n"
  },
  {
    "path": "examples/domain_adaptation/object_detection/visualize.sh",
    "content": "# Source Only Faster RCNN: VOC->Clipart\nCUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  --test Clipart datasets/clipart --save-path visualizations/source_only/voc2clipart \\\n  MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth\n\n# Source Only Faster RCNN: VOC->WaterColor, Comic\nCUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \\\n  --test WaterColorTest datasets/watercolor ComicTest datasets/comic --save-path visualizations/source_only/voc2comic_watercolor \\\n  MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth\n"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/README.md",
    "content": "# Open-set Domain Adaptation for Image Classification\n\n## Installation\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).\nYou also need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)\n- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)\n- [VisDA2017](http://ai.bu.edu/visda-2017/)\n\n## Supported Methods\n\nSupported methods include:\n\n- [Open Set Domain Adaptation (OSBP)](https://arxiv.org/abs/1804.10427)\n\n## Experiment and Results\n\nThe shell files give the script to reproduce the benchmark with specified hyper-parameters.\nFor example, if you want to train DANN on Office31, use the following script\n\n```shell script\n# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.\n# Assume you have put the datasets under the path `data/office-31`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W\n```\n\n**Notations**\n- ``Origin`` means the accuracy reported by the original paper.\n- ``Avg`` is the accuracy reported by `TLlib`.\n- ``ERM`` refers to the model trained with data from the source domain.\n\nWe report ``HOS`` used in [ROS (ECCV 2020)](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610409.pdf) to better measure the abilities of different open set domain adaptation algorithms.\n\nWe report the best ``HOS`` in all epochs.\nDANN (baseline model) will degrade performance as training progresses, thus the\nfinal ``HOS`` will be much lower than reported.\nIn contrast, OSBP will improve performance stably.\n\n\n### Office-31 H-Score on ResNet-50\n\n| Methods     | Avg  | A → W | D → W | W → D | A → D | D → A | W → A |\n|-------------|------|-------|-------|-------|-------|-------|-------|\n| ERM         | 75.9 | 67.7  | 85.7  | 91.4  | 72.1  | 68.4  | 67.8  |\n| DANN        | 80.4 | 81.4  | 89.1  | 92.0  | 82.5  | 66.7  | 70.4  |\n| OSBP        | 87.8 | 90.7  | 96.4  | 97.5  | 88.7  | 77.0  | 76.7  |\n\n### Office-Home HOS on ResNet-50\n\n| Methods     | Origin | Avg  | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr |\n|-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|\n| Source Only | /      | 59.8 | 55.2    | 65.2    | 71.4    | 52.8    | 59.6    | 65.2    | 55.8    | 44.8    | 68.0    | 63.8    | 49.4    | 68.0    |\n| DANN        | /      | 64.8 | 55.2    | 65.2    | 71.4    | 52.8    | 59.6    | 65.2    | 55.8    | 44.8    | 68.0    | 63.8    | 49.4    | 68.0    |\n| OSBP        | 64.7   | 68.6 | 62.0    | 70.8    | 76.5    | 66.4    | 68.8    | 73.8    | 65.8    | 57.1    | 75.4    | 70.6    | 60.6    | 75.9    |\n\n### VisDA-2017 performance on ResNet-50\n| Methods     | HOS  | OS   | OS*  | UNK  | bcycl | bus  | car  | mcycl | train | truck |\n|-------------|------|------|------|------|-------|------|------|-------|-------|-------|\n| Source Only | 42.6 | 37.6 | 34.7 | 55.1 | 42.6  | 6.4  | 30.5 | 67.1  | 84.0  | 0.2   |\n| DANN        | 57.8 | 50.4 | 45.6 | 78.9 | 20.1  | 71.4 | 29.5 | 74.4  | 67.8  | 10.4  |\n| OSBP        | 75.4 | 67.3 | 62.9 | 94.3 | 63.7  | 75.9 | 49.6 | 74.4  | 86.2  | 27.3  |\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@InProceedings{OSBP,\n    author = {Saito, Kuniaki and Yamamoto, Shohei and Ushiku, Yoshitaka and Harada, Tatsuya},\n    title = {Open Set Domain Adaptation by Backpropagation},\n    booktitle = {ECCV},\n    year = {2018}\n}\n```\n"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/dann.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.modules.classifier import Classifier\nfrom tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(\n        device)\n    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),\n                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    domain_adv = DomainAdversarialLoss(domain_discri).to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = validate(test_loader, classifier, args)\n        print(acc1)\n        return\n\n    # start training\n    best_h_score = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        h_score = validate(val_loader, classifier, args)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if h_score > best_h_score:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_h_score = max(h_score, best_h_score)\n\n    print(\"best_h_score = {:3.1f}\".format(best_h_score))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    h_score = validate(test_loader, classifier, args)\n    print(\"test_h_score = {:3.1f}\".format(h_score))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    losses = AverageMeter('Loss', ':6.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_t, labels_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = domain_adv(f_s, f_t)\n        domain_acc = domain_adv.domain_discriminator_accuracy\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n        tgt_acc = accuracy(y_t, labels_t)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        tgt_accs.update(tgt_acc.item(), x_s.size(0))\n        domain_accs.update(domain_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\ndef validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float:\n    batch_time = AverageMeter('Time', ':6.3f')\n    classes = val_loader.dataset.classes\n    confmat = ConfusionMatrix(len(classes))\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            softmax_output = F.softmax(output, dim=1)\n            softmax_output[:, -1] = args.threshold\n\n            # measure accuracy and record loss\n            confmat.update(target, softmax_output.argmax(1))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        acc_global, accs, iu = confmat.compute()\n        all_acc = torch.mean(accs).item() * 100\n        known = torch.mean(accs[:-1]).item() * 100\n        unknown = accs[-1].item() * 100\n        h_score = 2 * known * unknown / (known + unknown)\n        if args.per_class_eval:\n            print(confmat.format(classes))\n        print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'\n              .format(all=all_acc, known=known, unknown=unknown, h_score=h_score))\n\n    return h_score\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='DANN for Openset Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--threshold', default=0.8, type=float,\n                        help='When class confidence is less than the given threshold, '\n                             'model will output \"unknown\" (default: 0.5)')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='dann',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/dann.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Pr\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \\\n    --epochs 30 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/dann/VisDA2017_S2R\n"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/erm.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, _, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, pool_layer=pool_layer).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # using shuffled val loader\n        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(val_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = validate(test_loader, classifier, args)\n        print(acc1)\n        return\n\n    # start training\n    best_h_score = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, classifier, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        h_score = validate(val_loader, classifier, args)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if h_score > best_h_score:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_h_score = max(h_score, best_h_score)\n\n    print(\"best_h_score = {:3.1f}\".format(best_h_score))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    h_score = validate(test_loader, classifier, args)\n    print(\"test_h_score = {:3.1f}\".format(h_score))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_s = x_s.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        loss = cls_loss\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\ndef validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float:\n    batch_time = AverageMeter('Time', ':6.3f')\n    classes = val_loader.dataset.classes\n    confmat = ConfusionMatrix(len(classes))\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            softmax_output = F.softmax(output, dim=1)\n            softmax_output[:, -1] = args.threshold\n\n            # measure accuracy and record loss\n            confmat.update(target, softmax_output.argmax(1))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        acc_global, accs, iu = confmat.compute()\n        all_acc = torch.mean(accs).item() * 100\n        known = torch.mean(accs[:-1]).item() * 100\n        unknown = accs[-1].item() * 100\n        h_score = 2 * known * unknown / (known + unknown)\n        if args.per_class_eval:\n            print(confmat.format(classes))\n        print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'\n              .format(all=all_acc, known=known, unknown=unknown, h_score=h_score))\n\n    return h_score\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Source Only for Openset Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--threshold', default=0.8, type=float,\n                        help='When class confidence is less than the given threshold, '\n                             'model will output \"unknown\" (default: 0.5)')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='src_only',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/erm.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Pr\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \\\n    --epochs 30 -i 500 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/erm/VisDA2017_S2R\n"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/osbp.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.alignment.osbp import ImageClassifier as Classifier, UnknownClassBinaryCrossEntropy\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(device)\n    print(classifier)\n    unknown_bce = UnknownClassBinaryCrossEntropy(t=0.5)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = validate(test_loader, classifier, args)\n        print(acc1)\n        return\n\n    # start training\n    best_h_score = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, unknown_bce, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        h_score = validate(val_loader, classifier, args)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if h_score > best_h_score:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_h_score = max(h_score, best_h_score)\n\n    print(\"best_h_score = {:3.1f}\".format(best_h_score))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    h_score = validate(test_loader, classifier, args)\n    print(\"test_h_score = {:3.1f}\".format(h_score))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Classifier,\n          unknown_bce: UnknownClassBinaryCrossEntropy, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')\n    trans_losses = AverageMeter('Trans Loss', ':3.2f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_t, labels_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, _ = model(x_s, grad_reverse=False)\n        y_t, _ = model(x_t, grad_reverse=True)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        trans_loss = unknown_bce(y_t)\n        loss = cls_loss + trans_loss\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n        tgt_acc = accuracy(y_t, labels_t)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        trans_losses.update(trans_loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        tgt_accs.update(tgt_acc.item(), x_t.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\ndef validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float:\n    batch_time = AverageMeter('Time', ':6.3f')\n    classes = val_loader.dataset.classes\n    confmat = ConfusionMatrix(len(classes))\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n\n            # measure accuracy and record loss\n            confmat.update(target, output.argmax(1))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        acc_global, accs, iu = confmat.compute()\n        all_acc = torch.mean(accs).item() * 100\n        known = torch.mean(accs[:-1]).item() * 100\n        unknown = accs[-1].item() * 100\n        h_score = 2 * known * unknown / (known + unknown)\n        if args.per_class_eval:\n            print(confmat.format(classes))\n        print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'\n              .format(all=all_acc, known=known, unknown=unknown, h_score=h_score))\n\n    return h_score\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='OSBP for Openset Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='osbp',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/osbp.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Pr\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python osbp.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \\\n    --epochs 30 -i 1000 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/osbp/VisDA2017_S2R"
  },
  {
    "path": "examples/domain_adaptation/openset_domain_adaptation/utils.py",
    "content": "import sys\nimport timm\nimport torch.nn as nn\nimport torchvision.transforms as T\n\nsys.path.append('../../..')\nimport tllib.vision.datasets.openset as datasets\nfrom tllib.vision.datasets.openset import default_open_set as open_set\nimport tllib.vision.models as models\nfrom tllib.vision.transforms import ResizeImage\n\n\ndef get_model_names():\n    return sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    ) + timm.list_models()\n\n\ndef get_model(model_name):\n    if model_name in models.__dict__:\n        # load models from tllib.vision.models\n        backbone = models.__dict__[model_name](pretrained=True)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=True)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n            backbone.copy_head = backbone.get_classifier\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n            backbone.copy_head = lambda x: x.head\n    return backbone\n\n\ndef get_dataset_names():\n    return sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n\ndef get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None):\n    if train_target_transform is None:\n        train_target_transform = train_source_transform\n    # load datasets from tllib.vision.datasets\n    dataset = datasets.__dict__[dataset_name]\n    source_dataset = open_set(dataset, source=True)\n    target_dataset = open_set(dataset, source=False)\n\n    train_source_dataset = source_dataset(root=root, task=source, download=True, transform=train_source_transform)\n    train_target_dataset = target_dataset(root=root, task=target, download=True, transform=train_target_transform)\n    val_dataset = target_dataset(root=root, task=target, download=True, transform=val_transform)\n    if dataset_name == 'DomainNet':\n        test_dataset = target_dataset(root=root, task=target, split='test', download=True, transform=val_transform)\n    else:\n        test_dataset = val_dataset\n    class_names = train_source_dataset.classes\n    num_classes = len(class_names)\n    return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names\n\n\ndef get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take a random resized crop of size 224;\n        - cen.crop: resize the image to 256 and take the center crop of size 224;\n        - res: resize the image to 224;\n        - res.|crop: resize the image to 256 and take a random crop of size 224;\n        - res.sma|crop: resize the image keeping its aspect ratio such that the\n            smaller side is 256, then take a random crop of size 224;\n        – inc.crop: “inception crop” from (Szegedy et al., 2015);\n        – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.RandomResizedCrop(224)\n        ])\n    elif resizing == 'cen.crop':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224)\n        ])\n    elif resizing == 'ran.crop':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.RandomCrop(224)\n        ])\n    elif resizing == 'res.':\n        transform = T.Resize(224)\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize((256, 256)),\n            T.RandomCrop(224)\n        ])\n    elif resizing == \"res.sma|crop\":\n        transform = T.Compose([\n            T.Resize(256),\n            T.RandomCrop(224)\n        ])\n    elif resizing == 'inc.crop':\n        transform = T.RandomResizedCrop(224)\n    elif resizing == 'cif.crop':\n        transform = T.Compose([\n            T.Resize((224, 224)),\n            T.Pad(28),\n            T.RandomCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if random_color_jitter:\n        transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n    return T.Compose(transforms)\n\n\ndef get_val_transform(resizing='default'):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take the center crop of size 224;\n        – res.: resize the image to 224\n        – res.|crop: resize the image such that the smaller side is of size 256 and\n            then take a central crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224),\n        ])\n    elif resizing == 'res.':\n        transform = T.Resize((224, 224))\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize(256),\n            T.CenterCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    return T.Compose([\n        transform,\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/README.md",
    "content": "# Partial Domain Adaptation for Image Classification\n\n## Installation\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).\nYou also need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)\n- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)\n- [VisDA2017](http://ai.bu.edu/visda-2017/)\n\n## Supported Methods\n\nSupported methods include:\n\n- [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818)\n- [Partial Adversarial Domain Adaptation (PADA)](https://arxiv.org/abs/1808.04205)\n- [Importance Weighted Adversarial Nets (IWAN)](https://arxiv.org/abs/1803.09210)\n- [Adaptive Feature Norm (AFN)](https://arxiv.org/pdf/1811.07456v2.pdf)\n\n## Experiment and Results\n\nThe shell files give the script to reproduce the benchmark with specified hyper-parameters.\nFor example, if you want to train DANN on Office31, use the following script\n\n```shell script\n# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.\n# Assume you have put the datasets under the path `data/office-31`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W\n```\n\n**Notations**\n- ``Origin`` means the accuracy reported by the original paper.\n- ``Avg`` is the accuracy reported by `TLlib`.\n- ``ERM`` refers to the model trained with data from the source domain.\n- ``Oracle`` refers to the model trained with data from the target domain.\n\n\nWe found that the accuracies of adversarial methods (including DANN) are not stable\neven after the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017*\nfor three times and report their average accuracy.\n\n### Office-31 accuracy on ResNet-50\n| Methods     | Origin | Avg  | A → W | D → W | W → D | A → D | D → A | W → A | \n|-------------|--------|------|-------|-------|-------|-------|-------|-------|\n| ERM | 75.6   | 90.1 | 78.3  | 98.3  | 99.4  | 87.3  | 88.5  | 88.8  | 84.0  |\n| DANN        | 43.4   | 82.4 | 60.0  | 94.9  | 98.1  | 71.3  | 84.9  | 85.0  | \n| PADA        | 92.7   | 93.8 | 86.4  | 100.0 | 100.0 | 87.3  | 93.8  | 95.4  |\n| IWAN        | 94.7   | 94.8 | 91.2  | 99.7  | 99.4  | 89.8  | 94.2  | 94.3  |\n| AFN         | /      | 93.1 | 87.8  | 95.6  | 99.4  | 87.9  | 93.9  | 94.1  |\n\n### Office-Home accuracy on ResNet-50\n\n| Methods     | Origin | Avg  | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr |\n|-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|\n| ERM | 53.7   | 60.1 | 42.0    | 66.9    | 78.5    | 56.4    | 55.2    | 65.4    | 57.9    | 36.0    | 75.5    | 68.7    | 43.6    | 74.8    |\n| DANN        | 47.4   | 57.0 | 46.2    | 59.3    | 76.9    | 47.0    | 47.4    | 56.4    | 51.6    | 38.8    | 72.1    | 68.0    | 46.1    | 74.2    |\n| PADA        | 62.1   | 65.9 | 52.9    | 69.3    | 82.8    | 59.0    | 57.5    | 66.4    | 66.0    | 41.7    | 82.5    | 78.0    | 50.2    | 84.1    |\n| IWAN        | 63.6   | 71.3 | 59.2    | 76.6    | 84.0    | 67.8    | 66.7    | 69.2    | 73.3    | 55.0    | 83.9    | 79.0    | 58.3    | 82.2    |\n| AFN         | 71.8   | 72.6 | 59.2    | 76.7    | 82.8    | 72.5    | 74.5    | 76.8    | 72.5    | 56.7    | 80.8    | 77.0    | 60.5    | 81.6    |\n\n### VisDA-2017 accuracy on ResNet-50\n| Methods     | Origin | Mean | plane | bcycl | bus  | car  | horse | knife | Avg  |\n|-------------|--------|------|-------|-------|------|------|-------|-------|------|\n| ERM | 45.3   | 50.9 | 59.2  | 31.3  | 68.7 | 73.2 | 69.3  | 3.4   | 60.0 |\n| DANN        | 51.0   | 55.9 | 88.4  | 34.1  | 72.1 | 50.7 | 61.9  | 27.8  | 57.1 |\n| PADA        | 53.5   | 60.5 | 89.4  | 35.1  | 72.5 | 69.2 | 86.7  | 10.1  | 66.8 |\n| IWAN        | /      | 61.5 | 89.2  | 57.0  | 61.5 | 55.2 | 80.1  | 25.7  | 66.8 |\n| AFN         | 67.6   | 61.0 | 79.1  | 62.7  | 73.9 | 49.6 | 79.6  | 21.0  | 64.1 |\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{DANN,\n    author = {Ganin, Yaroslav and Lempitsky, Victor},\n    Booktitle = {ICML},\n    Title = {Unsupervised domain adaptation by backpropagation},\n    Year = {2015}\n}\n\n@InProceedings{PADA,\n    author    = {Zhangjie Cao and\n               Lijia Ma and\n               Mingsheng Long and\n               Jianmin Wang},\n    title     = {Partial Adversarial Domain Adaptation},\n    booktitle = {ECCV},\n    year = {2018}\n}\n\n@InProceedings{IWAN,\n    author    = {Jing Zhang and\n               Zewei Ding and\n               Wanqing Li and\n               Philip Ogunbona},\n    title     = {Importance Weighted Adversarial Nets for Partial Domain Adaptation},\n    booktitle = {CVPR},\n    year = {2018}\n}\n\n@InProceedings{AFN,\n    author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang},\n    title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation},\n    booktitle = {ICCV},\n    year = {2019}\n}\n```\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/afn.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.normalization.afn import AdaptiveFeatureNorm, ImageClassifier\nfrom tllib.modules.entropy import entropy\nimport tllib.vision.models as models\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    backbone = models.__dict__[args.arch](pretrained=True)\n    classifier = ImageClassifier(backbone, train_source_dataset.num_classes, args.num_blocks,\n                                 bottleneck_dim=args.bottleneck_dim, dropout_p=args.dropout_p, pool_layer=pool_layer).to(device)\n    adaptive_feature_norm = AdaptiveFeatureNorm(args.delta).to(device)\n\n    # define optimizer\n    # the learning rate is fixed according to origin paper\n    optimizer = SGD(classifier.get_parameters(), args.lr, weight_decay=args.weight_decay)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, adaptive_feature_norm, optimizer, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':3.1f')\n    data_time = AverageMeter('Data', ':3.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    norm_losses = AverageMeter('Norm Loss', ':3.2f')\n    src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')\n    tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs, tgt_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_t, labels_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n        y_t, f_t = model(x_t)\n\n        # classification loss\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        # norm loss\n        norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)\n\n        loss = cls_loss + norm_loss * args.trade_off_norm\n\n        # using entropy minimization\n        if args.trade_off_entropy:\n            y_t = F.softmax(y_t, dim=1)\n            entropy_loss = entropy(y_t, reduction='mean')\n            loss += entropy_loss * args.trade_off_entropy\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # update statistics\n        cls_acc = accuracy(y_s, labels_s)[0]\n        tgt_acc = accuracy(y_t, labels_t)[0]\n\n        cls_losses.update(cls_loss.item(), x_s.size(0))\n        norm_losses.update(norm_loss.item(), x_s.size(0))\n        src_feature_norm.update(f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))\n        tgt_feature_norm.update(f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        tgt_accs.update(tgt_acc.item(), x_s.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='AFN for Partial Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('-n', '--num-blocks', default=1, type=int, help='Number of basic blocks for classifier')\n    parser.add_argument('--bottleneck-dim', default=1000, type=int, help='Dimension of bottleneck')\n    parser.add_argument('--dropout-p', default=0.5, type=float,\n                        help='Dropout probability')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)',\n                        dest='weight_decay')\n    parser.add_argument('--trade-off-norm', default=0.05, type=float,\n                        help='the trade-off hyper-parameter for norm loss')\n    parser.add_argument('--trade-off-entropy', default=None, type=float,\n                        help='the trade-off hyper-parameter for entropy loss')\n    parser.add_argument('-r', '--delta', default=1, type=float, help='Increment for L2 norm')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='afn',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/afn.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Pr\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python afn.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 -r 0.3 -b 36 \\\n    --epochs 30 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/afn/VisDA2017"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/dann.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.modules.classifier import Classifier\nfrom tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    if args.data == 'ImageNetCaltech':\n        classifier = Classifier(backbone, num_classes, head=backbone.copy_head(), pool_layer=pool_layer).to(device)\n    else:\n        classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device)\n\n    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),\n                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    domain_adv = DomainAdversarialLoss(domain_discri).to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    losses = AverageMeter('Loss', ':6.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_t, labels_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        transfer_loss = domain_adv(f_s, f_t)\n        domain_acc = domain_adv.domain_discriminator_accuracy\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n        tgt_acc = accuracy(y_t, labels_t)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        tgt_accs.update(tgt_acc.item(), x_s.size(0))\n        domain_accs.update(domain_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='DANN for Partial Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay',default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='dann',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/dann.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Pr\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \\\n    --epochs 5 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/dann/VisDA2017_S2R\n\n# ImageNet-Caltech\nCUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \\\n    --epochs 5 --seed 0 --log logs/dann/I2C\nCUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \\\n    --epochs 5 --seed 0 --bottleneck-dim 2048 --log logs/dann/C2I\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/erm.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, _, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    head = backbone.copy_head() if args.data == 'ImageNetCaltech' else None\n    classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, head=head).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # using shuffled val loader\n        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(val_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, classifier, optimizer,\n              lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_s = x_s.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n\n        cls_loss = F.cross_entropy(y_s, labels_s)\n        loss = cls_loss\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Source Only for Partial Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='src_only',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/erm.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Pr\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \\\n    --epochs 10 -i 500 --seed 0 --per-class-eval --log logs/erm/VisDA2017_S2R\n\n# ImageNet-Caltech\nCUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \\\n    --epochs 20 --seed 0 -i 2000 --log logs/erm/I2C\nCUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \\\n    --epochs 20 --seed 0 -i 2000 --log logs/erm/C2I\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/iwan.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.modules.entropy import entropy\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.reweight.iwan import ImportanceWeightModule, ImageClassifier\nfrom tllib.alignment.dann import DomainAdversarialLoss\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    if args.data == 'ImageNetCaltech':\n        classifier = Classifier(backbone, num_classes, head=backbone.copy_head(), pool_layer=pool_layer).to(device)\n    else:\n        classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device)\n\n    # define domain classifier D, D_0\n    D = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device)\n    D_0 = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters() + D.get_parameters() + D_0.get_parameters(),\n                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    domain_adv_D = DomainAdversarialLoss(D).to(device)\n    domain_adv_D_0 = DomainAdversarialLoss(D_0).to(device)\n    # define importance weight module\n    importance_weight_module = ImportanceWeightModule(D, train_target_dataset.partial_classes_idx)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, domain_adv_D, domain_adv_D_0,\n              importance_weight_module, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          domain_adv_D: DomainAdversarialLoss, domain_adv_D_0: DomainAdversarialLoss,\n          importance_weight_module, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    losses = AverageMeter('Loss', ':6.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')\n    domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f')\n    domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f')\n    partial_classes_weights = AverageMeter('Partial Weight', ':3.2f')\n    non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs, tgt_accs,\n         domain_accs_D, domain_accs_D_0, partial_classes_weights, non_partial_classes_weights],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv_D.train()\n    domain_adv_D_0.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_t, labels_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        # classification loss\n        cls_loss = F.cross_entropy(y_s, labels_s)\n\n        # domain adversarial loss for D\n        adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach())\n\n        # get importance weights\n        w_s = importance_weight_module.get_importance_weight(f_s)\n        # domain adversarial loss for D_0\n        adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s)\n\n        # entropy loss\n        y_t = F.softmax(y_t, dim=1)\n        entropy_loss = entropy(y_t, reduction='mean')\n\n        loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \\\n               args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n        tgt_acc = accuracy(y_t, labels_t)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        tgt_accs.update(tgt_acc.item(), x_s.size(0))\n        domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy, x_s.size(0))\n        domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy, x_s.size(0))\n\n        # debug: output class weight averaged on the partial classes and non-partial classes respectively\n        partial_class_weight, non_partial_classes_weight = \\\n            importance_weight_module.get_partial_classes_weight(w_s, labels_s)\n        partial_classes_weights.update(partial_class_weight.item(), x_s.size(0))\n        non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0))\n\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='IWAN for Partial Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of source (and target) dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--gamma', default=0.1, type=float,\n                        help='the trade-off hyper-parameter for entropy loss(default: 0.1)')\n    parser.add_argument('--trade-off', default=3, type=float,\n                        help='the trade-off hyper-parameter for transfer loss(default: 3))')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=10, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='iwan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/iwan.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s A -t W -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s D -t W -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s W -t D -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s A -t D -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s D -t A -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s W -t A -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Pr\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \\\n    --lr 0.0003 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/iwan/VisDA2017_S2R\n\n# ImageNet-Caltech\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \\\n    --seed 0 --log logs/iwan/I2C\nCUDA_VISIBLE_DEVICES=0 python iwan.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \\\n    --seed 0 --bottleneck-dim 2048 --log logs/iwan/C2I\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/pada.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.modules.classifier import Classifier\nfrom tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier\nfrom tllib.reweight.pada import AutomaticUpdateClassWeightModule\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import collect_feature, tsne, a_distance\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=False)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, drop_last=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    if args.data == 'ImageNetCaltech':\n        classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, head=backbone.copy_head()).to(device)\n    else:\n        classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device)\n    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)\n    class_weight_module = AutomaticUpdateClassWeightModule(args.class_weight_update_steps, train_target_loader,\n                                                           classifier, num_classes, device, args.temperature,\n                                                           train_target_dataset.partial_classes_idx)\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),\n                    args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)\n    lr_scheduler = LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n\n    # define loss function\n    domain_adv = DomainAdversarialLoss(domain_discri).to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = collect_feature(train_source_loader, feature_extractor, device)\n        target_feature = collect_feature(train_target_loader, feature_extractor, device)\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_source_iter, train_target_iter, classifier, domain_adv, class_weight_module,\n              optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test_acc1 = {:3.1f}\".format(acc1))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,\n          domain_adv: DomainAdversarialLoss, class_weight_module: AutomaticUpdateClassWeightModule,\n          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':5.2f')\n    data_time = AverageMeter('Data', ':5.2f')\n    losses = AverageMeter('Loss', ':6.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    tgt_accs = AverageMeter('Tgt Acc', ':3.1f')\n    partial_classes_weights = AverageMeter('Partial Weight', ':3.1f')\n    non_partial_classes_weights = AverageMeter('Non-partial Weight', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs, domain_accs, tgt_accs, partial_classes_weights, non_partial_classes_weights],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n    domain_adv.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_s, labels_s = next(train_source_iter)\n        x_t, labels_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        x = torch.cat((x_s, x_t), dim=0)\n        y, f = model(x)\n        y_s, y_t = y.chunk(2, dim=0)\n        f_s, f_t = f.chunk(2, dim=0)\n\n        cls_loss = F.cross_entropy(y_s, labels_s, class_weight_module.get_class_weight_for_cross_entropy_loss())\n        w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss(labels_s)\n        transfer_loss = domain_adv(f_s, f_t, w_s, w_t)\n        class_weight_module.step()\n        partial_classes_weight, non_partial_classes_weight = class_weight_module.get_partial_classes_weight()\n        domain_acc = domain_adv.domain_discriminator_accuracy\n        loss = cls_loss + transfer_loss * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n        tgt_acc = accuracy(y_t, labels_t)[0]\n\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n        domain_accs.update(domain_acc.item(), x_s.size(0))\n        tgt_accs.update(tgt_acc.item(), x_s.size(0))\n        partial_classes_weights.update(partial_classes_weight.item(), x_s.size(0))\n        non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='PADA for Partial Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of source (and target) dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: Office31)')\n    parser.add_argument('-s', '--source', help='source domain')\n    parser.add_argument('-t', '--target', help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet18)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--bottleneck-dim', default=256, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('-u', '--class-weight-update-steps', default=500, type=int,\n                        help='Number of steps to update class weight once')\n    parser.add_argument('--temperature', default=0.1, type=float,\n                        help='temperature for softmax when calculating class weight')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay',default=1e-3, type=float,\n                        metavar='W', help='weight decay (default: 1e-3)',\n                        dest='weight_decay')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--per-class-eval', action='store_true',\n                        help='whether output per-class accuracy during evaluation')\n    parser.add_argument(\"--log\", type=str, default='pada',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n\n"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/pada.sh",
    "content": "#!/usr/bin/env bash\n# Office31\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_A2W\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_D2W\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_W2D\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_A2D\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_D2A\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_W2A\n\n# Office-Home\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Cl\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Pr\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Rw\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Ar\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Pr\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Rw\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Ar\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Cl\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Rw\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Ar\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Cl\nCUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Pr\n\n\n# VisDA-2017\nCUDA_VISIBLE_DEVICES=0 python pada.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \\\n    --epochs 20 --seed 0 -u 500 -i 500 --train-resizing cen.crop --trade-off 0.4 --per-class-eval --log logs/pada/VisDA2017_S2R\n\n# ImageNet-Caltech\nCUDA_VISIBLE_DEVICES=0 python pada.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \\\n    --epochs 20 --seed 0 --lr 0.003 --temperature 0.01 -u 2000 -i 2000 --log logs/pada/I2C\nCUDA_VISIBLE_DEVICES=0 python pada.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \\\n    --epochs 20 --seed 0 --lr 0.003 --temperature 0.01 -u 2000 -i 2000 --bottleneck-dim 2048 --log logs/pada/C2I"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/requirements.txt",
    "content": "timm"
  },
  {
    "path": "examples/domain_adaptation/partial_domain_adaptation/utils.py",
    "content": "import sys\nimport time\nimport timm\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.transforms as T\n\nsys.path.append('../../..')\nimport tllib.vision.datasets.partial as datasets\nfrom tllib.vision.datasets.partial import default_partial as partial\nimport tllib.vision.models as models\nfrom tllib.vision.transforms import ResizeImage\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\ndef get_model_names():\n    return sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    ) + timm.list_models()\n\n\ndef get_model(model_name):\n    if model_name in models.__dict__:\n        # load models from tllib.vision.models\n        backbone = models.__dict__[model_name](pretrained=True)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=True)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n            backbone.copy_head = backbone.get_classifier\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n            backbone.copy_head = lambda x: x.head\n    return backbone\n\n\ndef get_dataset_names():\n    return sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n\ndef get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None):\n    if train_target_transform is None:\n        train_target_transform = train_source_transform\n    # load datasets from tllib.vision.datasets\n    dataset = datasets.__dict__[dataset_name]\n    partial_dataset = partial(dataset)\n\n    train_source_dataset = dataset(root=root, task=source, download=True, transform=train_source_transform)\n    train_target_dataset = partial_dataset(root=root, task=target, download=True, transform=train_target_transform)\n    val_dataset = partial_dataset(root=root, task=target, download=True, transform=val_transform)\n    if dataset_name == 'DomainNet':\n        test_dataset = partial_dataset(root=root, task=target, split='test', download=True, transform=val_transform)\n    else:\n        test_dataset = val_dataset\n    class_names = train_source_dataset.classes\n    num_classes = len(class_names)\n    return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names\n\n\ndef validate(val_loader, model, args, device) -> float:\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    top1 = AverageMeter('Acc@1', ':6.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, top1],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    if args.per_class_eval:\n        confmat = ConfusionMatrix(len(args.class_names))\n    else:\n        confmat = None\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            loss = F.cross_entropy(output, target)\n\n            # measure accuracy and record loss\n            acc1, = accuracy(output, target, topk=(1,))\n            if confmat:\n                confmat.update(target, output.argmax(1))\n            losses.update(loss.item(), images.size(0))\n            top1.update(acc1.item(), images.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))\n        if confmat:\n            print(confmat.format(args.class_names))\n\n    return top1.avg\n\n\ndef get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take a random resized crop of size 224;\n        - cen.crop: resize the image to 256 and take the center crop of size 224;\n        - res: resize the image to 224;\n        - res.|crop: resize the image to 256 and take a random crop of size 224;\n        - res.sma|crop: resize the image keeping its aspect ratio such that the\n            smaller side is 256, then take a random crop of size 224;\n        – inc.crop: “inception crop” from (Szegedy et al., 2015);\n        – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.RandomResizedCrop(224)\n        ])\n    elif resizing == 'cen.crop':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224)\n        ])\n    elif resizing == 'res.':\n        transform = T.Resize(224)\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize((256, 256)),\n            T.RandomCrop(224)\n        ])\n    elif resizing == \"res.sma|crop\":\n        transform = T.Compose([\n            T.Resize(256),\n            T.RandomCrop(224)\n        ])\n    elif resizing == 'inc.crop':\n        transform = T.RandomResizedCrop(224)\n    elif resizing == 'cif.crop':\n        transform = T.Compose([\n            T.Resize((224, 224)),\n            T.Pad(28),\n            T.RandomCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if random_color_jitter:\n        transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n    return T.Compose(transforms)\n\n\ndef get_val_transform(resizing='default'):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take the center crop of size 224;\n        – res.: resize the image to 224\n        – res.|crop: resize the image such that the smaller side is of size 256 and\n            then take a central crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224),\n        ])\n    elif resizing == 'res.':\n        transform = T.Resize((224, 224))\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize(256),\n            T.CenterCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    return T.Compose([\n        transform,\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/README.md",
    "content": "# Unsupervised Domain Adaptation for Person Re-Identification\n\n## Installation\n\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You\nalso need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [Market1501](http://zheng-lab.cecs.anu.edu.au/Project/project_reid.html)\n- [DukeMTMC](https://exposing.ai/duke_mtmc/)\n- [MSMT17](https://arxiv.org/pdf/1711.08565.pdf)\n\n## Supported Methods\n\nSupported methods include:\n\n- [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)\n- [Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification (MMT, 2020 ICLR)](https://arxiv.org/abs/2001.01526)\n- [Similarity Preserving Generative Adversarial Network (SPGAN, 2018 CVPR)](https://arxiv.org/pdf/1811.10551.pdf)\n\n## Usage\n\nThe shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to\ntrain MMT on Market1501 -> DukeMTMC task, use the following script\n\n```shell script\n# Train MMT on Market1501 -> DukeMTMC task using ResNet 50.\n# Assume you have put the datasets under the path `data/market1501` and `data/dukemtmc`, \n# or you are glad to download the datasets automatically from the Internet to this path\n\n# MMT involves two training steps:\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2DukeSeed0\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2DukeSeed1\n\n# step2: train mmt\nCUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data -t DukeMTMC -a reid_resnet50 \\\n--pretrained-model-1-path logs/baseline/Market2DukeSeed0/checkpoints/best.pth \\\n--pretrained-model-2-path logs/baseline/Market2DukeSeed1/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/mmt/Market2Duke\n```\n\n### Experiment and Results\nIn our experiments, we adopt modified resnet architecture from [MMT](https://arxiv.org/pdf/2001.01526.pdf>). For a fair comparison,\nwe use standard cross entropy loss and triplet loss in all methods. For methods that utilize clustering algorithms, \nwe adopt kmeans or DBSCAN and report both results.\n\n**Notations**\n- ``Avg`` means the mAP (mean average precision) reported by `TLlib`.\n- ``Baseline_Cluster`` represents the strong baseline in [MMT](https://arxiv.org/pdf/2001.01526.pdf>).\n\n### Cross dataset mAP on ResNet-50\n\n| Methods                  | Avg  | Market2Duke | Duke2Market | Market2MSMT | MSMT2Market | Duke2MSMT | MSMT2Duke |\n|--------------------------|------|-------------|-------------|-------------|-------------|-----------|-----------|\n| Baseline                 | 27.1 | 32.4        | 31.4        | 8.2         | 36.7        | 11.0      | 43.1      |\n| IBN                      | 30.0 | 35.2        | 36.5        | 11.3        | 38.7        | 14.1      | 44.3      |\n| SPGAN                    | 30.7 | 34.4        | 35.4        | 14.1        | 40.2        | 16.1      | 43.8      |\n| Baseline_Cluster(kmeans) | 45.1 | 52.8        | 59.5        | 19.0        | 62.6        | 20.3      | 56.2      |\n| Baseline_Cluster(dbscan) | 54.9 | 62.5        | 73.5        | 25.2        | 77.9        | 25.3      | 65.0      |\n| MMT(kmeans)              | 55.4 | 63.7        | 72.5        | 26.2        | 75.8        | 28.0      | 66.1      |\n| MMT(dbscan)              | 60.0 | 68.2        | 80.0        | 28.2        | 82.5        | 31.2      | 70.0      |\n\n## Citation\n\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{IBN-Net,  \n    author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang},  \n    title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net},  \n    booktitle = {ECCV},  \n    year = {2018}  \n}\n\n@inproceedings{SPGAN,\n    title={Image-image domain adaptation with preserved self-similarity and domain-dissimilarity for person re-identification},\n    author={Deng, Weijian and Zheng, Liang and Ye, Qixiang and Kang, Guoliang and Yang, Yi and Jiao, Jianbin},\n    booktitle={CVPR},\n    year={2018}\n}\n\n@inproceedings{MMT,\n    title={Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification},\n    author={Yixiao Ge and Dapeng Chen and Hongsheng Li},\n    booktitle={ICLR},\n    year={2020},\n}\n```"
  },
  {
    "path": "examples/domain_adaptation/re_identification/baseline.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import DataParallel\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss\nfrom tllib.vision.models.reid.identifier import ReIdentifier\nimport tllib.vision.datasets.reid as datasets\nfrom tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset\nfrom tllib.utils.scheduler import WarmupMultiStepLR\nfrom tllib.utils.metric.reid import validate, visualize_ranked_results\nfrom tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,\n                                                random_horizontal_flip=True, random_color_jitter=False,\n                                                random_gray_scale=False, random_erasing=False)\n    val_transform = utils.get_val_transform(args.height, args.width)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    working_dir = osp.dirname(osp.abspath(__file__))\n    source_root = osp.join(working_dir, args.source_root)\n    target_root = osp.join(working_dir, args.target_root)\n\n    # source dataset\n    source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower()))\n    sampler = RandomMultipleGallerySampler(source_dataset.train, args.num_instances)\n    train_source_loader = DataLoader(\n        convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    val_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),\n                                   root=source_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # target dataset\n    target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower()))\n    train_target_loader = DataLoader(\n        convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n    test_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),\n                                   root=target_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # create model\n    num_classes = source_dataset.num_train_pids\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)\n    model = DataParallel(model)\n\n    # define optimizer and lr scheduler\n    optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,\n                     weight_decay=args.weight_decay)\n    lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1,\n                                     warmup_steps=args.warmup_steps)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # plot t-SNE\n        utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,\n                             filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)\n        # visualize ranked results\n        visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                 visualize_dir=logger.visualize_directory, width=args.width, height=args.height,\n                                 rerank=args.rerank)\n        return\n\n    if args.phase == 'test':\n        print(\"Test on source domain:\")\n        validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        print(\"Test on target domain:\")\n        validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        return\n\n    # define loss function\n    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)\n    criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)\n\n    # start training\n    best_val_mAP = 0.\n    best_test_mAP = 0.\n    for epoch in range(args.epochs):\n        # print learning rate\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args)\n\n        # update learning rate\n        lr_scheduler.step()\n\n        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):\n\n            # evaluate on validation set\n            print(\"Validation on source domain...\")\n            _, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device,\n                                  cmc_flag=True)\n\n            # remember best mAP and save checkpoint\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if val_mAP > best_val_mAP:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n            best_val_mAP = max(val_mAP, best_val_mAP)\n\n            # evaluate on test set\n            print(\"Test on target domain...\")\n            _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                   cmc_flag=True, rerank=args.rerank)\n            best_test_mAP = max(test_mAP, best_test_mAP)\n\n    # evaluate on test set\n    model.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    print(\"Test on target domain:\")\n    _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                           cmc_flag=True, rerank=args.rerank)\n    print(\"test mAP on target = {}\".format(test_mAP))\n    print(\"oracle mAP on target = {}\".format(best_test_mAP))\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model,\n          criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_triplet: SoftTripletLoss, optimizer: Adam,\n          epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_ce = AverageMeter('CeLoss', ':3.2f')\n    losses_triplet = AverageMeter('TripletLoss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        x_s, _, labels_s, _ = next(train_source_iter)\n        x_t, _, _, _ = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        x_t = x_t.to(device)\n        labels_s = labels_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, f_s = model(x_s)\n        y_t, f_t = model(x_t)\n\n        # cross entropy loss\n        loss_ce = criterion_ce(y_s, labels_s)\n        # triplet loss\n        loss_triplet = criterion_triplet(f_s, f_s, labels_s)\n        loss = loss_ce + loss_triplet * args.trade_off\n\n        cls_acc = accuracy(y_s, labels_s)[0]\n        losses_ce.update(loss_ce.item(), x_s.size(0))\n        losses_triplet.update(loss_triplet.item(), x_s.size(0))\n        losses.update(loss.item(), x_s.size(0))\n        cls_accs.update(cls_acc.item(), x_s.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description=\"Baseline for Domain Adaptative ReID\")\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', type=str, help='source domain')\n    parser.add_argument('-t', '--target', type=str, help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: reid_resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--rate', type=float, default=0.2)\n    # training parameters\n    parser.add_argument('--trade-off', type=float, default=1,\n                        help='trade-off hyper parameter between cross entropy loss and triplet loss')\n    parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')\n    parser.add_argument('-j', '--workers', type=int, default=4)\n    parser.add_argument('-b', '--batch-size', type=int, default=16)\n    parser.add_argument('--height', type=int, default=256, help=\"input height\")\n    parser.add_argument('--width', type=int, default=128, help=\"input width\")\n    parser.add_argument('--num-instances', type=int, default=4,\n                        help=\"each minibatch consist of \"\n                             \"(batch_size // num_instances) identities, and \"\n                             \"each identity has num_instances instances, \"\n                             \"default: 4\")\n    parser.add_argument('--lr', type=float, default=0.00035,\n                        help=\"initial learning rate\")\n    parser.add_argument('--weight-decay', type=float, default=5e-4)\n    parser.add_argument('--epochs', type=int, default=80)\n    parser.add_argument('--warmup-steps', type=int, default=10, help='number of warm-up steps')\n    parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70],\n                        help='milestones for the learning rate decay')\n    parser.add_argument('--eval-step', type=int, default=40)\n    parser.add_argument('--iters-per-epoch', type=int, default=400)\n    parser.add_argument('--print-freq', type=int, default=40)\n    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')\n    parser.add_argument('--rerank', action='store_true', help=\"evaluation only\")\n    parser.add_argument(\"--log\", type=str, default='baseline',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/baseline.sh",
    "content": "#!/usr/bin/env bash\n# Market1501 -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2Duke\n\n# Duke -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2Market\n\n# Market1501 -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMT\n\n# MSMT -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Market\n\n# Duke -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMT\n\n# MSMT -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/baseline_cluster.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import DataParallel\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.utils.data import DataLoader\nfrom sklearn.cluster import KMeans, DBSCAN\n\nimport utils\nimport tllib.vision.datasets.reid as datasets\nfrom tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset\nfrom tllib.vision.models.reid.identifier import ReIdentifier\nfrom tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss\nfrom tllib.utils.metric.reid import extract_reid_feature, validate, visualize_ranked_results\nfrom tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,\n                                                random_horizontal_flip=True, random_color_jitter=False,\n                                                random_gray_scale=False, random_erasing=True)\n    val_transform = utils.get_val_transform(args.height, args.width)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    working_dir = osp.dirname(osp.abspath(__file__))\n    source_root = osp.join(working_dir, args.source_root)\n    target_root = osp.join(working_dir, args.target_root)\n\n    # source dataset\n    source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower()))\n    val_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),\n                                   root=source_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # target dataset\n    target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower()))\n    cluster_loader = DataLoader(\n        convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n    test_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),\n                                   root=target_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # create model\n    num_classes = args.num_clusters\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)\n    model = DataParallel(model)\n\n    # load pretrained weights\n    pretrained_model = torch.load(args.pretrained_model_path)\n    utils.copy_state_dict(model, pretrained_model)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        utils.copy_state_dict(model, checkpoint['model'])\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # plot t-SNE\n        utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,\n                             filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)\n        # visualize ranked results\n        visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                 visualize_dir=logger.visualize_directory, width=args.width, height=args.height,\n                                 rerank=args.rerank)\n        return\n\n    if args.phase == 'test':\n        print(\"Test on Source domain:\")\n        validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        print(\"Test on target domain:\")\n        validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        return\n\n    # define loss function\n    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)\n    criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        utils.copy_state_dict(model, checkpoint['model'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    # start training\n    best_test_mAP = 0.\n    for epoch in range(args.start_epoch, args.epochs):\n        # run clustering algorithm and generate pseudo labels\n        if args.clustering_algorithm == 'kmeans':\n            train_target_iter = run_kmeans(cluster_loader, model, target_dataset, train_transform, args)\n        elif args.clustering_algorithm == 'dbscan':\n            train_target_iter, num_classes = run_dbscan(cluster_loader, model, target_dataset, train_transform, args)\n\n        # define cross entropy loss with current number of classes\n        criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)\n\n        # define optimizer\n        optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,\n                         weight_decay=args.weight_decay)\n\n        # train for one epoch\n        train(train_target_iter, model, optimizer, criterion_ce, criterion_triplet, epoch, args)\n\n        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):\n            # remember best mAP and save checkpoint\n            torch.save(\n                {\n                    'model': model.state_dict(),\n                    'epoch': epoch\n                }, logger.get_checkpoint_path(epoch)\n            )\n            print(\"Test on target domain...\")\n            _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                   cmc_flag=True, rerank=args.rerank)\n            if test_mAP > best_test_mAP:\n                shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))\n            best_test_mAP = max(test_mAP, best_test_mAP)\n\n    print(\"best mAP on target = {}\".format(best_test_mAP))\n    logger.close()\n\n\ndef run_kmeans(cluster_loader: DataLoader, model: DataParallel, target_dataset, train_transform,\n               args: argparse.Namespace):\n    # run kmeans clustering algorithm\n    print('Clustering into {} classes'.format(args.num_clusters))\n    feature_dict = extract_reid_feature(cluster_loader, model, device, normalize=True)\n    feature = torch.stack(list(feature_dict.values())).cpu().numpy()\n    km = KMeans(n_clusters=args.num_clusters, random_state=args.seed).fit(feature)\n    cluster_labels = km.labels_\n    cluster_centers = km.cluster_centers_\n    print('Clustering finished')\n\n    # normalize cluster centers and convert to pytorch tensor\n    cluster_centers = torch.from_numpy(cluster_centers).float().to(device)\n    cluster_centers = F.normalize(cluster_centers, dim=1)\n    # reinitialize classifier head\n    model.module.head.weight.data.copy_(cluster_centers)\n\n    # generate training set with pseudo labels\n    target_train_set = []\n    for (fname, _, cid), label in zip(target_dataset.train, cluster_labels):\n        target_train_set.append((fname, int(label), cid))\n\n    sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)\n    train_target_loader = DataLoader(\n        convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    return train_target_iter\n\n\ndef run_dbscan(cluster_loader: DataLoader, model: DataParallel, target_dataset, train_transform,\n               args: argparse.Namespace):\n    # run dbscan clustering algorithm\n    feature_dict = extract_reid_feature(cluster_loader, model, device, normalize=True)\n    feature = torch.stack(list(feature_dict.values())).cpu()\n    rerank_dist = utils.compute_rerank_dist(feature).numpy()\n\n    print('Clustering with dbscan algorithm')\n    dbscan = DBSCAN(eps=0.6, min_samples=4, metric='precomputed', n_jobs=-1)\n    cluster_labels = dbscan.fit_predict(rerank_dist)\n    print('Clustering finished')\n\n    # generate training set with pseudo labels and calculate cluster centers\n    target_train_set = []\n    cluster_centers = {}\n    for i, ((fname, _, cid), label) in enumerate(zip(target_dataset.train, cluster_labels)):\n        if label == -1:\n            continue\n        target_train_set.append((fname, label, cid))\n\n        if label not in cluster_centers:\n            cluster_centers[label] = []\n        cluster_centers[label].append(feature[i])\n\n    cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())]\n    cluster_centers = torch.stack(cluster_centers)\n    # normalize cluster centers\n    cluster_centers = F.normalize(cluster_centers, dim=1).float().to(device)\n\n    # reinitialize classifier head\n    features_dim = model.module.features_dim\n    num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)\n    model.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)\n    model.module.head.weight.data.copy_(cluster_centers)\n\n    sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)\n    train_target_loader = DataLoader(\n        convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    return train_target_iter, num_clusters\n\n\ndef train(train_target_iter: ForeverDataIterator, model, optimizer, criterion_ce: CrossEntropyLossWithLabelSmooth,\n          criterion_triplet: SoftTripletLoss, epoch: int, args: argparse.Namespace):\n    # train with pseudo labels\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_ce = AverageMeter('CeLoss', ':3.2f')\n    losses_triplet = AverageMeter('TripletLoss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        x_t, _, labels_t, _ = next(train_target_iter)\n\n        x_t = x_t.to(device)\n        labels_t = labels_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_t, f_t = model(x_t)\n\n        # cross entropy loss\n        loss_ce = criterion_ce(y_t, labels_t)\n        # triplet loss\n        loss_triplet = criterion_triplet(f_t, f_t, labels_t)\n        loss = loss_ce + loss_triplet * args.trade_off\n\n        cls_acc = accuracy(y_t, labels_t)[0]\n        losses_ce.update(loss_ce.item(), x_t.size(0))\n        losses_triplet.update(loss_triplet.item(), x_t.size(0))\n        losses.update(loss.item(), x_t.size(0))\n        cls_accs.update(cls_acc.item(), x_t.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description=\"Cluster Baseline for Domain Adaptative ReID\")\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', type=str, help='source domain')\n    parser.add_argument('-t', '--target', type=str, help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: reid_resnet50)')\n    parser.add_argument('--num-clusters', type=int, default=500)\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--rate', type=float, default=0.2)\n    # training parameters\n    parser.add_argument('--clustering-algorithm', type=str, default='dbscan', choices=['kmeans', 'dbscan'],\n                        help='clustering algorithm to run, currently supported method: [\"kmeans\", \"dbscan\"]')\n    parser.add_argument('--resume', type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    parser.add_argument('--pretrained-model-path', type=str, help='path to pretrained (source-only) model')\n    parser.add_argument('--trade-off', type=float, default=1,\n                        help='trade-off hyper parameter between cross entropy loss and triplet loss')\n    parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')\n    parser.add_argument('-j', '--workers', type=int, default=4)\n    parser.add_argument('-b', '--batch-size', type=int, default=64)\n    parser.add_argument('--height', type=int, default=256, help=\"input height\")\n    parser.add_argument('--width', type=int, default=128, help=\"input width\")\n    parser.add_argument('--num-instances', type=int, default=4,\n                        help=\"each minibatch consist of \"\n                             \"(batch_size // num_instances) identities, and \"\n                             \"each identity has num_instances instances, \"\n                             \"default: 4\")\n    parser.add_argument('--lr', type=float, default=0.00035,\n                        help=\"learning rate\")\n    parser.add_argument('--weight-decay', type=float, default=5e-4)\n    parser.add_argument('--epochs', type=int, default=40)\n    parser.add_argument('--start-epoch', default=0, type=int, help='start epoch')\n    parser.add_argument('--eval-step', type=int, default=1)\n    parser.add_argument('--iters-per-epoch', type=int, default=400)\n    parser.add_argument('--print-freq', type=int, default=40)\n    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')\n    parser.add_argument('--rerank', action='store_true', help=\"evaluation only\")\n    parser.add_argument(\"--log\", type=str, default='baseline_cluster',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/baseline_cluster.sh",
    "content": "#!/usr/bin/env bash\n# Market1501 -> Duke\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2Duke\n# step2: train with pseudo labels assigned by cluster algorithm\nCUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--pretrained-model-path logs/baseline/Market2Duke/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/baseline_cluster/Market2Duke\n\n# Duke -> Market1501\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2Market\n# step2: train with pseudo labels assigned by cluster algorithm\nCUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--pretrained-model-path logs/baseline/Duke2Market/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/baseline_cluster/Duke2Market\n\n# Market1501 -> MSMT\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMT\n# step2: train with pseudo labels assigned by cluster algorithm\nCUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--pretrained-model-path logs/baseline/Market2MSMT/checkpoints/best.pth \\\n--num-clusters 1000 --finetune --seed 0 --log logs/baseline_cluster/Market2MSMT\n\n# MSMT -> Market1501\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Market\n# step2: train with pseudo labels assigned by cluster algorithm\nCUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--pretrained-model-path logs/baseline/MSMT2Market/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/baseline_cluster/MSMT2Market\n\n# Duke -> MSMT\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMT\n# step2: train with pseudo labels assigned by cluster algorithm\nCUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--pretrained-model-path logs/baseline/Duke2MSMT/checkpoints/best.pth \\\n--num-clusters 1000 --finetune --seed 0 --log logs/baseline_cluster/Duke2MSMT\n\n# MSMT -> Duke\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Duke\n# step2: train with pseudo labels assigned by cluster algorithm\nCUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--pretrained-model-path logs/baseline/MSMT2Duke/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/baseline_cluster/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/ibn.sh",
    "content": "#!/usr/bin/env bash\n# Market1501 -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke\n\n# Duke -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market\n\n# Market1501 -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_a \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_b \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT\n\n# MSMT -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_a \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_b \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market\n\n# Duke -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT\n\n# MSMT -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/mmt.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport os.path as osp\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import DataParallel\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.utils.data import DataLoader\nfrom sklearn.cluster import KMeans, DBSCAN\n\nimport utils\nimport tllib.vision.datasets.reid as datasets\nfrom tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset\nfrom tllib.vision.models.reid.identifier import ReIdentifier\nfrom tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss, CrossEntropyLoss\nfrom tllib.self_training.mean_teacher import EMATeacher\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric.reid import extract_reid_feature, validate, visualize_ranked_results\nfrom tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,\n                                                random_horizontal_flip=True, random_color_jitter=False,\n                                                random_gray_scale=False, random_erasing=True)\n    val_transform = utils.get_val_transform(args.height, args.width)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    working_dir = osp.dirname(osp.abspath(__file__))\n    source_root = osp.join(working_dir, args.source_root)\n    target_root = osp.join(working_dir, args.target_root)\n\n    # source dataset\n    source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower()))\n    val_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),\n                                   root=source_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # target dataset\n    target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower()))\n    cluster_loader = DataLoader(\n        convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n    test_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),\n                                   root=target_dataset.images_dir, transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # create model\n    model_1, model_1_ema = create_model(args, args.pretrained_model_1_path)\n    model_2, model_2_ema = create_model(args, args.pretrained_model_2_path)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        utils.copy_state_dict(model_1_ema, checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # plot t-SNE\n        utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model_1_ema,\n                             filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)\n        # visualize ranked results\n        visualize_ranked_results(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery, device,\n                                 visualize_dir=logger.visualize_directory, width=args.width, height=args.height,\n                                 rerank=args.rerank)\n        return\n\n    if args.phase == 'test':\n        print(\"Test on Source domain:\")\n        validate(val_loader, model_1_ema, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        print(\"Test on target domain:\")\n        validate(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        return\n\n    # define loss function\n    num_classes = args.num_clusters\n    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)\n    criterion_ce_soft = CrossEntropyLoss().to(device)\n    criterion_triplet = SoftTripletLoss(margin=0.0).to(device)\n    criterion_triplet_soft = SoftTripletLoss(margin=None).to(device)\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        utils.copy_state_dict(model_1, checkpoint['model_1'])\n        utils.copy_state_dict(model_1_ema, checkpoint['model_1_ema'])\n        utils.copy_state_dict(model_2, checkpoint['model_2'])\n        utils.copy_state_dict(model_2_ema, checkpoint['model_2_ema'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    # start training\n    best_test_mAP = 0.\n    for epoch in range(args.start_epoch, args.epochs):\n        # run clustering algorithm and generate pseudo labels\n        if args.clustering_algorithm == 'kmeans':\n            train_target_iter = run_kmeans(cluster_loader, model_1, model_2, model_1_ema, model_2_ema, target_dataset,\n                                           train_transform, args)\n        elif args.clustering_algorithm == 'dbscan':\n            train_target_iter, num_classes = run_dbscan(cluster_loader, model_1, model_2, model_1_ema, model_2_ema,\n                                                        target_dataset, train_transform, args)\n\n        # define cross entropy loss with current number of classes\n        criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)\n\n        # define optimizer\n        optimizer = Adam(model_1.module.get_parameters(base_lr=args.lr, rate=args.rate) + model_2.module.get_parameters(\n            base_lr=args.lr, rate=args.rate), args.lr, weight_decay=args.weight_decay)\n\n        # train for one epoch\n        train(train_target_iter, model_1, model_1_ema, model_2, model_2_ema, optimizer, criterion_ce, criterion_ce_soft,\n              criterion_triplet, criterion_triplet_soft, epoch, args)\n\n        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):\n            # save checkpoint and remember best mAP\n            torch.save(\n                {\n                    'model_1': model_1.state_dict(),\n                    'model_1_ema': model_1_ema.state_dict(),\n                    'model_2': model_2.state_dict(),\n                    'model_2_ema': model_2_ema.state_dict(),\n                    'epoch': epoch\n                }, logger.get_checkpoint_path(epoch)\n            )\n            print(\"Test model_1 on target domain...\")\n            _, test_mAP_1 = validate(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery,\n                                     device, cmc_flag=True, rerank=args.rerank)\n            print(\"Test model_2 on target domain...\")\n            _, test_mAP_2 = validate(test_loader, model_2_ema, target_dataset.query, target_dataset.gallery,\n                                     device, cmc_flag=True, rerank=args.rerank)\n            if test_mAP_1 > test_mAP_2 and test_mAP_1 > best_test_mAP:\n                torch.save(model_1_ema.state_dict(), logger.get_checkpoint_path('best'))\n                best_test_mAP = test_mAP_1\n            if test_mAP_2 > test_mAP_1 and test_mAP_2 > best_test_mAP:\n                torch.save(model_2_ema.state_dict(), logger.get_checkpoint_path('best'))\n                best_test_mAP = test_mAP_2\n\n    print(\"best mAP on target = {}\".format(best_test_mAP))\n    logger.close()\n\n\ndef create_model(args: argparse.Namespace, pretrained_model_path: str):\n    num_classes = args.num_clusters\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)\n    model = DataParallel(model)\n\n    # load pretrained weights\n    pretrained_model = torch.load(pretrained_model_path)\n    utils.copy_state_dict(model, pretrained_model)\n\n    # EMA model\n    model_ema = EMATeacher(model, args.alpha)\n    return model, model_ema\n\n\ndef run_kmeans(cluster_loader: DataLoader, model_1: DataParallel, model_2: DataParallel, model_1_ema: EMATeacher,\n               model_2_ema: EMATeacher, target_dataset, train_transform, args: argparse.Namespace):\n    # run kmeans clustering algorithm\n    print('Clustering into {} classes'.format(args.num_clusters))\n    # collect feature with different ema teachers\n    feature_dict_1 = extract_reid_feature(cluster_loader, model_1_ema, device, normalize=True)\n    feature_1 = torch.stack(list(feature_dict_1.values())).cpu().numpy()\n    feature_dict_2 = extract_reid_feature(cluster_loader, model_2_ema, device, normalize=True)\n    feature_2 = torch.stack(list(feature_dict_2.values())).cpu().numpy()\n    # average feature_1, feature_2 to create final feature\n    feature = (feature_1 + feature_2) / 2\n\n    km = KMeans(n_clusters=args.num_clusters, random_state=args.seed).fit(feature)\n    cluster_labels = km.labels_\n    cluster_centers = km.cluster_centers_\n    print('Clustering finished')\n\n    # normalize cluster centers and convert to pytorch tensor\n    cluster_centers = torch.from_numpy(cluster_centers).float().to(device)\n    cluster_centers = F.normalize(cluster_centers, dim=1)\n\n    # reinitialize classifier head\n    model_1.module.head.weight.data.copy_(cluster_centers)\n    model_2.module.head.weight.data.copy_(cluster_centers)\n    model_1_ema.module.head.weight.data.copy_(cluster_centers)\n    model_2_ema.module.head.weight.data.copy_(cluster_centers)\n\n    # generate training set with pseudo labels\n    target_train_set = []\n    for (fname, _, cid), label in zip(target_dataset.train, cluster_labels):\n        target_train_set.append((fname, int(label), cid))\n\n    sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)\n    train_target_loader = DataLoader(\n        convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir,\n                                   transform=MultipleApply([train_transform, train_transform])),\n        batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    return train_target_iter\n\n\ndef run_dbscan(cluster_loader: DataLoader, model_1: DataParallel, model_2: DataParallel, model_1_ema: EMATeacher,\n               model_2_ema: EMATeacher, target_dataset, train_transform, args: argparse.Namespace):\n    # run dbscan clustering algorithm\n\n    # collect feature with different ema teachers\n    feature_dict_1 = extract_reid_feature(cluster_loader, model_1_ema, device, normalize=True)\n    feature_1 = torch.stack(list(feature_dict_1.values())).cpu()\n    feature_dict_2 = extract_reid_feature(cluster_loader, model_2_ema, device, normalize=True)\n    feature_2 = torch.stack(list(feature_dict_2.values())).cpu()\n    # average feature_1, feature_2 to create final feature\n    feature = (feature_1 + feature_2) / 2\n    feature = F.normalize(feature, dim=1)\n    rerank_dist = utils.compute_rerank_dist(feature).numpy()\n\n    print('Clustering with dbscan algorithm')\n    dbscan = DBSCAN(eps=0.7, min_samples=4, metric='precomputed', n_jobs=-1)\n    cluster_labels = dbscan.fit_predict(rerank_dist)\n    print('Clustering finished')\n\n    # generate training set with pseudo labels and calculate cluster centers\n    target_train_set = []\n    cluster_centers = {}\n    for i, ((fname, _, cid), label) in enumerate(zip(target_dataset.train, cluster_labels)):\n        if label == -1:\n            continue\n        target_train_set.append((fname, label, cid))\n\n        if label not in cluster_centers:\n            cluster_centers[label] = []\n        cluster_centers[label].append(feature[i])\n\n    cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())]\n    cluster_centers = torch.stack(cluster_centers)\n    # normalize cluster centers\n    cluster_centers = F.normalize(cluster_centers, dim=1).float().to(device)\n\n    # reinitialize classifier head\n    features_dim = model_1.module.features_dim\n    num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)\n\n    model_1.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)\n    model_2.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)\n    model_1_ema.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)\n    model_2_ema.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)\n\n    model_1.module.head.weight.data.copy_(cluster_centers)\n    model_2.module.head.weight.data.copy_(cluster_centers)\n    model_1_ema.module.head.weight.data.copy_(cluster_centers)\n    model_2_ema.module.head.weight.data.copy_(cluster_centers)\n\n    sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)\n    train_target_loader = DataLoader(\n        convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir,\n                                   transform=MultipleApply([train_transform, train_transform])),\n        batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    return train_target_iter, num_clusters\n\n\ndef train(train_target_iter: ForeverDataIterator, model_1: DataParallel, model_1_ema: EMATeacher, model_2: DataParallel,\n          model_2_ema: EMATeacher, optimizer: Adam, criterion_ce: CrossEntropyLossWithLabelSmooth,\n          criterion_ce_soft: CrossEntropyLoss, criterion_triplet: SoftTripletLoss,\n          criterion_triplet_soft: SoftTripletLoss, epoch: int, args: argparse.Namespace):\n    # train with pseudo labels\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    # statistics for model_1\n    losses_ce_1 = AverageMeter('Model_1 CELoss', ':3.2f')\n    losses_triplet_1 = AverageMeter('Model_1 TripletLoss', ':3.2f')\n    cls_accs_1 = AverageMeter('Model_1 Cls Acc', ':3.1f')\n    # statistics for model_2\n    losses_ce_2 = AverageMeter('Model_2 CELoss', ':3.2f')\n    losses_triplet_2 = AverageMeter('Model_2 TripletLoss', ':3.2f')\n    cls_accs_2 = AverageMeter('Model_2 Cls Acc', ':3.1f')\n\n    losses_ce_soft = AverageMeter('Soft CELoss', ':3.2f')\n    losses_triplet_soft = AverageMeter('Soft TripletLoss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_ce_1, losses_triplet_1, cls_accs_1, losses_ce_2, losses_triplet_2, cls_accs_2,\n         losses_ce_soft, losses_triplet_soft, losses],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model_1.train()\n    model_2.train()\n    model_1_ema.train()\n    model_2_ema.train()\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        # below we ignore subscript `t` and use `x_1`, `x_2` to denote different augmented versions of origin samples\n        # `x_t` from target domain\n        (x_1, x_2), _, labels, _ = next(train_target_iter)\n\n        x_1 = x_1.to(device)\n        x_2 = x_2.to(device)\n        labels = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_1, f_1 = model_1(x_1)\n        y_2, f_2 = model_2(x_2)\n        # compute output by ema-teacher\n        y_1_teacher, f_1_teacher = model_1_ema(x_1)\n        y_2_teacher, f_2_teacher = model_2_ema(x_2)\n\n        # cross entropy loss\n        loss_ce_1 = criterion_ce(y_1, labels)\n        loss_ce_2 = criterion_ce(y_2, labels)\n        # triplet loss\n        loss_triplet_1 = criterion_triplet(f_1, f_1, labels)\n        loss_triplet_2 = criterion_triplet(f_2, f_2, labels)\n        # soft cross entropy loss\n        loss_ce_soft = criterion_ce_soft(y_1, y_2_teacher) + \\\n                       criterion_ce_soft(y_2, y_1_teacher)\n        # soft triplet loss\n        loss_triplet_soft = criterion_triplet_soft(f_1, f_2_teacher, labels) + \\\n                            criterion_triplet_soft(f_2, f_1_teacher, labels)\n        # final objective\n        loss = (loss_ce_1 + loss_ce_2) * (1 - args.trade_off_ce_soft) + \\\n               (loss_triplet_1 + loss_triplet_2) * (1 - args.trade_off_triplet_soft) + \\\n               loss_ce_soft * args.trade_off_ce_soft + \\\n               loss_triplet_soft * args.trade_off_triplet_soft\n\n        # update statistics\n        batch_size = args.batch_size\n        cls_acc_1 = accuracy(y_1, labels)[0]\n        cls_acc_2 = accuracy(y_2, labels)[0]\n        # model 1\n        losses_ce_1.update(loss_ce_1.item(), batch_size)\n        losses_triplet_1.update(loss_triplet_1.item(), batch_size)\n        cls_accs_1.update(cls_acc_1.item(), batch_size)\n        # model 2\n        losses_ce_2.update(loss_ce_2.item(), batch_size)\n        losses_triplet_2.update(loss_triplet_2.item(), batch_size)\n        cls_accs_2.update(cls_acc_2.item(), batch_size)\n\n        losses_ce_soft.update(loss_ce_soft.item(), batch_size)\n        losses_triplet_soft.update(loss_triplet_soft.item(), batch_size)\n        losses.update(loss.item(), batch_size)\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # update teacher\n        global_step = epoch * args.iters_per_epoch + i + 1\n        model_1_ema.set_alpha(min(args.alpha, 1 - 1 / global_step))\n        model_2_ema.set_alpha(min(args.alpha, 1 - 1 / global_step))\n        model_1_ema.update()\n        model_2_ema.update()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description=\"MMT for Domain Adaptative ReID\")\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', type=str, help='source domain')\n    parser.add_argument('-t', '--target', type=str, help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: reid_resnet50)')\n    parser.add_argument('--num-clusters', type=int, default=500)\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--alpha', type=float, default=0.999, help='ema alpha')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--rate', type=float, default=0.2)\n    # training parameters\n    parser.add_argument('--clustering-algorithm', type=str, default='dbscan', choices=['kmeans', 'dbscan'],\n                        help='clustering algorithm to run, currently supported method: [\"kmeans\", \"dbscan\"]')\n    parser.add_argument('--resume', type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    parser.add_argument('--pretrained-model-1-path', type=str, help='path to pretrained (source-only) model_1')\n    parser.add_argument('--pretrained-model-2-path', type=str, help='path to pretrained (source-only) model_2')\n    parser.add_argument('--trade-off-ce-soft', type=float, default=0.5,\n                        help='the trade off hyper parameter between cross entropy loss and soft cross entropy loss')\n    parser.add_argument('--trade-off-triplet-soft', type=float, default=0.8,\n                        help='the trade off hyper parameter between triplet loss and soft triplet loss')\n    parser.add_argument('-j', '--workers', type=int, default=4)\n    parser.add_argument('-b', '--batch-size', type=int, default=64)\n    parser.add_argument('--height', type=int, default=256, help=\"input height\")\n    parser.add_argument('--width', type=int, default=128, help=\"input width\")\n    parser.add_argument('--num-instances', type=int, default=4,\n                        help=\"each minibatch consist of \"\n                             \"(batch_size // num_instances) identities, and \"\n                             \"each identity has num_instances instances, \"\n                             \"default: 4\")\n    parser.add_argument('--lr', type=float, default=0.00035,\n                        help=\"learning rate\")\n    parser.add_argument('--weight-decay', type=float, default=5e-4)\n    parser.add_argument('--epochs', type=int, default=40)\n    parser.add_argument('--start-epoch', default=0, type=int, help='start epoch')\n    parser.add_argument('--eval-step', type=int, default=1)\n    parser.add_argument('--iters-per-epoch', type=int, default=400)\n    parser.add_argument('--print-freq', type=int, default=40)\n    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')\n    parser.add_argument('--rerank', action='store_true', help=\"evaluation only\")\n    parser.add_argument(\"--log\", type=str, default='mmt',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/mmt.sh",
    "content": "#!/usr/bin/env bash\n# Market1501 -> Duke\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2DukeSeed0\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2DukeSeed1\n# step2: train mmt\nCUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--pretrained-model-1-path logs/baseline/Market2DukeSeed0/checkpoints/best.pth \\\n--pretrained-model-2-path logs/baseline/Market2DukeSeed1/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/mmt/Market2Duke\n\n# Duke -> Market1501\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MarketSeed0\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Duke2MarketSeed1\n# step2: train mmt\nCUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--pretrained-model-1-path logs/baseline/Duke2MarketSeed0/checkpoints/best.pth \\\n--pretrained-model-2-path logs/baseline/Duke2MarketSeed1/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/mmt/Duke2Market\n\n# Market1501 -> MSMT\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMTSeed0\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2MSMTSeed1\n# step2: train mmt\nCUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--pretrained-model-1-path logs/baseline/Market2MSMTSeed0/checkpoints/best.pth \\\n--pretrained-model-2-path logs/baseline/Market2MSMTSeed1/checkpoints/best.pth \\\n--num-clusters 1000 --finetune --seed 0 --log logs/mmt/Market2MSMT\n\n# MSMT -> Market1501\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2MarketSeed0\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/MSMT2MarketSeed1\n# step2: train mmt\nCUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--pretrained-model-1-path logs/baseline/MSMT2MarketSeed0/checkpoints/best.pth \\\n--pretrained-model-2-path logs/baseline/MSMT2MarketSeed1/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/mmt/MSMT2Market\n\n# Duke -> MSMT\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMTSeed0\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Duke2MSMTSeed1\n# step2: train mmt\nCUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--pretrained-model-1-path logs/baseline/Duke2MSMTSeed0/checkpoints/best.pth \\\n--pretrained-model-2-path logs/baseline/Duke2MSMTSeed1/checkpoints/best.pth \\\n--num-clusters 1000 --finetune --seed 0 --log logs/mmt/Duke2MSMT\n\n# MSMT -> Duke\n# step1: pretrain\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2DukeSeed0\nCUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/MSMT2DukeSeed1\n# step2: train mmt\nCUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--pretrained-model-1-path logs/baseline/MSMT2DukeSeed0/checkpoints/best.pth \\\n--pretrained-model-2-path logs/baseline/MSMT2DukeSeed1/checkpoints/best.pth \\\n--finetune --seed 0 --log logs/mmt/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/requirements.txt",
    "content": "timm\nopencv-python"
  },
  {
    "path": "examples/domain_adaptation/re_identification/spgan.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport itertools\nimport os.path as osp\nfrom PIL import Image\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as T\n\nsys.path.append('../../..')\nimport tllib.translation.cyclegan as cyclegan\nimport tllib.translation.spgan as spgan\nfrom tllib.translation.cyclegan.util import ImagePool, set_requires_grad\nimport tllib.vision.datasets.reid as datasets\nfrom tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset\nfrom tllib.vision.transforms import Denormalize\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = T.Compose([\n        T.Resize(args.load_size, Image.BICUBIC),\n        T.RandomCrop(args.input_size),\n        T.RandomHorizontalFlip(),\n        T.ToTensor(),\n        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n    ])\n\n    working_dir = osp.dirname(osp.abspath(__file__))\n    root = osp.join(working_dir, args.root)\n\n    source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower()))\n    train_source_loader = DataLoader(\n        convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True)\n\n    target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower()))\n    train_target_loader = DataLoader(\n        convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # define networks (generators, discriminators and siamese network)\n    netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n    netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n    siamese_net = spgan.SiameseNetwork(nsf=args.nsf).to(device)\n\n    # create image buffer to store previously generated images\n    fake_S_pool = ImagePool(args.pool_size)\n    fake_T_pool = ImagePool(args.pool_size)\n\n    # define optimizer and lr scheduler\n    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr,\n                       betas=(args.beta1, 0.999))\n    optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))\n    optimizer_siamese = Adam(siamese_net.parameters(), lr=args.lr, betas=(args.beta1, 0.999))\n\n    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)\n    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)\n    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)\n    lr_scheduler_siamese = LambdaLR(optimizer_siamese, lr_lambda=lr_decay_function)\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        print(\"Resume from\", args.resume)\n        checkpoint = torch.load(args.resume, map_location='cpu')\n\n        netG_S2T.load_state_dict(checkpoint['netG_S2T'])\n        netG_T2S.load_state_dict(checkpoint['netG_T2S'])\n        netD_S.load_state_dict(checkpoint['netD_S'])\n        netD_T.load_state_dict(checkpoint['netD_T'])\n        siamese_net.load_state_dict(checkpoint['siamese_net'])\n\n        optimizer_G.load_state_dict(checkpoint['optimizer_G'])\n        optimizer_D.load_state_dict(checkpoint['optimizer_D'])\n        optimizer_siamese.load_state_dict(checkpoint['optimizer_siamese'])\n        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])\n        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])\n        lr_scheduler_siamese.load_state_dict(checkpoint['lr_scheduler_siamese'])\n\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    if args.phase == 'test':\n        transform = T.Compose([\n            T.Resize(args.test_input_size, Image.BICUBIC),\n            cyclegan.transform.Translation(netG_S2T, device)\n        ])\n        source_dataset.translate(transform, osp.join(args.translated_root, args.source.lower()))\n        return\n\n    # define loss function\n    criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()\n    criterion_cycle = nn.L1Loss()\n    criterion_identity = nn.L1Loss()\n    criterion_contrastive = spgan.ContrastiveLoss(margin=args.margin)\n\n    # define visualization function\n    tensor_to_image = T.Compose([\n        Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n        T.ToPILImage()\n    ])\n\n    def visualize(image, name):\n        \"\"\"\n        Args:\n            image (tensor): image in shape 3 x H x W\n            name: name of the saving image\n        \"\"\"\n        tensor_to_image(image).save(logger.get_image_path(\"{}.png\".format(name)))\n\n    # start training\n    for epoch in range(args.start_epoch, args.epochs + args.epochs_decay):\n        logger.set_epoch(epoch)\n        print(lr_scheduler_G.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, siamese_net,\n              criterion_gan, criterion_cycle, criterion_identity, criterion_contrastive,\n              optimizer_G, optimizer_D, optimizer_siamese,\n              fake_S_pool, fake_T_pool, epoch, visualize, args)\n\n        # update learning rates\n        lr_scheduler_G.step()\n        lr_scheduler_D.step()\n        lr_scheduler_siamese.step()\n\n        # save checkpoint\n        torch.save(\n            {\n                'netG_S2T': netG_S2T.state_dict(),\n                'netG_T2S': netG_T2S.state_dict(),\n                'netD_S': netD_S.state_dict(),\n                'netD_T': netD_T.state_dict(),\n                'siamese_net': siamese_net.state_dict(),\n                'optimizer_G': optimizer_G.state_dict(),\n                'optimizer_D': optimizer_D.state_dict(),\n                'optimizer_siamese': optimizer_siamese.state_dict(),\n                'lr_scheduler_G': lr_scheduler_G.state_dict(),\n                'lr_scheduler_D': lr_scheduler_D.state_dict(),\n                'lr_scheduler_siamese': lr_scheduler_siamese.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n\n    if args.translated_root is not None:\n        transform = T.Compose([\n            T.Resize(args.test_input_size, Image.BICUBIC),\n            cyclegan.transform.Translation(netG_S2T, device)\n        ])\n        source_dataset.translate(transform, osp.join(args.translated_root, args.source.lower()))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          netG_S2T, netG_T2S, netD_S, netD_T, siamese_net: spgan.SiameseNetwork,\n          criterion_gan: cyclegan.LeastSquaresGenerativeAdversarialLoss,\n          criterion_cycle: nn.L1Loss, criterion_identity: nn.L1Loss,\n          criterion_contrastive: spgan.ContrastiveLoss,\n          optimizer_G: Adam, optimizer_D: Adam, optimizer_siamese: Adam,\n          fake_S_pool: ImagePool, fake_T_pool: ImagePool, epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')\n    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')\n    losses_D_S = AverageMeter('D_S', ':3.2f')\n    losses_D_T = AverageMeter('D_T', ':3.2f')\n    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')\n    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')\n    losses_identity_S = AverageMeter('idt_S', ':3.2f')\n    losses_identity_T = AverageMeter('idt_T', ':3.2f')\n    losses_contrastive_G = AverageMeter('contrastive_G', ':3.2f')\n    losses_contrastive_siamese = AverageMeter('contrastive_siamese', ':3.2f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,\n         losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T,\n         losses_contrastive_G, losses_contrastive_siamese],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        real_S, _, _, _ = next(train_source_iter)\n        real_T, _, _, _ = next(train_target_iter)\n\n        real_S = real_S.to(device)\n        real_T = real_T.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # Compute fake images and reconstruction images.\n        fake_T = netG_S2T(real_S)\n        rec_S = netG_T2S(fake_T)\n        fake_S = netG_T2S(real_T)\n        rec_T = netG_S2T(fake_S)\n\n        # ===============================================\n        # train the generators (every two iterations)\n        # ===============================================\n        if i % 2 == 0:\n            # save memory\n            set_requires_grad(netD_S, False)\n            set_requires_grad(netD_T, False)\n            set_requires_grad(siamese_net, False)\n            # GAN loss D_T(G_S2T(S))\n            loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)\n            # GAN loss D_S(G_T2S(B))\n            loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)\n            # Cycle loss || G_T2S(G_S2T(S)) - S||\n            loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle\n            # Cycle loss || G_S2T(G_T2S(T)) - T||\n            loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle\n            # Identity loss\n            # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||\n            identity_T = netG_S2T(real_T)\n            loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity\n            # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||\n            identity_S = netG_T2S(real_S)\n            loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity\n\n            # siamese network output\n            f_real_S = siamese_net(real_S)\n            f_fake_T = siamese_net(fake_T)\n            f_real_T = siamese_net(real_T)\n            f_fake_S = siamese_net(fake_S)\n\n            # positive pair\n            loss_contrastive_p_G = criterion_contrastive(f_real_S, f_fake_T, 0) + \\\n                                   criterion_contrastive(f_real_T, f_fake_S, 0)\n            # negative pair\n            loss_contrastive_n_G = criterion_contrastive(f_fake_T, f_real_T, 1) + \\\n                                   criterion_contrastive(f_fake_S, f_real_S, 1) + \\\n                                   criterion_contrastive(f_real_S, f_real_T, 1)\n            # contrastive loss\n            loss_contrastive_G = (loss_contrastive_p_G + 0.5 * loss_contrastive_n_G) / 4 * args.trade_off_contrastive\n\n            # combined loss and calculate gradients\n            loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T\n            if epoch > 1:\n                loss_G += loss_contrastive_G\n            netG_S2T.zero_grad()\n            netG_T2S.zero_grad()\n            loss_G.backward()\n            optimizer_G.step()\n\n            # update corresponding statistics\n            losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))\n            losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))\n            losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))\n            losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))\n            losses_identity_S.update(loss_identity_S.item(), real_S.size(0))\n            losses_identity_T.update(loss_identity_T.item(), real_S.size(0))\n            if epoch > 1:\n                losses_contrastive_G.update(loss_contrastive_G, real_S.size(0))\n\n        # ===============================================\n        # train the siamese network (when epoch > 0)\n        # ===============================================\n        if epoch > 0:\n            set_requires_grad(siamese_net, True)\n            # siamese network output\n            f_real_S = siamese_net(real_S)\n            f_fake_T = siamese_net(fake_T.detach())\n            f_real_T = siamese_net(real_T)\n            f_fake_S = siamese_net(fake_S.detach())\n\n            # positive pair\n            loss_contrastive_p_siamese = criterion_contrastive(f_real_S, f_fake_T, 0) + \\\n                                         criterion_contrastive(f_real_T, f_fake_S, 0)\n            # negative pair\n            loss_contrastive_n_siamese = criterion_contrastive(f_real_S, f_real_T, 1)\n            # contrastive loss\n            loss_contrastive_siamese = (loss_contrastive_p_siamese + 2 * loss_contrastive_n_siamese) / 3\n\n            # update siamese network\n            siamese_net.zero_grad()\n            loss_contrastive_siamese.backward()\n            optimizer_siamese.step()\n\n            # update corresponding statistics\n            losses_contrastive_siamese.update(loss_contrastive_siamese, real_S.size(0))\n\n        # ===============================================\n        # train the discriminators\n        # ===============================================\n\n        set_requires_grad(netD_S, True)\n        set_requires_grad(netD_T, True)\n        # Calculate GAN loss for discriminator D_S\n        fake_S_ = fake_S_pool.query(fake_S.detach())\n        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))\n        # Calculate GAN loss for discriminator D_T\n        fake_T_ = fake_T_pool.query(fake_T.detach())\n        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))\n\n        # update discriminators\n        netD_S.zero_grad()\n        netD_T.zero_grad()\n        loss_D_S.backward()\n        loss_D_T.backward()\n        optimizer_D.step()\n\n        # update corresponding statistics\n        losses_D_S.update(loss_D_S.item(), real_S.size(0))\n        losses_D_T.update(loss_D_T.item(), real_S.size(0))\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n            for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T],\n                                    [\"real_S\", \"real_T\", \"fake_S\", \"fake_T\", \"rec_S\",\n                                     \"rec_T\", \"identity_S\", \"identity_T\"]):\n                visualize(tensor[0], \"{}_{}\".format(i, name))\n\n\nif __name__ == '__main__':\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description='SPGAN for Domain Adaptative ReID')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-s', '--source', type=str, help='source domain')\n    parser.add_argument('-t', '--target', type=str, help='target domain')\n    parser.add_argument('--load-size', nargs='+', type=int, default=(286, 144), help='loading image size')\n    parser.add_argument('--input-size', nargs='+', type=int, default=(256, 128),\n                        help='the input and output image size during training')\n    # model parameters\n    parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n    parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n    parser.add_argument('--nsf', type=int, default=64, help='# of sianet filters int the first conv layer')\n    parser.add_argument('--netD', type=str, default='patch',\n                        help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')\n    parser.add_argument('--netG', type=str, default='resnet_9',\n                        help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')\n    parser.add_argument('--norm', type=str, default='instance',\n                        help='instance normalization or batch normalization [instance | batch | none]')\n    # training parameters\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')\n    parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')\n    parser.add_argument('--trade-off-contrastive', type=float, default=2.0, help='trade off for contrastive loss')\n    parser.add_argument('--margin', type=float, default=2,\n                        help='margin for contrastive loss')\n    parser.add_argument('-b', '--batch-size', default=8, type=int,\n                        metavar='N', help='mini-batch size (default: 8)')\n    parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n    parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=15, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--epochs-decay', type=int, default=15,\n                        help='number of epochs to linearly decay learning rate to zero')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('-i', '--iters-per-epoch', default=2000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('--pool-size', type=int, default=50,\n                        help='the size of image buffer that stores previously generated images')\n    parser.add_argument('-p', '--print-freq', default=500, type=int,\n                        metavar='N', help='print frequency (default: 500)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='spgan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    # test parameters\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--translated-root', type=str, default=None,\n                        help=\"The root to put the translated dataset\")\n    parser.add_argument('--test-input-size', nargs='+', type=int, default=(256, 128),\n                        help='the input image size during testing')\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/spgan.sh",
    "content": "# Market1501 -> Duke\n# step1: train SPGAN\nCUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t DukeMTMC \\\n--log logs/spgan/Market2Duke --translated-root data/spganM2D --seed 0\n# step2: train baseline on translated source dataset\nCUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2D data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2Duke\n\n# Duke -> Market1501\n# step1: train SPGAN\nCUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t Market1501 \\\n--log logs/spgan/Duke2Market --translated-root data/spganD2M --seed 0\n# step2: train baseline on translated source dataset\nCUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2M data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2Market\n\n# Market1501 -> MSMT17\n# step1: train SPGAN\nCUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t MSMT17 \\\n--log logs/spgan/Market2MSMT --translated-root data/spganM2S --seed 0\n# step2: train baseline on translated source dataset\nCUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2S data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2MSMT\n\n# MSMT -> Market1501\n# step1: train SPGAN\nCUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t Market1501 \\\n--log logs/spgan/MSMT2Market --translated-root data/spganS2M --seed 0\n# step2: train baseline on translated source dataset\nCUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2M data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Market\n\n# Duke -> MSMT\n# step1: train SPGAN\nCUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t MSMT17 \\\n--log logs/spgan/Duke2MSMT --translated-root data/spganD2S --seed 0\n# step2: train baseline on translated source dataset\nCUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2S data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2MSMT\n\n# MSMT -> Duke\n# step1: train SPGAN\nCUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t DukeMTMC \\\n--log logs/spgan/MSMT2Duke --translated-root data/spganS2D --seed 0\n# step2: train baseline on translated source dataset\nCUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2D data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_adaptation/re_identification/utils.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport sys\nimport timm\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import Parameter\nimport torchvision.transforms as T\n\nsys.path.append('../../..')\nfrom tllib.utils.metric.reid import extract_reid_feature\nfrom tllib.utils.analysis import tsne\nfrom tllib.vision.transforms import RandomErasing\nimport tllib.vision.models.reid as models\nimport tllib.normalization.ibn as ibn_models\n\n\ndef copy_state_dict(model, state_dict, strip=None):\n    \"\"\"Copy state dict into the passed in ReID model. As we are using classification loss, which means we need to output\n    different number of classes(identities) for different datasets, we will not copy the parameters of last `fc` layer.\n    \"\"\"\n    tgt_state = model.state_dict()\n    copied_names = set()\n    for name, param in state_dict.items():\n        if strip is not None and name.startswith(strip):\n            name = name[len(strip):]\n        if name not in tgt_state:\n            continue\n        if isinstance(param, Parameter):\n            param = param.data\n        if param.size() != tgt_state[name].size():\n            print('mismatch:', name, param.size(), tgt_state[name].size())\n            continue\n        tgt_state[name].copy_(param)\n        copied_names.add(name)\n\n    missing = set(tgt_state.keys()) - copied_names\n    if len(missing) > 0:\n        print(\"missing keys in state_dict:\", missing)\n\n    return model\n\n\ndef get_model_names():\n    return sorted(name for name in models.__dict__ if\n                  name.islower() and not name.startswith(\"__\") and callable(models.__dict__[name])) + \\\n           sorted(name for name in ibn_models.__dict__ if\n                  name.islower() and not name.startswith(\"__\") and callable(ibn_models.__dict__[name])) + \\\n           timm.list_models()\n\n\ndef get_model(model_name):\n    if model_name in models.__dict__:\n        # load models from tllib.vision.models\n        backbone = models.__dict__[model_name](pretrained=True)\n    elif model_name in ibn_models.__dict__:\n        # load models (with ibn) from tllib.normalization.ibn\n        backbone = ibn_models.__dict__[model_name](pretrained=True)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=True)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n    return backbone\n\n\ndef get_train_transform(height, width, resizing='default', random_horizontal_flip=True, random_color_jitter=False,\n                        random_gray_scale=False, random_erasing=False):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to (height, width), zero-pad it by 10 on each size, the take a random crop of\n            (height, width)\n        - res: resize the image to(height, width)\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            T.Resize((height, width), interpolation=3),\n            T.Pad(10),\n            T.RandomCrop((height, width))\n        ])\n    elif resizing == 'res':\n        transform = T.Resize((height, width), interpolation=3)\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if random_color_jitter:\n        transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3))\n    if random_gray_scale:\n        transforms.append(T.RandomGrayscale())\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n    if random_erasing:\n        transforms.append(RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]))\n    return T.Compose(transforms)\n\n\ndef get_val_transform(height, width):\n    return T.Compose([\n        T.Resize((height, width), interpolation=3),\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n\n\ndef visualize_tsne(source_loader, target_loader, model, filename, device, n_data_points_per_domain=3000):\n    \"\"\"Visualize features from different domains using t-SNE. As we can have very large number of samples in each\n    domain, only `n_data_points_per_domain` number of samples are randomly selected in each domain.\n    \"\"\"\n    source_feature_dict = extract_reid_feature(source_loader, model, device, normalize=True)\n    source_feature = torch.stack(list(source_feature_dict.values())).cpu()\n    source_feature = source_feature[torch.randperm(len(source_feature))]\n    source_feature = source_feature[:n_data_points_per_domain]\n\n    target_feature_dict = extract_reid_feature(target_loader, model, device, normalize=True)\n    target_feature = torch.stack(list(target_feature_dict.values())).cpu()\n    target_feature = target_feature[torch.randperm(len(target_feature))]\n    target_feature = target_feature[:n_data_points_per_domain]\n\n    tsne.visualize(source_feature, target_feature, filename, source_color='cornflowerblue', target_color='darkorange')\n    print('T-SNE process is done, figure is saved to {}'.format(filename))\n\n\ndef k_reciprocal_neigh(initial_rank, i, k1):\n    \"\"\"Compute k-reciprocal neighbors of i-th sample. Two samples f_i, f_j are k reciprocal-neighbors if and only if\n    each one of them is among the k-nearest samples of another sample.\n    \"\"\"\n    forward_k_neigh_index = initial_rank[i, :k1 + 1]\n    backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]\n    fi = torch.nonzero(backward_k_neigh_index == i)[:, 0]\n    return forward_k_neigh_index[fi]\n\n\ndef compute_rerank_dist(target_features, k1=30, k2=6):\n    \"\"\"Compute distance according to `Re-ranking Person Re-identification with k-reciprocal Encoding\n    (CVPR 2017) <https://arxiv.org/pdf/1701.08398.pdf>`_.\n    \"\"\"\n    n = target_features.size(0)\n    original_dist = torch.pow(target_features, 2).sum(dim=1, keepdim=True) * 2\n    original_dist = original_dist.expand(n, n) - 2 * torch.mm(target_features, target_features.t())\n    original_dist /= original_dist.max(0)[0]\n    original_dist = original_dist.t()\n    initial_rank = torch.argsort(original_dist, dim=-1)\n    all_num = gallery_num = original_dist.size(0)\n\n    del target_features\n\n    nn_k1 = []\n    nn_k1_half = []\n    for i in range(all_num):\n        nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1))\n        nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2))))\n\n    V = torch.zeros(all_num, all_num)\n    for i in range(all_num):\n        k_reciprocal_index = nn_k1[i]\n        k_reciprocal_expansion_index = k_reciprocal_index\n        for candidate in k_reciprocal_index:\n            candidate_k_reciprocal_index = nn_k1_half[candidate]\n            if (len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(\n                    candidate_k_reciprocal_index)):\n                k_reciprocal_expansion_index = torch.cat((k_reciprocal_expansion_index, candidate_k_reciprocal_index))\n\n        k_reciprocal_expansion_index = torch.unique(k_reciprocal_expansion_index)\n        weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index])\n        V[i, k_reciprocal_expansion_index] = weight / torch.sum(weight)\n\n    if k2 != 1:\n        k2_rank = initial_rank[:, :k2].clone().view(-1)\n        V_qe = V[k2_rank]\n        V_qe = V_qe.view(initial_rank.size(0), k2, -1).sum(1)\n        V_qe /= k2\n        V = V_qe\n        del V_qe\n    del initial_rank\n\n    invIndex = []\n    for i in range(gallery_num):\n        invIndex.append(torch.nonzero(V[:, i])[:, 0])\n\n    jaccard_dist = torch.zeros_like(original_dist)\n    for i in range(all_num):\n        temp_min = torch.zeros(1, gallery_num)\n        indNonZero = torch.nonzero(V[i, :])[:, 0]\n        indImages = [invIndex[ind] for ind in indNonZero]\n        for j in range(len(indNonZero)):\n            temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + \\\n                                        torch.min(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])\n        jaccard_dist[i] = 1 - temp_min / (2 - temp_min)\n    del invIndex\n    del V\n\n    pos_bool = (jaccard_dist < 0)\n    jaccard_dist[pos_bool] = 0.0\n    return jaccard_dist\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/README.md",
    "content": "# Unsupervised Domain Adaptation for Semantic Segmentation\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\n## Dataset\n\nYou need to prepare following datasets manually if you want to use them:\n- [Cityscapes](https://www.cityscapes-dataset.com/)\n- [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/)\n- [Synthia](https://synthia-dataset.net/)\n\n#### Cityscapes, Foggy Cityscapes\n  - Download Cityscapes and Foggy Cityscapes dataset from the [link](https://www.cityscapes-dataset.com/downloads/). Particularly, we use *leftImg8bit_trainvaltest.zip* for Cityscapes and *leftImg8bit_trainvaltest_foggy.zip* for Foggy Cityscapes.\n  - Unzip them under the directory like\n  \n ```\ndata/Cityscapes\n├── gtFine\n├── leftImg8bit\n│   ├── train\n│   ├── val\n│   └── test\n├── leftImg8bit_foggy\n│   ├── train\n│   ├── val\n│   └── test\n└── ...\n```\n\n#### GTA-5\nYou need to download GTA5 manually from [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/).\nEnsure that there exist following directories before you use this dataset.\n ```\ndata/GTA5\n├── images\n├── labels\n└── ...\n```\n\n#### Synthia\nYou need to download Synthia manually from [Synthia](https://synthia-dataset.net/).\nEnsure that there exist following directories before you use this dataset.\n ```\ndata/synthia\n├── RGB\n├── synthia_mapped_to_cityscapes\n└── ...\n```\n\n\n## Supported Methods\n\nSupported methods include:\n\n- [Cycle-Consistent Adversarial Networks (CycleGAN)](https://arxiv.org/pdf/1703.10593.pdf)\n- [CyCADA: Cycle-Consistent Adversarial Domain Adaptation](https://arxiv.org/abs/1711.03213)\n- [Adversarial Entropy Minimization (ADVENT)](https://arxiv.org/abs/1811.12833)\n- [Fourier Domain Adaptation (FDA)](https://arxiv.org/abs/2004.05498)\n\n## Experiment and Results\n\n**Notations**\n- ``Origin`` means the accuracy reported by the original paper.\n- ``mIoU`` is the accuracy reported by `TLlib`.\n- ``ERM`` refers to the model trained with data from the source domain.\n- ``Oracle`` refers to the model trained with data from the target domain.\n\n\n### GTA5->Cityscapes mIoU on deeplabv2 (ResNet-101)\n\n| GTA5        | Origin | mIoU | road | sidewalk | building | wall | fence | pole | traffic light | traffic sign | vegetation | terrian | sky  | person | rider | car  | truck | bus  | train | motorbike | bicycle |\n|-------------|--------|------|------|----------|----------|------|-------|------|---------------|--------------|------------|---------|------|--------|-------|------|-------|------|-------|-----------|---------|\n| ERM | 27.1   | 37.3 | 66.5 | 17.4     | 73.3     | 13.4 | 21.5  | 22.8 | 30.1          | 17.1         | 82.2       | 7.1     | 73.6 | 57.4   | 28.4  | 78.6 | 36.1  | 13.4 | 1.5   | 31.9      | 36.2    |\n| AdvEnt      | 43.8   | 43.8 | 89.3 | 33.9     | 80.3     | 24.0 | 25.2  | 27.8 | 36.7          | 18.2         | 84.3       | 33.9    | 81.3 | 59.8   | 28.4  | 84.3 | 34.1  | 44.4 | 0.1   | 33.2      | 12.9    |\n| FDA         | 44.6   | 45.6 | 85.5 | 31.7     | 81.8     | 27.1 | 24.9  | 28.9 | 38.1          | 23.2         | 83.7       | 40.3    | 80.6 | 60.5   | 30.3  | 79.1 | 32.8  | 45.1 | 5.0   | 32.4      | 35.2    |\n| Cycada      | 42.7   | 47.4 | 87.3 | 35.7     | 83.7     | 31.3 | 24.0  | 32.2 | 35.8          | 30.3         | 82.7       | 32.0    | 85.7 | 60.8   | 31.5  | 85.6 | 39.8  | 43.3 | 5.4   | 29.5      | 44.6    |\n| CycleGAN    |        | 47.0 | 88.4 | 41.9     | 83.6     | 34.4 | 23.9  | 32.9 | 35.5          | 26.0         | 83.1       | 36.8    | 82.3 | 59.9   | 27.0  | 83.4 | 31.6  | 42.3 | 11.0  | 28.2      | 40.5    |\n| Oracle      | 65.1   | 70.5 | 97.4 | 79.7     | 90.1     | 53.0 | 50.0  | 48.0 | 55.5          | 67.2         | 90.2       | 60.0    | 93.0 | 72.7   | 55.2  | 92.7 | 76.5  | 78.5 | 56.0  | 54.6      | 68.8    |\n\n### Synthia->Cityscapes mIoU on deeplabv2 (ResNet-101)\n\n| Synthia     | Origin | mIoU | road | sidewalk | building | traffic light | traffic sign | vegetation | sky  | person | rider | car  | bus  | motorbike | bicycle |\n|-------------|--------|------|------|----------|----------|---------------|--------------|------------|------|--------|-------|------|------|-----------|---------|\n| ERM | 22.1   | 41.5 | 59.6 | 21.1     | 77.4     | 7.7           | 17.6         | 78.0       | 84.5 | 53.2   | 16.9  | 65.9 | 24.9 | 8.5       | 24.8    |\n| AdvEnt      | 47.6   | 47.9 | 88.3 | 44.9     | 80.5     | 4.5           | 9.1          | 81.3       | 86.2 | 52.9   | 21.0  | 82.0 | 30.3 | 11.9      | 30.2    |\n| FDA         | -      | 43.9 | 62.5 | 23.7     | 78.5     | 9.4           | 15.7         | 78.3       | 81.1 | 52.3   | 18.7  | 79.8 | 32.5 | 8.7       | 29.6    |\n| Oracle      | 71.7   | 76.6 | 97.4 | 79.7     | 90.1     | 55.5          | 67.2         | 90.2       | 93.0 | 72.7   | 55.2  | 92.7 | 78.5 | 54.6      | 68.8    |\n\n### Cityscapes->Foggy Cityscapes mIoU on deeplabv2 (ResNet-101)\n\n| Foggy       | Origin | mIoU | road | sidewalk | building | wall | fence | pole | traffic light | traffic sign | vegetation | terrian | sky  | person | rider | car  | truck | bus  | train | motorbike | bicycle |\n|-------------|--------|------|------|----------|----------|------|-------|------|---------------|--------------|------------|---------|------|--------|-------|------|-------|------|-------|-----------|---------|\n| ERM |        | 51.2 | 95.3 | 70.2     | 64.1     | 31.9 | 35.2  | 30.7 | 33.3          | 51.1         | 42.3       | 44.0    | 32.1 | 64.4   | 47.0  | 86.0 | 64.4  | 56.4 | 21.1  | 43.1      | 60.8    |\n| AdvEnt      |        | 61.8 | 96.8 | 75.1     | 76.4     | 46.2 | 42.6  | 39.3 | 43.6          | 58.9         | 74.3       | 50.1    | 75.9 | 67.3   | 51.0  | 89.4 | 70.5  | 64.7 | 39.9  | 47.9      | 65.0    |\n| FDA         |        | 61.9 | 96.9 | 77.2     | 75.3     | 46.5 | 42.0  | 39.8 | 47.1          | 61.0         | 72.7       | 54.6    | 63.8 | 68.4   | 50.1  | 90.1 | 72.8  | 68.0 | 35.5  | 50.8      | 64.2    |\n| Cycada      |        | 63.3 | 96.8 | 75.5     | 79.1     | 38.0 | 40.3  | 42.1 | 48.2          | 61.2         | 76.9       | 52.1    | 77.6 | 68.6   | 51.7  | 90.4 | 71.7  | 70.4 | 43.3  | 52.6      | 65.7    |\n| CycleGAN    |        | 66.0 | 97.1 | 77.6     | 84.3     | 42.7 | 46.3  | 42.8 | 47.5          | 61.0         | 84.0       | 55.2    | 83.4 | 69.4   | 51.8  | 90.7 | 73.7  | 76.2 | 54.2  | 50.7      | 65.6    |\n| Oracle      |        | 66.9 | 97.4 | 78.6     | 88.1     | 50.7 | 50.5  | 46.2 | 51.3          | 64.4         | 88.1       | 55.3    | 87.4 | 70.9   | 52.7  | 91.6 | 72.4  | 73.2 | 31.8  | 52.2      | 67.4    |\n\n## Visualization\nIf you want to visualize the segmentation results during training, you should set ``--debug``.\n\n```\nCUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes --log logs/src_only/gtav2cityscapes --debug\n```\n\nThen you can find images, predictions and labels in directory ``logs/src_only/gtav2cityscapes/visualize/``.\n\n<img src=\"./fig/segmentation_image.png\" width=\"300\"/>\n<img src=\"./fig/segmentation_pred.png\" width=\"300\"/>\n<img src=\"./fig/segmentation_label.png\" width=\"300\"/>\n\n\nTranslation model such as CycleGAN will save images by default. Here is the source-style images and its translated version.\n\n<img src=\"./fig/cyclegan_real_S.png\" width=\"300\"/>\n<img src=\"./fig/cyclegan_fake_T.png\" width=\"300\"/>\n\n\n## TODO\nSupport methods: AdaptSeg\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{CycleGAN,\n    title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},\n    author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},\n    booktitle={ICCV},\n    year={2017}\n}\n\n@inproceedings{cycada,\n    title={Cycada: Cycle-consistent adversarial domain adaptation},\n    author={Hoffman, Judy and Tzeng, Eric and Park, Taesung and Zhu, Jun-Yan and Isola, Phillip and Saenko, Kate and Efros, Alexei and Darrell, Trevor},\n    booktitle={ICML},\n    year={2018},\n}\n\n@inproceedings{Advent,\n    author = {Vu, Tuan-Hung and Jain, Himalaya and Bucher, Maxime and Cord, Matthieu and Perez, Patrick},\n    title = {ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation},\n    booktitle = {CVPR},\n    year = {2019}\n}\n\n@inproceedings{FDA,\n    author    = {Yanchao Yang and\n               Stefano Soatto},\n    title     = {{FDA:} Fourier Domain Adaptation for Semantic Segmentation},\n    booktitle = {CVPR},\n    year = {2020}\n}\n```\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/advent.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nfrom PIL import Image\nimport numpy as np\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD, Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nsys.path.append('../../..')\nfrom tllib.alignment.advent import Discriminator, DomainAdversarialEntropyLoss\nimport tllib.vision.models.segmentation as models\nimport tllib.vision.datasets.segmentation as datasets\nimport tllib.vision.transforms.segmentation as T\nfrom tllib.vision.transforms import DeNormalizeAndTranspose\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter, Meter\nfrom tllib.utils.logger import CompleteLogger\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(\n        root=args.source_root,\n        transforms=T.Compose([\n            T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),\n            T.ColorJitter(brightness=0.3, contrast=0.3),\n            T.RandomHorizontalFlip(),\n            T.NormalizeAndTranspose(),\n        ]),\n    )\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    target_dataset = datasets.__dict__[args.target]\n    train_target_dataset = target_dataset(\n        root=args.target_root,\n        transforms=T.Compose([\n            T.RandomResizedCrop(size=args.train_size, ratio=(2., 2.), scale=(0.5, 1.)),\n            T.RandomHorizontalFlip(),\n            T.NormalizeAndTranspose(),\n        ]),\n    )\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_target_dataset = target_dataset(\n        root=args.target_root, split='val',\n        transforms=T.Compose([\n            T.Resize(image_size=args.test_input_size, label_size=args.test_output_size),\n            T.NormalizeAndTranspose(),\n        ]),\n    )\n    val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    num_classes = train_source_dataset.num_classes\n    model = models.__dict__[args.arch](num_classes=num_classes).to(device)\n    discriminator = Discriminator(num_classes=num_classes).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)\n    optimizer_d = Adam(discriminator.parameters(), lr=args.lr_d, betas=(0.9, 0.99))\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power))\n    lr_scheduler_d = LambdaLR(optimizer_d, lambda x: (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power))\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n        discriminator.load_state_dict(checkpoint['discriminator'])\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n        optimizer_d.load_state_dict(checkpoint['optimizer_d'])\n        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    # define loss function (criterion)\n    criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device)\n    dann = DomainAdversarialEntropyLoss(discriminator)\n    interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)\n    interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True)\n\n    # define visualization function\n    decode = train_source_dataset.decode_target\n\n    def visualize(image, pred, label, prefix):\n        \"\"\"\n        Args:\n            image (tensor): 3 x H x W\n            pred (tensor): C x H x W\n            label (tensor): H x W\n            prefix: prefix of the saving image\n        \"\"\"\n        image = image.detach().cpu().numpy()\n        pred = pred.detach().max(dim=0)[1].cpu().numpy()\n        label = label.cpu().numpy()\n        for tensor, name in [\n            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), \"image\"),\n            (decode(label), \"label\"),\n            (decode(pred), \"pred\")\n        ]:\n            tensor.save(logger.get_image_path(\"{}_{}.png\".format(prefix, name)))\n\n    if args.phase == 'test':\n        confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args)\n        print(confmat)\n        return\n\n    # start training\n    best_iou = 0.\n    for epoch in range(args.start_epoch, args.epochs):\n        logger.set_epoch(epoch)\n        print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())\n        # train for one epoch\n        train(train_source_iter, train_target_iter, model, interp_train, criterion, dann, optimizer,\n              lr_scheduler, optimizer_d, lr_scheduler_d, epoch, visualize if args.debug else None, args)\n\n        # evaluate on validation set\n        confmat = validate(val_target_loader, model, interp_val, criterion, None, args)\n        print(confmat.format(train_source_dataset.classes))\n        acc_global, acc, iu = confmat.compute()\n\n        # calculate the mean iou over partial classes\n        indexes = [train_source_dataset.classes.index(name) for name\n                   in train_source_dataset.evaluate_classes]\n        iu = iu[indexes]\n        mean_iou = iu.mean()\n\n        # remember best acc@1 and save checkpoint\n        torch.save(\n            {\n                'model': model.state_dict(),\n                'discriminator': discriminator.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'optimizer_d': optimizer_d.state_dict(),\n                'lr_scheduler': lr_scheduler.state_dict(),\n                'lr_scheduler_d': lr_scheduler_d.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n        if mean_iou > best_iou:\n            shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))\n        best_iou = max(best_iou, mean_iou)\n        print(\"Target: {} Best: {}\".format(mean_iou, best_iou))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model, interp, criterion, dann,\n          optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,\n          epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_transfer = AverageMeter('Loss (transfer)', ':3.2f')\n    losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f')\n    accuracies_s = Meter('Acc (s)', ':3.2f')\n    accuracies_t = Meter('Acc (t)', ':3.2f')\n    iou_s = Meter('IoU (s)', ':3.2f')\n    iou_t = Meter('IoU (t)', ':3.2f')\n\n    confmat_s = ConfusionMatrix(model.num_classes)\n    confmat_t = ConfusionMatrix(model.num_classes)\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s, losses_transfer, losses_discriminator,\n         accuracies_s, accuracies_t, iou_s, iou_t],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        x_s, label_s = next(train_source_iter)\n        x_t, label_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        label_s = label_s.long().to(device)\n        x_t = x_t.to(device)\n        label_t = label_t.long().to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        optimizer.zero_grad()\n        optimizer_d.zero_grad()\n\n        # Step 1: Train the segmentation network, freeze the discriminator\n        dann.eval()\n        y_s = model(x_s)\n        pred_s = interp(y_s)\n        loss_cls_s = criterion(pred_s, label_s)\n        loss_cls_s.backward()\n\n        # adversarial training to fool the discriminator\n        y_t = model(x_t)\n        pred_t = interp(y_t)\n        loss_transfer = dann(pred_t, 'source')\n        (loss_transfer * args.trade_off).backward()\n\n        # Step 2: Train the discriminator\n        dann.train()\n        loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') + dann(pred_t.detach(), 'target'))\n        loss_discriminator.backward()\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        optimizer_d.step()\n        lr_scheduler.step()\n        lr_scheduler_d.step()\n\n        # measure accuracy and record loss\n        losses_s.update(loss_cls_s.item(), x_s.size(0))\n        losses_transfer.update(loss_transfer.item(), x_s.size(0))\n        losses_discriminator.update(loss_discriminator.item(), x_s.size(0))\n\n        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())\n        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())\n        acc_global_s, acc_s, iu_s = confmat_s.compute()\n        acc_global_t, acc_t, iu_t = confmat_t.compute()\n        accuracies_s.update(acc_s.mean().item())\n        accuracies_t.update(acc_t.mean().item())\n        iou_s.update(iu_s.mean().item())\n        iou_t.update(iu_t.mean().item())\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n            if visualize is not None:\n                visualize(x_s[0], pred_s[0], label_s[0], \"source_{}\".format(i))\n                visualize(x_t[0], pred_t[0], label_t[0], \"target_{}\".format(i))\n\n\ndef validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    acc = Meter('Acc', ':3.2f')\n    iou = Meter('IoU', ':3.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, acc, iou],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    confmat = ConfusionMatrix(model.num_classes)\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (x, label) in enumerate(val_loader):\n            x = x.to(device)\n            label = label.long().to(device)\n\n            # compute output\n            output = interp(model(x))\n            loss = criterion(output, label)\n\n            # measure accuracy and record loss\n            losses.update(loss.item(), x.size(0))\n            confmat.update(label.flatten(), output.argmax(1).flatten())\n            acc_global, accs, iu = confmat.compute()\n            acc.update(accs.mean().item())\n            iou.update(iu.mean().item())\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n                if visualize is not None:\n                    visualize(x[0], output[0], label[0], \"val_{}\".format(i))\n\n    return confmat\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='ADVENT for Segmentation Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),\n                        help='the resize ratio for the random resize crop')\n    parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input and output image size during training')\n    parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input image size during test')\n    parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024),\n                        help='the output image size during test')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: deeplabv2_resnet101)')\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    parser.add_argument('--trade-off', type=float, default=0.001,\n                        help='trade-off parameter for the advent loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=2, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 2)')\n    parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument(\"--momentum\", type=float, default=0.9, help=\"Momentum component of the optimiser.\")\n    parser.add_argument(\"--weight-decay\", type=float, default=0.0005, help=\"Regularisation parameter for L2-loss.\")\n    parser.add_argument(\"--lr-power\", type=float, default=0.9,\n                        help=\"Decay parameter to compute the learning rate (only for deeplab).\")\n    parser.add_argument(\"--lr-d\", default=1e-4, type=float,\n                        metavar='LR', help='initial learning rate for discriminator')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--ignore-label\", type=int, default=255,\n                        help=\"The index of the label to ignore during the training.\")\n    parser.add_argument(\"--log\", type=str, default='advent',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--debug', action=\"store_true\",\n                        help='In the debug mode, save images and predictions during training')\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/advent.sh",
    "content": "# GTA5 to Cityscapes\nCUDA_VISIBLE_DEVICES=0 python advent.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \\\n    --log logs/advent/gtav2cityscapes\n\n# Synthia to Cityscapes\nCUDA_VISIBLE_DEVICES=0 python advent.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \\\n    --log logs/advent/synthia2cityscapes\n\n# Cityscapes to Foggy\nCUDA_VISIBLE_DEVICES=0 python advent.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \\\n    --log logs/advent/cityscapes2foggy"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/cycada.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport itertools\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nfrom torchvision.transforms import ToPILImage, Compose, Lambda\n\nsys.path.append('../../..')\nimport tllib.translation.cyclegan as cyclegan\nfrom tllib.translation.cyclegan.util import ImagePool, set_requires_grad\nfrom tllib.translation.cycada import SemanticConsistency\nimport tllib.vision.models.segmentation as models\nimport tllib.vision.datasets.segmentation as datasets\nfrom tllib.vision.transforms import Denormalize, NormalizeAndTranspose\nimport tllib.vision.transforms.segmentation as T\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = T.Compose([\n        T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),\n        T.RandomHorizontalFlip(),\n        T.ToTensor(),\n        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n    ])\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    target_dataset = datasets.__dict__[args.target]\n    train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # define networks (both generators and discriminators)\n    netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n    netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n\n    # create image buffer to store previously generated images\n    fake_S_pool = ImagePool(args.pool_size)\n    fake_T_pool = ImagePool(args.pool_size)\n\n    # define optimizer and lr scheduler\n    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999))\n    optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))\n    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)\n    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)\n    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        print(\"Resume from\", args.resume)\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        netG_S2T.load_state_dict(checkpoint['netG_S2T'])\n        netG_T2S.load_state_dict(checkpoint['netG_T2S'])\n        netD_S.load_state_dict(checkpoint['netD_S'])\n        netD_T.load_state_dict(checkpoint['netD_T'])\n        optimizer_G.load_state_dict(checkpoint['optimizer_G'])\n        optimizer_D.load_state_dict(checkpoint['optimizer_D'])\n        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])\n        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    if args.phase == 'test':\n        transform = T.Compose([\n            T.Resize(image_size=args.test_input_size),\n            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),\n        ])\n        train_source_dataset.translate(transform, args.translated_root)\n        return\n\n    # define loss function\n    criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()\n    criterion_cycle = nn.L1Loss()\n    criterion_identity = nn.L1Loss()\n    criterion_semantic = SemanticConsistency(ignore_index=[args.ignore_label]+train_source_dataset.ignore_classes).to(device)\n    interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)\n\n    # define segmentation model and predict function\n    model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device)\n    if args.pretrain:\n        print(\"Loading pretrain segmentation model from\", args.pretrain)\n        checkpoint = torch.load(args.pretrain, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n    model.eval()\n\n    cycle_gan_tensor_to_segmentation_tensor = Compose([\n        Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n        Lambda(lambda image: image.mul(255).permute((1, 2, 0))),\n        NormalizeAndTranspose(),\n    ])\n\n    def predict(image):\n        image = cycle_gan_tensor_to_segmentation_tensor(image.squeeze())\n        image = image.unsqueeze(dim=0).to(device)\n        prediction = model(image)\n        return interp_train(prediction)\n\n    # define visualization function\n    tensor_to_image = Compose([\n        Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n        ToPILImage()\n    ])\n    decode = train_source_dataset.decode_target\n\n    def visualize(image, name, pred=None):\n        \"\"\"\n        Args:\n            image (tensor): image in shape 3 x H x W\n            name: name of the saving image\n            pred (tensor): predictions in shape C x H x W\n        \"\"\"\n        tensor_to_image(image).save(logger.get_image_path(\"{}.png\".format(name)))\n        if pred is not None:\n            pred = pred.detach().max(dim=0).indices.cpu().numpy()\n            pred = decode(pred)\n            pred.save(logger.get_image_path(\"pred_{}.png\".format(name)))\n\n    # start training\n    for epoch in range(args.start_epoch, args.epochs+args.epochs_decay):\n        logger.set_epoch(epoch)\n        print(lr_scheduler_G.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, predict,\n              criterion_gan, criterion_cycle, criterion_identity, criterion_semantic, optimizer_G, optimizer_D,\n              fake_S_pool, fake_T_pool, epoch, visualize, args)\n\n        # update learning rates\n        lr_scheduler_G.step()\n        lr_scheduler_D.step()\n\n        # save checkpoint\n        torch.save(\n            {\n                'netG_S2T': netG_S2T.state_dict(),\n                'netG_T2S': netG_T2S.state_dict(),\n                'netD_S': netD_S.state_dict(),\n                'netD_T': netD_T.state_dict(),\n                'optimizer_G': optimizer_G.state_dict(),\n                'optimizer_D': optimizer_D.state_dict(),\n                'lr_scheduler_G': lr_scheduler_G.state_dict(),\n                'lr_scheduler_D': lr_scheduler_D.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n\n    if args.translated_root is not None:\n        transform = T.Compose([\n            T.Resize(image_size=args.test_input_size),\n            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),\n        ])\n        train_source_dataset.translate(transform, args.translated_root)\n\n    logger.close()\n\n\ndef train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, predict,\n          criterion_gan, criterion_cycle, criterion_identity, criterion_semantic,\n          optimizer_G, optimizer_D, fake_S_pool, fake_T_pool,\n          epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')\n    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')\n    losses_D_S = AverageMeter('D_S', ':3.2f')\n    losses_D_T = AverageMeter('D_T', ':3.2f')\n    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')\n    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')\n    losses_identity_S = AverageMeter('idt_S', ':3.2f')\n    losses_identity_T = AverageMeter('idt_T', ':3.2f')\n    losses_semantic_S2T = AverageMeter('sem_S2T', ':3.2f')\n    losses_semantic_T2S = AverageMeter('sem_T2S', ':3.2f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,\n         losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T,\n         losses_semantic_S2T, losses_semantic_T2S],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        real_S, label_s = next(train_source_iter)\n        real_T, _ = next(train_target_iter)\n\n        real_S = real_S.to(device)\n        real_T = real_T.to(device)\n        label_s = label_s.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # Compute fake images and reconstruction images.\n        fake_T = netG_S2T(real_S)\n        rec_S = netG_T2S(fake_T)\n        fake_S = netG_T2S(real_T)\n        rec_T = netG_S2T(fake_S)\n\n        # Optimizing generators\n        # discriminators require no gradients\n        set_requires_grad(netD_S, False)\n        set_requires_grad(netD_T, False)\n\n        optimizer_G.zero_grad()\n        # GAN loss D_T(G_S2T(S))\n        loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)\n        # GAN loss D_S(G_T2S(B))\n        loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)\n        # Cycle loss || G_T2S(G_S2T(S)) - S||\n        loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle\n        # Cycle loss || G_S2T(G_T2S(T)) - T||\n        loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle\n        # Identity loss\n        # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||\n        identity_T = netG_S2T(real_T)\n        loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity\n        # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||\n        identity_S = netG_T2S(real_S)\n        loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity\n        # Semantic loss\n        pred_fake_T = predict(fake_T)\n        pred_real_S = predict(real_S)\n        loss_semantic_S2T = criterion_semantic(pred_fake_T, label_s) * args.trade_off_semantic\n        pred_fake_S = predict(fake_S)\n        pred_real_T = predict(real_T)\n        loss_semantic_T2S = criterion_semantic(pred_fake_S, pred_real_T.max(1).indices) * args.trade_off_semantic\n        # combined loss and calculate gradients\n        loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + \\\n                 loss_identity_S + loss_identity_T + loss_semantic_S2T + loss_semantic_T2S\n        loss_G.backward()\n        optimizer_G.step()\n\n        # Optimize discriminator\n        set_requires_grad(netD_S, True)\n        set_requires_grad(netD_T, True)\n        optimizer_D.zero_grad()\n        # Calculate GAN loss for discriminator D_S\n        fake_S_ = fake_S_pool.query(fake_S.detach())\n        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))\n        loss_D_S.backward()\n        # Calculate GAN loss for discriminator D_T\n        fake_T_ = fake_T_pool.query(fake_T.detach())\n        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))\n        loss_D_T.backward()\n        optimizer_D.step()\n\n        # measure elapsed time\n        losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))\n        losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))\n        losses_D_S.update(loss_D_S.item(), real_S.size(0))\n        losses_D_T.update(loss_D_T.item(), real_S.size(0))\n        losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))\n        losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))\n        losses_identity_S.update(loss_identity_S.item(), real_S.size(0))\n        losses_identity_T.update(loss_identity_T.item(), real_S.size(0))\n        losses_semantic_S2T.update(loss_semantic_S2T.item(), real_S.size(0))\n        losses_semantic_T2S.update(loss_semantic_T2S.item(), real_S.size(0))\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n            for image, prediction, name in zip([real_S, real_T, fake_S, fake_T],\n                                               [pred_real_S, pred_real_T, pred_fake_S, pred_fake_T],\n                                               [\"real_S\", \"real_T\", \"fake_S\", \"fake_T\"]):\n                visualize(image[0], \"{}_{}\".format(i, name), prediction[0])\n            for image, name in zip([rec_S, rec_T, identity_S, identity_T],\n                                   [\"rec_S\", \"rec_T\", \"identity_S\", \"identity_T\"]):\n                visualize(image[0], \"{}_{}\".format(i, name))\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    # dataset parameters\n    parser = argparse.ArgumentParser(description='Cycada for Segmentation Domain Adaptation')\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),\n                        help='the resize ratio for the random resize crop')\n    parser.add_argument('--train-size', nargs='+', type=int, default=(512, 256),\n                        help='the input and output image size during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: deeplabv2_resnet101)')\n    parser.add_argument('--pretrain', type=str, default=None,\n                        help='pretrain checkpoints for segementation model')\n    parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n    parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n    parser.add_argument('--netD', type=str, default='patch',\n                        help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')\n    parser.add_argument('--netG', type=str, default='unet_256',\n                        help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')\n    parser.add_argument('--norm', type=str, default='instance',\n                        help='instance normalization or batch normalization [instance | batch | none]')\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Where restore cyclegan model parameters from.\")\n    parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')\n    parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')\n    parser.add_argument('--trade-off-semantic', type=float, default=1.0, help='trade off for semantic loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=1, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 1)')\n    parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n    parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--epochs-decay', type=int, default=20,\n                        help='number of epochs to linearly decay learning rate to zero')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('-i', '--iters-per-epoch', default=5000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('--pool-size', type=int, default=50,\n                        help='the size of image buffer that stores previously generated images')\n    parser.add_argument('-p', '--print-freq', default=400, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--ignore-label\", type=int, default=255,\n                        help=\"The index of the label to ignore during the training.\")\n    parser.add_argument(\"--log\", type=str, default='cycada',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    # test parameters\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--translated-root', type=str, default=None,\n                        help=\"The root to put the translated dataset\")\n    parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input image size during test')\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/cycada.sh",
    "content": "# GTA5 to Cityscapes\n# First, train the CycleGAN\nCUDA_VISIBLE_DEVICES=0 python cycada.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \\\n    --log logs/cycada/gtav2cityscapes --pretrain logs/src_only/gtav2cityscapes/checkpoints/59.pth \\\n    --translated-root data/GTA52Cityscapes/cycada_39\n# Then, train the src_only model on the translated source dataset\nCUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/cycada_39 data/Cityscapes \\\n    -s GTA5 -t Cityscapes --log logs/cycada_src_only/gtav2cityscapes\n\n\n## Synthia to Cityscapes\n# First, train the Cycada\nCUDA_VISIBLE_DEVICES=0 python cycada.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \\\n    --log logs/cycada/synthia2cityscapes --pretrain logs/src_only/synthia2cityscapes/checkpoints/59.pth \\\n    --translated-root data/Synthia2Cityscapes/cycada_39\n# Then, train the src_only model on the translated source dataset\nCUDA_VISIBLE_DEVICES=0 python source_only.py data/Synthia2Cityscapes/cycada_39 data/Cityscapes \\\n    -s Synthia -t Cityscapes --log logs/cycada_src_only/synthia2cityscapes\n\n\n# Cityscapes to FoggyCityscapes\n# First, train the CycleGAN\nCUDA_VISIBLE_DEVICES=0 python cycada.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \\\n    --log logs/cycada/cityscapes2foggy --pretrain logs/src_only/cityscapes2foggy/checkpoints/59.pth \\\n    --translated-root data/Cityscapes2Foggy/cycada_39\n# Then, train the src_only model on the translated source dataset\nCUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/cycada_39 data/Cityscapes \\\n    -s Cityscapes -t FoggyCityscapes --log logs/cycada_src_only/cityscapes2foggy\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/cycle_gan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport itertools\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\nfrom torchvision.transforms import ToPILImage, Compose\n\nsys.path.append('../../..')\nimport tllib.translation.cyclegan as cyclegan\nfrom tllib.translation.cyclegan.util import ImagePool, set_requires_grad\nimport tllib.vision.datasets.segmentation as datasets\nfrom tllib.vision.transforms import Denormalize\nimport tllib.vision.transforms.segmentation as T\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = T.Compose([\n        T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),\n        T.RandomHorizontalFlip(),\n        T.ToTensor(),\n        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n    ])\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform)\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    target_dataset = datasets.__dict__[args.target]\n    train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform)\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # define networks (both generators and discriminators)\n    netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)\n    netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n    netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)\n\n    # create image buffer to store previously generated images\n    fake_S_pool = ImagePool(args.pool_size)\n    fake_T_pool = ImagePool(args.pool_size)\n\n    # define optimizer and lr scheduler\n    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999))\n    optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))\n    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)\n    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)\n    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        print(\"Resume from\", args.resume)\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        netG_S2T.load_state_dict(checkpoint['netG_S2T'])\n        netG_T2S.load_state_dict(checkpoint['netG_T2S'])\n        netD_S.load_state_dict(checkpoint['netD_S'])\n        netD_T.load_state_dict(checkpoint['netD_T'])\n        optimizer_G.load_state_dict(checkpoint['optimizer_G'])\n        optimizer_D.load_state_dict(checkpoint['optimizer_D'])\n        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])\n        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    if args.phase == 'test':\n        transform = T.Compose([\n            T.Resize(image_size=args.test_input_size),\n            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),\n        ])\n        train_source_dataset.translate(transform, args.translated_root)\n        return\n\n    # define loss function\n    criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()\n    criterion_cycle = nn.L1Loss()\n    criterion_identity = nn.L1Loss()\n\n    # define visualization function\n    tensor_to_image = Compose([\n        Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n        ToPILImage()\n    ])\n\n    def visualize(image, name):\n        \"\"\"\n        Args:\n            image (tensor): image in shape 3 x H x W\n            name: name of the saving image\n        \"\"\"\n        tensor_to_image(image).save(logger.get_image_path(\"{}.png\".format(name)))\n\n    # start training\n    for epoch in range(args.start_epoch, args.epochs+args.epochs_decay):\n        logger.set_epoch(epoch)\n        print(lr_scheduler_G.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,\n              criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,\n              fake_S_pool, fake_T_pool, epoch, visualize, args)\n\n        # update learning rates\n        lr_scheduler_G.step()\n        lr_scheduler_D.step()\n\n        # save checkpoint\n        torch.save(\n            {\n                'netG_S2T': netG_S2T.state_dict(),\n                'netG_T2S': netG_T2S.state_dict(),\n                'netD_S': netD_S.state_dict(),\n                'netD_T': netD_T.state_dict(),\n                'optimizer_G': optimizer_G.state_dict(),\n                'optimizer_D': optimizer_D.state_dict(),\n                'lr_scheduler_G': lr_scheduler_G.state_dict(),\n                'lr_scheduler_D': lr_scheduler_D.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n\n    if args.translated_root is not None:\n        transform = T.Compose([\n            T.Resize(image_size=args.test_input_size),\n            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),\n        ])\n        train_source_dataset.translate(transform, args.translated_root)\n\n    logger.close()\n\n\ndef train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,\n          criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,\n          fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_G_S2T = AverageMeter('G_S2T', ':3.2f')\n    losses_G_T2S = AverageMeter('G_T2S', ':3.2f')\n    losses_D_S = AverageMeter('D_S', ':3.2f')\n    losses_D_T = AverageMeter('D_T', ':3.2f')\n    losses_cycle_S = AverageMeter('cycle_S', ':3.2f')\n    losses_cycle_T = AverageMeter('cycle_T', ':3.2f')\n    losses_identity_S = AverageMeter('idt_S', ':3.2f')\n    losses_identity_T = AverageMeter('idt_T', ':3.2f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,\n         losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        real_S, _ = next(train_source_iter)\n        real_T, _ = next(train_target_iter)\n\n        real_S = real_S.to(device)\n        real_T = real_T.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # Compute fake images and reconstruction images.\n        fake_T = netG_S2T(real_S)\n        rec_S = netG_T2S(fake_T)\n        fake_S = netG_T2S(real_T)\n        rec_T = netG_S2T(fake_S)\n\n        # Optimizing generators\n        # discriminators require no gradients\n        set_requires_grad(netD_S, False)\n        set_requires_grad(netD_T, False)\n\n        optimizer_G.zero_grad()\n        # GAN loss D_T(G_S2T(S))\n        loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)\n        # GAN loss D_S(G_T2S(B))\n        loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)\n        # Cycle loss || G_T2S(G_S2T(S)) - S||\n        loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle\n        # Cycle loss || G_S2T(G_T2S(T)) - T||\n        loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle\n        # Identity loss\n        # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||\n        identity_T = netG_S2T(real_T)\n        loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity\n        # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||\n        identity_S = netG_T2S(real_S)\n        loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity\n        # combined loss and calculate gradients\n        loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T\n        loss_G.backward()\n        optimizer_G.step()\n\n        # Optimize discriminator\n        set_requires_grad(netD_S, True)\n        set_requires_grad(netD_T, True)\n        optimizer_D.zero_grad()\n        # Calculate GAN loss for discriminator D_S\n        fake_S_ = fake_S_pool.query(fake_S.detach())\n        loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))\n        loss_D_S.backward()\n        # Calculate GAN loss for discriminator D_T\n        fake_T_ = fake_T_pool.query(fake_T.detach())\n        loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))\n        loss_D_T.backward()\n        optimizer_D.step()\n\n        # measure elapsed time\n        losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))\n        losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))\n        losses_D_S.update(loss_D_S.item(), real_S.size(0))\n        losses_D_T.update(loss_D_T.item(), real_S.size(0))\n        losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))\n        losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))\n        losses_identity_S.update(loss_identity_S.item(), real_S.size(0))\n        losses_identity_T.update(loss_identity_T.item(), real_S.size(0))\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n            for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T],\n                                    [\"real_S\", \"real_T\", \"fake_S\", \"fake_T\", \"rec_S\",\n                                     \"rec_T\", \"identity_S\", \"identity_T\"]):\n                visualize(tensor[0], \"{}_{}\".format(i, name))\n\n\nif __name__ == '__main__':\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description='CycleGAN for Segmentation Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),\n                        help='the resize ratio for the random resize crop')\n    parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input and output image size during training')\n    # model parameters\n    parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n    parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n    parser.add_argument('--netD', type=str, default='patch',\n                        help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')\n    parser.add_argument('--netG', type=str, default='unet_256',\n                        help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')\n    parser.add_argument('--norm', type=str, default='instance',\n                        help='instance normalization or batch normalization [instance | batch | none]')\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')\n    parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=1, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 1)')\n    parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n    parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--epochs-decay', type=int, default=20,\n                        help='number of epochs to linearly decay learning rate to zero')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('-i', '--iters-per-epoch', default=5000, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('--pool-size', type=int, default=50,\n                        help='the size of image buffer that stores previously generated images')\n    parser.add_argument('-p', '--print-freq', default=400, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='cyclegan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    # test parameters\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--translated-root', type=str, default=None,\n                        help=\"The root to put the translated dataset\")\n    parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input image size during test')\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/cycle_gan.sh",
    "content": "# GTA5 to Cityscapes\n# First, train the CycleGAN\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \\\n    --log logs/cyclegan/gtav2cityscapes --translated-root data/GTA52Cityscapes/CycleGAN_39\n# Then, train the src_only model on the translated source dataset\nCUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/CycleGAN_39 data/Cityscapes \\\n    -s GTA5 -t Cityscapes --log logs/cyclegan_src_only/gtav2cityscapes\n\n\n# Cityscapes to FoggyCityscapes\n# First, train the CycleGAN\nCUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \\\n    --log logs/cyclegan/cityscapes2foggy --translated-root data/Cityscapes2Foggy/CycleGAN_39\n# Then, train the src_only model on the translated source dataset\nCUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/CycleGAN_39 data/Cityscapes \\\n    -s Cityscapes -t FoggyCityscapes --log logs/cyclegan_src_only/cityscapes2foggy\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/erm.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nfrom PIL import Image\nimport numpy as np\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nsys.path.append('../../..')\nimport tllib.vision.models.segmentation as models\nimport tllib.vision.datasets.segmentation as datasets\nimport tllib.vision.transforms.segmentation as T\nfrom tllib.vision.transforms import DeNormalizeAndTranspose\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter, Meter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(\n        root=args.source_root,\n        transforms=T.Compose([\n            T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=args.resize_scale),\n            T.ColorJitter(brightness=0.3, contrast=0.3),\n            T.RandomHorizontalFlip(),\n            T.NormalizeAndTranspose(),\n        ]),\n    )\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    target_dataset = datasets.__dict__[args.target]\n    val_target_dataset = target_dataset(\n        root=args.target_root, split='val',\n        transforms=T.Compose([\n            T.Resize(image_size=args.test_input_size, label_size=args.test_output_size),\n            T.NormalizeAndTranspose(),\n        ]),\n    )\n    val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n\n    # create model\n    model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device)\n    # define optimizer and lr scheduler\n    optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)\n    lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch)\n                                                 ** (args.lr_power))\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    # define loss function (criterion)\n    criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device)\n    interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)\n    interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True)\n\n    # define visualization function\n    decode = train_source_dataset.decode_target\n\n    def visualize(image, pred, label, prefix):\n        \"\"\"\n        Args:\n            image (tensor): 3 x H x W\n            pred (tensor): C x H x W\n            label (tensor): H x W\n            prefix: prefix of the saving image\n        \"\"\"\n        image = image.detach().cpu().numpy()\n        pred = pred.detach().max(dim=0)[1].cpu().numpy()\n        label = label.cpu().numpy()\n        for tensor, name in [\n            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), \"image\"),\n            (decode(label), \"label\"),\n            (decode(pred), \"pred\")\n        ]:\n            tensor.save(logger.get_image_path(\"{}_{}.png\".format(prefix, name)))\n\n    if args.phase == 'test':\n        confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args)\n        print(confmat)\n        return\n\n    # start training\n    best_iou = 0.\n    for epoch in range(args.start_epoch, args.epochs):\n        logger.set_epoch(epoch)\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, model, interp_train, criterion, optimizer,\n              lr_scheduler, epoch, visualize if args.debug else None, args)\n\n        # evaluate on validation set\n        confmat = validate(val_target_loader, model, interp_val, criterion, visualize if args.debug else None, args)\n        print(confmat.format(train_source_dataset.classes))\n        acc_global, acc, iu = confmat.compute()\n\n        # calculate the mean iou over partial classes\n        indexes = [train_source_dataset.classes.index(name) for name\n                   in train_source_dataset.evaluate_classes]\n        iu = iu[indexes]\n        mean_iou = iu.mean()\n\n        # remember best iou and save checkpoint\n        torch.save(\n            {\n                'model': model.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'lr_scheduler': lr_scheduler.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n        if mean_iou > best_iou:\n            shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))\n        best_iou = max(best_iou, mean_iou)\n        print(\"Target: {} Best: {}\".format(mean_iou, best_iou))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    accuracies_s = Meter('Acc (s)', ':3.2f')\n    iou_s = Meter('IoU (s)', ':3.2f')\n\n    confmat_s = ConfusionMatrix(model.num_classes)\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s,\n         accuracies_s, iou_s],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, label_s = next(train_source_iter)\n        x_s = x_s.to(device)\n        label_s = label_s.long().to(device)\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s = model(x_s)\n        pred_s = interp(y_s)\n        loss_cls_s = criterion(pred_s, label_s)\n        loss_cls_s.backward()\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure accuracy and record loss\n        losses_s.update(loss_cls_s.item(), x_s.size(0))\n        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())\n        acc_global_s, acc_s, iu_s = confmat_s.compute()\n        accuracies_s.update(acc_s.mean().item())\n        iou_s.update(iu_s.mean().item())\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n            if visualize is not None:\n                visualize(x_s[0], pred_s[0], label_s[0], \"source_{}\".format(i))\n\n\ndef validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    acc = Meter('Acc', ':3.2f')\n    iou = Meter('IoU', ':3.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, acc, iou],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    confmat = ConfusionMatrix(model.num_classes)\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (x, label) in enumerate(val_loader):\n            x = x.to(device)\n            label = label.long().to(device)\n\n            # compute output\n            output = interp(model(x))\n            loss = criterion(output, label)\n\n            # measure accuracy and record loss\n            losses.update(loss.item(), x.size(0))\n            confmat.update(label.flatten(), output.argmax(1).flatten())\n            acc_global, accs, iu = confmat.compute()\n            acc.update(accs.mean().item())\n            iou.update(iu.mean().item())\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n                if visualize is not None:\n                    visualize(x[0], output[0], label[0], \"val_{}\".format(i))\n\n    return confmat\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='Source Only for Segmentation Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),\n                        help='the resize ratio for the random resize crop')\n    parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.5, 1.),\n                        help='the resize scale for the random resize crop')\n    parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input and output image size during training')\n    parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input image size during test')\n    parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024),\n                        help='the output image size during test')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: deeplabv2_resnet101)')\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=2, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 2)')\n    parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument(\"--momentum\", type=float, default=0.9, help=\"Momentum component of the optimiser.\")\n    parser.add_argument(\"--weight-decay\", type=float, default=0.0005,\n                        help=\"Regularisation parameter for L2-loss.\")\n    parser.add_argument(\"--lr-power\", type=float, default=0.9,\n                        help=\"Decay parameter to compute the learning rate (only for deeplab).\")\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--ignore-label\", type=int, default=255,\n                        help=\"The index of the label to ignore during the training.\")\n    parser.add_argument(\"--log\", type=str, default='src_only',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--debug', action=\"store_true\",\n                        help='In the debug mode, save images and predictions during training')\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/erm.sh",
    "content": "# Source Only\n# GTA5 to Cityscapes\nCUDA_VISIBLE_DEVICES=0 python erm.py data/GTA5 data/Cityscapes \\\n    -s GTA5 -t Cityscapes --log logs/erm/gtav2cityscapes\n\n# Synthia to Cityscapes\nCUDA_VISIBLE_DEVICES=0 python erm.py data/synthia data/Cityscapes \\\n    -s Synthia -t Cityscapes --log logs/erm/synthia2cityscapes\n\n# Cityscapes to FoggyCityscapes\nCUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \\\n    -s Cityscapes -t FoggyCityscapes --log logs/erm/cityscapes2foggy\n\n# Oracle\n# Oracle Results on Cityscapes\nCUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \\\n    -s Cityscapes -t Cityscapes --log logs/oracle/cityscapes\n\n# Oracle Results on Foggy Cityscapes\nCUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \\\n    -s FoggyCityscapes -t FoggyCityscapes --log logs/oracle/foggy_cityscapes\n\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/fda.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nfrom PIL import Image\nimport numpy as np\nimport os\nimport math\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nsys.path.append('../../..')\nfrom tllib.translation.fourier_transform import FourierTransform\nimport tllib.vision.models.segmentation as models\nimport tllib.vision.datasets.segmentation as datasets\nimport tllib.vision.transforms.segmentation as T\nfrom tllib.vision.transforms import DeNormalizeAndTranspose\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter, Meter\nfrom tllib.utils.logger import CompleteLogger\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef robust_entropy(y, ita=1.5, num_classes=19, reduction='mean'):\n    \"\"\" Robust entropy proposed in `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) <https://arxiv.org/abs/2004.05498>`_\n\n    Args:\n        y (tensor): logits output of segmentation model in shape of :math:`(N, C, H, W)`\n        ita (float, optional): parameters for robust entropy. Default: 1.5\n        num_classes (int, optional): number of classes. Default: 19\n        reduction (string, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output. Default: ``'mean'``\n\n    Returns:\n        Scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`.\n\n    \"\"\"\n    P = F.softmax(y, dim=1)\n    logP = F.log_softmax(y, dim=1)\n    PlogP = P * logP\n    ent = -1.0 * PlogP.sum(dim=1)\n    ent = ent / math.log(num_classes)\n\n    # compute robust entropy\n    ent = ent ** 2.0 + 1e-8\n    ent = ent ** ita\n\n    if reduction == 'mean':\n        return ent.mean()\n    else:\n        return ent\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    target_dataset = datasets.__dict__[args.target]\n    train_target_dataset = target_dataset(\n        root=args.target_root,\n        transforms=T.Compose([\n            T.Resize(image_size=args.train_size),\n            T.NormalizeAndTranspose(),\n        ]),\n    )\n    train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n    val_target_dataset = target_dataset(\n        root=args.target_root, split='val',\n        transforms=T.Compose([\n            T.Resize(image_size=args.test_input_size, label_size=args.test_output_size),\n            T.NormalizeAndTranspose(),\n        ])\n    )\n    val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True)\n\n    # collect the absolute paths of all images in the target dataset\n    target_image_list = train_target_dataset.collect_image_paths()\n    # build a fourier transform that translate source images to the target style\n    fourier_transform = T.wrapper(FourierTransform)(target_image_list, os.path.join(logger.root, \"amplitudes\"),\n                                                    rebuild=False, beta=args.beta)\n\n    source_dataset = datasets.__dict__[args.source]\n    train_source_dataset = source_dataset(\n        root=args.source_root,\n        transforms=T.Compose([\n            T.Resize((2048, 1024)),  # convert source image to the size of the target image before fourier transform\n            fourier_transform,\n            T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),\n            T.ColorJitter(brightness=0.3, contrast=0.3),\n            T.RandomHorizontalFlip(),\n            T.NormalizeAndTranspose(),\n        ]),\n    )\n    train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,\n                                     shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)\n\n    train_source_iter = ForeverDataIterator(train_source_loader)\n    train_target_iter = ForeverDataIterator(train_target_loader)\n\n    # create model\n    model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device)\n    # define optimizer and lr scheduler\n    optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)\n    lr_scheduler = LambdaLR(optimizer,\n                            lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power))\n\n    # optionally resume from a checkpoint\n    if args.resume:\n        checkpoint = torch.load(args.resume, map_location='cpu')\n        model.load_state_dict(checkpoint['model'])\n        optimizer.load_state_dict(checkpoint['optimizer'])\n        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n        args.start_epoch = checkpoint['epoch'] + 1\n\n    # define loss function (criterion)\n    criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device)\n    interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)\n    interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True)\n\n    # define visualization function\n    decode = train_source_dataset.decode_target\n    def visualize(image, pred, label, prefix):\n        \"\"\"\n        Args:\n            image (tensor): 3 x H x W\n            pred (tensor): C x H x W\n            label (tensor): H x W\n            prefix: prefix of the saving image\n        \"\"\"\n        image = image.detach().cpu().numpy()\n        pred = pred.detach().max(dim=0)[1].cpu().numpy()\n        label = label.cpu().numpy()\n        for tensor, name in [\n            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), \"image\"),\n            (decode(label), \"label\"),\n            (decode(pred), \"pred\")\n        ]:\n            tensor.save(logger.get_image_path(\"{}_{}.png\".format(prefix, name)))\n\n    if args.phase == 'test':\n        confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args)\n        print(confmat)\n        return\n\n    # start training\n    best_iou = 0.\n    for epoch in range(args.start_epoch, args.epochs):\n        logger.set_epoch(epoch)\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(train_source_iter, train_target_iter, model, interp_train, criterion, optimizer,\n              lr_scheduler, epoch, visualize if args.debug else None, args)\n\n        # evaluate on validation set\n        confmat = validate(val_target_loader, model, interp_val, criterion, None, args)\n        print(confmat.format(train_source_dataset.classes))\n        acc_global, acc, iu = confmat.compute()\n\n        # calculate the mean iou over partial classes\n        indexes = [train_source_dataset.classes.index(name) for name\n                   in train_source_dataset.evaluate_classes]\n        iu = iu[indexes]\n        mean_iou = iu.mean()\n\n        # remember best acc@1 and save checkpoint\n        torch.save(\n            {\n                'model': model.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'lr_scheduler': lr_scheduler.state_dict(),\n                'epoch': epoch,\n                'args': args\n            }, logger.get_checkpoint_path(epoch)\n        )\n        if mean_iou > best_iou:\n            shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))\n        best_iou = max(best_iou, mean_iou)\n        print(\"Target: {} Best: {}\".format(mean_iou, best_iou))\n\n    logger.close()\n\n\ndef train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,\n          model, interp, criterion, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_t = AverageMeter('Loss (t)', ':3.2f')\n    losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f')\n    accuracies_s = Meter('Acc (s)', ':3.2f')\n    accuracies_t = Meter('Acc (t)', ':3.2f')\n    iou_s = Meter('IoU (s)', ':3.2f')\n    iou_t = Meter('IoU (t)', ':3.2f')\n\n    confmat_s = ConfusionMatrix(model.num_classes)\n    confmat_t = ConfusionMatrix(model.num_classes)\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_s, losses_t, losses_entropy_t,\n         accuracies_s, accuracies_t, iou_s, iou_t],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        optimizer.zero_grad()\n\n        x_s, label_s = next(train_source_iter)\n        x_t, label_t = next(train_target_iter)\n\n        x_s = x_s.to(device)\n        label_s = label_s.long().to(device)\n        x_t = x_t.to(device)\n        label_t = label_t.long().to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s = model(x_s)\n        pred_s = interp(y_s)\n        loss_cls_s = criterion(pred_s, label_s)\n        loss_cls_s.backward()\n\n        y_t = model(x_t)\n        pred_t = interp(y_t)\n        loss_cls_t = criterion(pred_t, label_t)\n        loss_entropy_t = robust_entropy(y_t, args.ita)\n        (args.entropy_weight * loss_entropy_t).backward()\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure accuracy and record loss\n        losses_s.update(loss_cls_s.item(), x_s.size(0))\n        losses_t.update(loss_cls_t.item(), x_s.size(0))\n        losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0))\n\n        confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())\n        confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())\n        acc_global_s, acc_s, iu_s = confmat_s.compute()\n        acc_global_t, acc_t, iu_t = confmat_t.compute()\n        accuracies_s.update(acc_s.mean().item())\n        accuracies_t.update(acc_t.mean().item())\n        iou_s.update(iu_s.mean().item())\n        iou_t.update(iu_t.mean().item())\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n            if visualize is not None:\n                visualize(x_s[0], pred_s[0], label_s[0], \"source_{}\".format(i))\n                visualize(x_t[0], pred_t[0], label_t[0], \"target_{}\".format(i))\n\n\ndef validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    acc = Meter('Acc', ':3.2f')\n    iou = Meter('IoU', ':3.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, acc, iou],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    confmat = ConfusionMatrix(model.num_classes)\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (x, label) in enumerate(val_loader):\n            x = x.to(device)\n            label = label.long().to(device)\n\n            # compute output\n            output = interp(model(x))\n            loss = criterion(output, label)\n\n            # measure accuracy and record loss\n            losses.update(loss.item(), x.size(0))\n            confmat.update(label.flatten(), output.argmax(1).flatten())\n            acc_global, accs, iu = confmat.compute()\n            acc.update(accs.mean().item())\n            iou.update(iu.mean().item())\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n                if visualize is not None:\n                    visualize(x[0], output[0], label[0], \"val_{}\".format(i))\n\n    return confmat\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n    parser = argparse.ArgumentParser(description='FDA for Segmentation Domain Adaptation')\n    # dataset parameters\n    parser.add_argument('source_root', help='root path of the source dataset')\n    parser.add_argument('target_root', help='root path of the target dataset')\n    parser.add_argument('-s', '--source', help='source domain(s)')\n    parser.add_argument('-t', '--target', help='target domain(s)')\n    parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),\n                        help='the resize ratio for the random resize crop')\n    parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input and output image size during training')\n    parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),\n                        help='the input image size during test')\n    parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024),\n                        help='the output image size during test')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: deeplabv2_resnet101)')\n    parser.add_argument(\"--entropy-weight\", type=float, default=0., help=\"weight for entropy\")\n    parser.add_argument(\"--ita\", type=float, default=2.0, help=\"ita for robust entropy\")\n    parser.add_argument(\"--beta\", type=int, default=1, help=\"beta for FDA\")\n    parser.add_argument(\"--resume\", type=str, default=None,\n                        help=\"Where restore model parameters from.\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=2, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 2)')\n    parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument(\"--momentum\", type=float, default=0.9, help=\"Momentum component of the optimiser.\")\n    parser.add_argument(\"--weight-decay\", type=float, default=0.0005, help=\"Regularisation parameter for L2-loss.\")\n    parser.add_argument(\"--lr-power\", type=float, default=0.9,\n                        help=\"Decay parameter to compute the learning rate (only for deeplab).\")\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--ignore-label\", type=int, default=255,\n                        help=\"The index of the label to ignore during the training.\")\n    parser.add_argument(\"--log\", type=str, default='fda',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    parser.add_argument('--debug', action=\"store_true\",\n                        help='In the debug mode, save images and predictions during training')\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/semantic_segmentation/fda.sh",
    "content": "# GTA5 to Cityscapes\nCUDA_VISIBLE_DEVICES=0 python fda.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \\\n    --log logs/fda/gtav2cityscapes --debug\n\n# Synthia to Cityscapes\nCUDA_VISIBLE_DEVICES=0 python fda.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \\\n    --log logs/fda/synthia2cityscapes --debug\n\n# Cityscapes to FoggyCityscapes\nCUDA_VISIBLE_DEVICES=0 python fda.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \\\n    --log logs/fda/cityscapes2foggy --debug\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/README.md",
    "content": "# Unsupervised Domain Adaptation for WILDS (Image Classification)\n\n## Installation\n\nIt’s suggested to use **pytorch==1.9.0** in order to reproduce the benchmark results.\n\nYou need to install **apex** following ``https://github.com/NVIDIA/apex``. Then run\n\n```\npip install -r requirements.txt\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [DomainNet](http://ai.bu.edu/M3SDA/)\n- [iwildcam (WILDS)](https://wilds.stanford.edu/datasets/)\n- [camelyon17 (WILDS)](https://wilds.stanford.edu/datasets/)\n- [fmow (WILDS)](https://wilds.stanford.edu/datasets/)\n\n## Supported Methods\n\nSupported methods include:\n\n- [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818)\n- [Deep Adaptation Network (DAN)](https://arxiv.org/pdf/1502.02791)\n- [Joint Adaptation Network (JAN)](https://arxiv.org/abs/1605.06636)\n- [Conditional Domain Adversarial Network (CDAN)](https://arxiv.org/abs/1705.10667)\n- [Margin Disparity Discrepancy (MDD)](https://arxiv.org/abs/1904.05801)\n- [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (FixMatch)](https://arxiv.org/abs/2001.07685)\n\n## Usage\n\nOur code is based\non [https://github.com/NVIDIA/apex/edit/master/examples/imagenet](https://github.com/NVIDIA/apex/edit/master/examples/imagenet)\n. It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and\nVGG, on the WILDS dataset.  \nCommand-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed\nprecision \"optimization levels\" or `opt_level`s.  \nFor a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).\n\nThe shell files give all the training scripts we use, e.g.,\n\n```\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121\n```\n\n## Results\n\n### Performance on WILDS-FMoW (DenseNet-121)\n\n| Methods | Val Avg Acc | Test Avg Acc | Val Worst-region Acc | Test Worst-region Acc |\n|---------|-------------|--------------|----------------------|-----------------------|\n| ERM     | 59.8        | 53.3         | 50.2                 | 32.2                  |\n| DANN    | 60.6        | 54.2         | 49.1                 | 34.8                  |\n| DAN     | 61.7        | 55.5         | 48.3                 | 35.3                  |\n| JAN     | 61.5        | 55.3         | 50.6                 | 36.3                  |\n| CDAN    | 60.7        | 55.0         | 47.4                 | 35.5                  |\n| MDD     | 60.1        | 55.1         | 49.3                 | 35.9                  |\n| FixMatch| 61.1        | 55.1         | 51.8                 | 37.4                  |\n\n### Performance on WILDS-IWildCAM (ResNet50)\n\n| Methods | Val Avg Acc | Test Avg Acc | Val F1 macro | Test F1 macro |\n|---------|-------------|--------------|--------------|---------------|\n| ERM     | 59.9        | 72.6         | 36.3         | 32.9          |\n| DANN    | 57.4        | 70.1         | 35.8         | 32.2          |\n| DAN     | 63.7        | 69.4         | 39.1         | 31.6          |\n| JAN     | 62.4        | 68.7         | 37.6         | 31.5          |\n| CDAN    | 57.6        | 71.2         | 37.0         | 30.6          |\n| MDD     | 58.3        | 73.5         | 35.0         | 30.0          |\n\n### Visualization\n\nWe use tensorboard to record the training process and visualize the outputs of the models.\n\n```\ntensorboard --logdir=logs\n```\n\n### Distributed training\n\nWe uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process.\n\n```\nCUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 6666  erm.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 -j 8 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121_bs_128\n```\n\n## TODO\n\n1. update experiment results\n2. support DomainNet\n3. support camelyon17\n4. support self-training methods\n5. support self-supervised methods\n\n## Citation\n\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{DANN,\n    author = {Ganin, Yaroslav and Lempitsky, Victor},\n    Booktitle = {ICML},\n    Title = {Unsupervised domain adaptation by backpropagation},\n    Year = {2015}\n}\n\n@inproceedings{DAN,\n    author    = {Mingsheng Long and\n    Yue Cao and\n    Jianmin Wang and\n    Michael I. Jordan},\n    title     = {Learning Transferable Features with Deep Adaptation Networks},\n    booktitle = {ICML},\n    year      = {2015},\n}\n\n@inproceedings{JAN,\n    title={Deep transfer learning with joint adaptation networks},\n    author={Long, Mingsheng and Zhu, Han and Wang, Jianmin and Jordan, Michael I},\n    booktitle={ICML},\n    year={2017},\n}\n\n@inproceedings{CDAN,\n    author    = {Mingsheng Long and\n                Zhangjie Cao and\n                Jianmin Wang and\n                Michael I. Jordan},\n    title     = {Conditional Adversarial Domain Adaptation},\n    booktitle = {NeurIPS},\n    year      = {2018}\n}\n\n@inproceedings{MDD,\n    title={Bridging theory and algorithm for domain adaptation},\n    author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael},\n    booktitle={ICML},\n    year={2019},\n}\n\n@inproceedings{FixMatch,\n    title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},\n    author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},\n    booktitle={NIPS},\n    year={2020}\n}\n\n```\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/cdan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\nimport math\nfrom itertools import cycle\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nimport torchvision.models as models\nfrom torch.utils.tensorboard import SummaryWriter\nfrom timm.loss.cross_entropy import LabelSmoothingCrossEntropy\nimport wilds\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier as Classifier\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    train_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        interpolation=args.interpolation,\n    )\n    val_transform = utils.get_val_transform(\n        img_size=args.img_size,\n        crop_pct=args.crop_pct,\n        interpolation=args.interpolation,\n    )\n    if args.local_rank == 0:\n        print(\"train_transform: \", train_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, verbose=args.local_rank == 0)\n\n    # create model\n    if args.local_rank == 0:\n        if not args.scratch:\n            print(\"=> using pre-trained model '{}'\".format(args.arch))\n        else:\n            print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,\n                       pool_layer=pool_layer, finetune=not args.scratch)\n    features_dim = model.features_dim\n\n    if args.randomized:\n        domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024, sigmoid=False)\n    else:\n        domain_discri = DomainDiscriminator(features_dim * args.num_classes, hidden_size=1024, sigmoid=False)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n    domain_discri = domain_discri.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.\n    optimizer = torch.optim.SGD(\n        model.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum,\n        weight_decay=args.weight_decay, nesterov=True)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    (model, domain_discri), optimizer = amp.initialize([model, domain_discri], optimizer,\n                                                       opt_level=args.opt_level,\n                                                       keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                                       loss_scale=args.loss_scale\n                                                       )\n\n    # Use cosine annealing learning rate strategy\n    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optimizer,\n        lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)\n    )\n\n    # define loss function\n    domain_adv = ConditionalDomainAdversarialLoss(\n        domain_discri, num_classes=args.num_classes, features_dim=features_dim, randomized=args.randomized,\n        randomized_dim=args.randomized_dim, sigmoid=False\n    )\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n        domain_adv = DDP(domain_adv, delay_allreduce=True)\n\n    # define loss function (criterion)\n    if args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()\n    else:\n        criterion = nn.CrossEntropyLoss().cuda()\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_unlabeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n        train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)\n    train_unlabeled_loader = DataLoader(\n        train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n            train_unlabeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        if args.local_rank == 0:\n            print(lr_scheduler.get_last_lr())\n            writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer,\n              args)\n\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            prec1 = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n\ndef train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv,\n          optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_trans = AverageMeter('Loss (transfer)', ':3.2f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))\n\n    for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \\\n            zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):\n\n        # compute output\n        n_s, n_t = len(input_s), len(input_t)\n        input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)\n        output, feature = model(input)\n        output_s, output_t = output.split([n_s, n_t], dim=0)\n        feature_s, feature_t = feature.split([n_s, n_t], dim=0)\n        loss_s = criterion(output_s, target_s.cuda())\n        loss_trans = domain_adv(output_s, feature_s, output_t, feature_t)\n        loss = loss_s + loss_trans * args.trade_off\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)\n                reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)\n                prec1 = utils.reduce_tensor(prec1, args.world_size)\n                domain_acc = domain_adv.module.domain_discriminator_accuracy\n            else:\n                reduced_loss_s = loss_s.data\n                reduced_loss_trans = loss_trans.data\n                domain_acc = domain_adv.domain_discriminator_accuracy\n\n            # to_python_float incurs a host<->device sync\n            losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))\n            losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))\n            domain_accs.update(to_python_float(domain_acc), input_s.size(0))\n            top1.update(to_python_float(prec1), input_s.size(0))\n            global_step = epoch * num_iterations + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar('train/top1', to_python_float(prec1), global_step)\n                writer.add_scalar(\"train/loss (s)\", to_python_float(reduced_loss_s), global_step)\n                writer.add_scalar(\"train/loss (trans)\", to_python_float(reduced_loss_trans), global_step)\n                writer.add_figure('train/predictions vs. actuals',\n                                  utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,\n                                                           metadata_s, train_labeled_loader.dataset.metadata_map),\n                                  global_step=global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\\t'\n                      'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\\t'\n                      'Domain Acc {domain_acc.val:.10f} ({domain_acc.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                    epoch, i, len(train_labeled_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans,\n                    domain_acc=domain_accs, top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = sorted(name for name in models.__dict__\n                         if name.islower() and not name.startswith(\"__\")\n                         and callable(models.__dict__[name]))\n\n    parser = argparse.ArgumentParser(description='CDAN')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: fmow)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[\"test_unlabeled\", ])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default=\"acc_worst_region\")\n    parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=None, metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--bottleneck-dim', default=512, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('-r', '--randomized', action='store_true',\n                        help='using randomized multi-linear-map (default: False)')\n    parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,\n                        help='randomized dimension when using randomized multi-linear-map (default: 1024)')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR',\n                        help='Initial learning rate.  Will be scaled by <global batch size>/256: '\n                             'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument(\"--log\", type=str, default='cdan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/cdan.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/cdan/fmow/lr_0_1_aa_v0_densenet121\n\nCUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d \"iwildcam\" --aa \"v0\" --unlabeled-list \"extra_unlabeled\" --lr 1 --opt-level O1 \\\n  --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric \"F1-macro_all\" \\\n  --log logs/cdan/iwildcam/lr_1_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/dan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\nimport math\nfrom itertools import cycle\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nimport torchvision.models as models\nfrom torch.utils.tensorboard import SummaryWriter\nfrom timm.loss.cross_entropy import LabelSmoothingCrossEntropy\nimport wilds\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom tllib.alignment.dan import MultipleKernelMaximumMeanDiscrepancy, ImageClassifier as Classifier\nfrom tllib.modules.kernels import GaussianKernel\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    train_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        interpolation=args.interpolation,\n    )\n    val_transform = utils.get_val_transform(\n        img_size=args.img_size,\n        crop_pct=args.crop_pct,\n        interpolation=args.interpolation,\n    )\n    if args.local_rank == 0:\n        print(\"train_transform: \", train_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, verbose=args.local_rank == 0)\n\n    # create model\n    if args.local_rank == 0:\n        if not args.scratch:\n            print(\"=> using pre-trained model '{}'\".format(args.arch))\n        else:\n            print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,\n                       pool_layer=pool_layer, finetune=not args.scratch)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.\n    optimizer = torch.optim.SGD(\n        model.get_parameters(), args.lr, momentum=args.momentum,\n        weight_decay=args.weight_decay, nesterov=True)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    model, optimizer = amp.initialize(model, optimizer,\n                                      opt_level=args.opt_level,\n                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                      loss_scale=args.loss_scale\n                                      )\n\n    # Use cosine annealing learning rate strategy\n    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optimizer,\n        lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)\n    )\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n\n    # define loss function (criterion)\n    if args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()\n    else:\n        criterion = nn.CrossEntropyLoss().cuda()\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_unlabeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n        train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)\n    train_unlabeled_loader = DataLoader(\n        train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    # define loss function\n    mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy(\n        kernels=[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],\n        linear=not args.non_linear\n    )\n\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n            train_unlabeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        print(lr_scheduler.get_last_lr())\n        writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, train_unlabeled_loader, model, criterion, mkmmd_loss, optimizer, epoch, writer,\n              args)\n\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            prec1 = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n\ndef train(train_labeled_loader, train_unlabeled_loader, model, criterion, mkmmd_loss, optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_trans = AverageMeter('Loss (transfer)', ':3.2f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))\n\n    for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \\\n            zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):\n\n        # compute output\n        n_s, n_t = len(input_s), len(input_t)\n        input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)\n        output, feature = model(input)\n        output_s, output_t = output.split([n_s, n_t], dim=0)\n        feature_s, feature_t = feature.split([n_s, n_t], dim=0)\n        loss_s = criterion(output_s, target_s.cuda())\n        loss_trans = mkmmd_loss(feature_s, feature_t)\n        loss = loss_s + loss_trans * args.trade_off\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)\n                reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)\n                prec1 = utils.reduce_tensor(prec1, args.world_size)\n            else:\n                reduced_loss_s = loss_s.data\n                reduced_loss_trans = loss_trans.data\n\n            # to_python_float incurs a host<->device sync\n            losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))\n            losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))\n            top1.update(to_python_float(prec1), input_s.size(0))\n            global_step = epoch * num_iterations + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar('train/top1', to_python_float(prec1), global_step)\n                writer.add_scalar(\"train/loss (s)\", to_python_float(reduced_loss_s), global_step)\n                writer.add_scalar(\"train/loss (trans)\", to_python_float(reduced_loss_trans), global_step)\n                writer.add_figure('train/predictions vs. actuals',\n                                  utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,\n                                                           metadata_s, train_labeled_loader.dataset.metadata_map),\n                                  global_step=global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\\t'\n                      'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                    epoch, i, len(train_labeled_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time,\n                    loss_s=losses_s, loss_trans=losses_trans, top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = sorted(name for name in models.__dict__\n                         if name.islower() and not name.startswith(\"__\")\n                         and callable(models.__dict__[name]))\n\n    parser = argparse.ArgumentParser(description='DAN')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: fmow)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[\"test_unlabeled\", ])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default=\"acc_worst_region\")\n    parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=None, metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--bottleneck-dim', default=512, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--non-linear', default=False, action='store_true',\n                        help='whether not use the linear version')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR',\n                        help='Initial learning rate.  Will be scaled by <global batch size>/256: '\n                             'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument(\"--log\", type=str, default='dan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/dan.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dan/fmow/lr_0_1_aa_v0_densenet121\n\nCUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d \"iwildcam\" --aa \"v0\" --unlabeled-list \"extra_unlabeled\" --lr 0.3 --opt-level O1 \\\n  --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric \"F1-macro_all\" \\\n  --log logs/dan/iwildcam/lr_0_3_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/dann.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\nimport math\nfrom itertools import cycle\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nimport torchvision.models as models\nfrom torch.utils.tensorboard import SummaryWriter\nfrom timm.loss.cross_entropy import LabelSmoothingCrossEntropy\nimport wilds\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom tllib.modules.domain_discriminator import DomainDiscriminator\nfrom tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier as Classifier\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    train_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        interpolation=args.interpolation,\n    )\n    val_transform = utils.get_val_transform(\n        img_size=args.img_size,\n        crop_pct=args.crop_pct,\n        interpolation=args.interpolation,\n    )\n    if args.local_rank == 0:\n        print(\"train_transform: \", train_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, verbose=args.local_rank == 0)\n\n    # create model\n    if args.local_rank == 0:\n        if not args.scratch:\n            print(\"=> using pre-trained model '{}'\".format(args.arch))\n        else:\n            print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,\n                       pool_layer=pool_layer, finetune=not args.scratch)\n    features_dim = model.features_dim\n    domain_discri = DomainDiscriminator(features_dim, hidden_size=1024, sigmoid=False)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n    domain_discri = domain_discri.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.\n    optimizer = torch.optim.SGD(\n        model.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum,\n        weight_decay=args.weight_decay, nesterov=True)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    (model, domain_discri), optimizer = amp.initialize([model, domain_discri], optimizer,\n                                                       opt_level=args.opt_level,\n                                                       keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                                       loss_scale=args.loss_scale\n                                                       )\n\n    # Use cosine annealing learning rate strategy\n    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optimizer,\n        lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)\n    )\n\n    # define loss function\n    domain_adv = DomainAdversarialLoss(domain_discri, sigmoid=False)\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n        domain_adv = DDP(domain_adv, delay_allreduce=True)\n\n    # define loss function (criterion)\n    if args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()\n    else:\n        criterion = nn.CrossEntropyLoss().cuda()\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_unlabeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n        train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)\n    train_unlabeled_loader = DataLoader(\n        train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n            train_unlabeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        if args.local_rank == 0:\n            print(lr_scheduler.get_last_lr())\n            writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer,\n              args)\n\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            prec1 = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n\ndef train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv,\n          optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_trans = AverageMeter('Loss (transfer)', ':3.2f')\n    domain_accs = AverageMeter('Domain Acc', ':3.1f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))\n\n    for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \\\n            zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):\n\n        # compute output\n        n_s, n_t = len(input_s), len(input_t)\n        input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)\n        output, feature = model(input)\n        output_s, output_t = output.split([n_s, n_t], dim=0)\n        feature_s, feature_t = feature.split([n_s, n_t], dim=0)\n        loss_s = criterion(output_s, target_s.cuda())\n        loss_trans = domain_adv(feature_s, feature_t)\n        loss = loss_s + loss_trans * args.trade_off\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)\n                reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)\n                prec1 = utils.reduce_tensor(prec1, args.world_size)\n                domain_acc = domain_adv.module.domain_discriminator_accuracy\n            else:\n                reduced_loss_s = loss_s.data\n                reduced_loss_trans = loss_trans.data\n                domain_acc = domain_adv.domain_discriminator_accuracy\n\n            # to_python_float incurs a host<->device sync\n            losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))\n            losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))\n            domain_accs.update(to_python_float(domain_acc), input_s.size(0))\n            top1.update(to_python_float(prec1), input_s.size(0))\n            global_step = epoch * num_iterations + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar('train/top1', to_python_float(prec1), global_step)\n                writer.add_scalar(\"train/loss (s)\", to_python_float(reduced_loss_s), global_step)\n                writer.add_scalar(\"train/loss (trans)\", to_python_float(reduced_loss_trans), global_step)\n                writer.add_figure('train/predictions vs. actuals',\n                                  utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,\n                                                           metadata_s, train_labeled_loader.dataset.metadata_map),\n                                  global_step=global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\\t'\n                      'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\\t'\n                      'Domain Acc {domain_acc.val:.10f} ({domain_acc.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                    epoch, i, len(train_labeled_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans,\n                    domain_acc=domain_accs, top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = sorted(name for name in models.__dict__\n                         if name.islower() and not name.startswith(\"__\")\n                         and callable(models.__dict__[name]))\n\n    parser = argparse.ArgumentParser(description='DANN')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: fmow)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[\"test_unlabeled\", ])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default=\"acc_worst_region\")\n    parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=None, metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--bottleneck-dim', default=512, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR',\n                        help='Initial learning rate.  Will be scaled by <global batch size>/256: '\n                             'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument(\"--log\", type=str, default='dann',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/dann.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dann/fmow/lr_0_1_aa_v0_densenet121\n\nCUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d \"iwildcam\" --aa \"v0\" --unlabeled-list \"extra_unlabeled\" --lr 1 --opt-level O1 \\\n  --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric \"F1-macro_all\" \\\n  --log logs/dann/iwildcam/lr_1_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/erm.py",
    "content": "\"\"\"\nAdapted from https://github.com/NVIDIA/apex/tree/master/examples\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nimport torchvision.models as models\nfrom torch.utils.tensorboard import SummaryWriter\nfrom timm.loss.cross_entropy import LabelSmoothingCrossEntropy\nimport wilds\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    train_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        interpolation=args.interpolation,\n    )\n    val_transform = utils.get_val_transform(\n        img_size=args.img_size,\n        crop_pct=args.crop_pct,\n        interpolation=args.interpolation,\n    )\n    if args.local_rank == 0:\n        print(\"train_transform: \", train_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, verbose=args.local_rank == 0)\n\n    # create model\n    if args.local_rank == 0:\n        if not args.scratch:\n            print(\"=> using pre-trained model '{}'\".format(args.arch))\n        else:\n            print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = Classifier(backbone, args.num_classes, pool_layer=pool_layer, finetune=not args.scratch)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.\n    optimizer = torch.optim.SGD(\n        model.get_parameters(), args.lr, momentum=args.momentum,\n        weight_decay=args.weight_decay, nesterov=True)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    model, optimizer = amp.initialize(model, optimizer,\n                                      opt_level=args.opt_level,\n                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                      loss_scale=args.loss_scale\n                                      )\n\n    # Use cosine annealing learning rate strategy\n    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optimizer,\n        lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)\n    )\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n\n    # define loss function (criterion)\n    if args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()\n    else:\n        criterion = nn.CrossEntropyLoss().cuda()\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_unlabeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n        train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler)\n    train_unlabeled_loader = DataLoader(\n        train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n            train_unlabeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        if args.local_rank == 0:\n            print(lr_scheduler.get_last_lr())\n            writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args)\n\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            prec1 = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n\ndef train(train_loader, model, criterion, optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(train_loader):\n\n        # compute output\n        output, _ = model(input.cuda())\n        loss = criterion(output, target.cuda())\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output.data, target.cuda(), topk=(1,))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)\n                prec1 = utils.reduce_tensor(prec1, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            # to_python_float incurs a host<->device sync\n            losses.update(to_python_float(reduced_loss), input.size(0))\n            top1.update(to_python_float(prec1), input.size(0))\n            global_step = epoch * len(train_loader) + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar('train/top1', to_python_float(prec1), global_step)\n                writer.add_scalar(\"train/loss\", to_python_float(reduced_loss), global_step)\n                writer.add_figure('train/predictions vs. actuals',\n                                  utils.plot_classes_preds(input.cpu(), target, output.cpu(), args.class_names,\n                                                           metadata, train_loader.dataset.metadata_map),\n                                  global_step=global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss {loss.val:.10f} ({loss.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                    epoch, i, len(train_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time,\n                    loss=losses, top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = sorted(name for name in models.__dict__\n                         if name.islower() and not name.startswith(\"__\")\n                         and callable(models.__dict__[name]))\n\n    parser = argparse.ArgumentParser(description='Src Only')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: fmow)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[\"test_unlabeled\", ])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default=\"acc_worst_region\")\n    parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=None, metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR',\n                        help='Initial learning rate.  Will be scaled by <global batch size>/256: '\n                             'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument(\"--log\", type=str, default='src_only',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/erm.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121\n\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d \"iwildcam\" --aa \"v0\" --unlabeled-list \"extra_unlabeled\" --lr 1 --opt-level O1 \\\n  --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 -p 500 --metric \"F1-macro_all\" \\\n  --log logs/erm/iwildcam/lr_1_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/fixmatch.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\nimport math\nfrom itertools import cycle\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nimport torchvision.models as models\nfrom torch.utils.tensorboard import SummaryWriter\nfrom timm.loss.cross_entropy import LabelSmoothingCrossEntropy\nimport wilds\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\nclass ImageClassifier(Classifier):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=512, **kwargs):\n        bottleneck = nn.Sequential(\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\"\"\"\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        predictions = self.head(f)\n        return predictions\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    weak_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=None,\n        auto_augment=None,\n        interpolation=args.interpolation,\n    )\n    strong_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        interpolation=args.interpolation,\n    )\n    train_source_transform = strong_transform\n    train_target_transform = MultipleApply([weak_transform, strong_transform])\n    val_transform = utils.get_val_transform(\n        img_size=args.img_size,\n        crop_pct=args.crop_pct,\n        interpolation=args.interpolation,\n    )\n    if args.local_rank == 0:\n        print(\"train_source_transform: \", train_source_transform)\n        print('train_target_transform: ', train_target_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_source_transform, val_transform, verbose=args.local_rank == 0,\n                          transform_train_target=train_target_transform)\n\n    # create model\n    if args.local_rank == 0:\n        if not args.scratch:\n            print(\"=> using pre-trained model '{}'\".format(args.arch))\n        else:\n            print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = ImageClassifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,\n                            pool_layer=pool_layer, finetune=not args.scratch)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.\n    optimizer = torch.optim.SGD(\n        model.get_parameters(), args.lr, momentum=args.momentum,\n        weight_decay=args.weight_decay, nesterov=True)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    model, optimizer = amp.initialize(model, optimizer,\n                                      opt_level=args.opt_level,\n                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                      loss_scale=args.loss_scale\n                                      )\n\n    # Use cosine annealing learning rate strategy\n    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optimizer,\n        lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)\n    )\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n\n    # define loss function (criterion)\n    if args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()\n    else:\n        criterion = nn.CrossEntropyLoss().cuda()\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_unlabeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n        train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)\n    train_unlabeled_loader = DataLoader(\n        train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n            train_unlabeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        if args.local_rank == 0:\n            print(lr_scheduler.get_last_lr())\n            writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, train_unlabeled_loader, model, criterion, optimizer, epoch, writer, args)\n\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            prec1 = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n\ndef train(train_labeled_loader, train_unlabeled_loader, model, criterion, optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_self_training = AverageMeter('Loss (self training)', ':3.2f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))\n\n    for i, (input_s, target_s, metadata_s), ((input_t, input_t_strong), metadata_t) in \\\n            zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):\n\n        # compute output\n        n_s, n_t = len(input_s), len(input_t)\n\n        with torch.no_grad():\n            output_t = model(input_t.cuda())\n            confidence, pseudo_labels = F.softmax(output_t, dim=1).max(dim=1)\n            mask = (confidence > args.threshold).float()\n        input = torch.cat([input_s.cuda(), input_t_strong.cuda()], dim=0)\n        output = model(input)\n        output_s, output_t_strong = output.split([n_s, n_t], dim=0)\n\n        loss_s = criterion(output_s, target_s.cuda())\n        loss_self_training = args.trade_off * \\\n                             (F.cross_entropy(output_t_strong, pseudo_labels, reduction='none') * mask).mean()\n        loss = loss_s + loss_self_training\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)\n                reduced_loss_self_training = utils.reduce_tensor(loss_self_training.data, args.world_size)\n                prec1 = utils.reduce_tensor(prec1, args.world_size)\n            else:\n                reduced_loss_s = loss_s.data\n                reduced_loss_self_training = loss_self_training.data\n\n            # to_python_float incurs a host<->device sync\n            losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))\n            losses_self_training.update(to_python_float(reduced_loss_self_training), input_s.size(0))\n            top1.update(to_python_float(prec1), input_s.size(0))\n            global_step = epoch * num_iterations + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar('train/top1', to_python_float(prec1), global_step)\n                writer.add_scalar(\"train/loss (s)\", to_python_float(reduced_loss_s), global_step)\n                writer.add_scalar(\"train/loss (self training)\", to_python_float(reduced_loss_self_training),\n                                  global_step)\n                writer.add_figure('train/predictions vs. actuals',\n                                  utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,\n                                                           metadata_s, train_labeled_loader.dataset.metadata_map),\n                                  global_step=global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\\t'\n                      'Loss (self training) {loss_self_training.val:.10f} ({loss_self_training.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                    epoch, i, len(train_labeled_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time, loss_s=losses_s, loss_self_training=losses_self_training,\n                    top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = sorted(name for name in models.__dict__\n                         if name.islower() and not name.startswith(\"__\")\n                         and callable(models.__dict__[name]))\n\n    parser = argparse.ArgumentParser(description='FixMatch')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: fmow)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[\"test_unlabeled\", ])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default=\"acc_worst_region\")\n    parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=None, metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--bottleneck-dim', default=512, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    parser.add_argument('--threshold', default=0.7, type=float,\n                        help='confidence threshold')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR',\n                        help='Initial learning rate.  Will be scaled by <global batch size>/256: '\n                             'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument(\"--log\", type=str, default='fixmatch',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/fixmatch.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/fixmatch/fmow/lr_0_1_aa_v0_densenet121\n\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d \"iwildcam\" --aa \"v0\" --unlabeled-list \"extra_unlabeled\" \\\n  --lr 0.3 --opt-level O1 --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 12 -b 24 24 -p 500 \\\n  --metric \"F1-macro_all\" --log logs/fixmatch/iwildcam/lr_0_3_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/jan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\nimport math\nfrom itertools import cycle\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nimport torchvision.models as models\nfrom torch.utils.tensorboard import SummaryWriter\nfrom timm.loss.cross_entropy import LabelSmoothingCrossEntropy\nimport wilds\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom tllib.alignment.jan import JointMultipleKernelMaximumMeanDiscrepancy, ImageClassifier as Classifier\nfrom tllib.modules.kernels import GaussianKernel\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    train_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        interpolation=args.interpolation,\n    )\n    val_transform = utils.get_val_transform(\n        img_size=args.img_size,\n        crop_pct=args.crop_pct,\n        interpolation=args.interpolation,\n    )\n    if args.local_rank == 0:\n        print(\"train_transform: \", train_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, verbose=args.local_rank == 0)\n\n    # create model\n    if args.local_rank == 0:\n        if not args.scratch:\n            print(\"=> using pre-trained model '{}'\".format(args.arch))\n        else:\n            print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,\n                       pool_layer=pool_layer, finetune=not args.scratch)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.\n    optimizer = torch.optim.SGD(\n        model.get_parameters(), args.lr, momentum=args.momentum,\n        weight_decay=args.weight_decay, nesterov=True)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    model, optimizer = amp.initialize(model, optimizer,\n                                      opt_level=args.opt_level,\n                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                      loss_scale=args.loss_scale\n                                      )\n\n    # Use cosine annealing learning rate strategy\n    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optimizer,\n        lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)\n    )\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n\n    # define loss function (criterion)\n    if args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()\n    else:\n        criterion = nn.CrossEntropyLoss().cuda()\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_unlabeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n        train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)\n    train_unlabeled_loader = DataLoader(\n        train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    # define loss function\n    jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy(\n        kernels=(\n            [GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],\n            (GaussianKernel(sigma=0.92, track_running_stats=False),)\n        ),\n        linear=args.linear\n    )\n\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n            train_unlabeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        if args.local_rank == 0:\n            print(lr_scheduler.get_last_lr())\n            writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, train_unlabeled_loader, model, criterion, jmmd_loss, optimizer, epoch, writer, args)\n\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            prec1 = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n\ndef train(train_labeled_loader, train_unlabeled_loader, model, criterion, jmmd_loss, optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_trans = AverageMeter('Loss (transfer)', ':3.2f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))\n\n    for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \\\n            zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):\n\n        # compute output\n        n_s, n_t = len(input_s), len(input_t)\n        input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)\n        output, feature = model(input)\n        output_s, output_t = output.split([n_s, n_t], dim=0)\n        feature_s, feature_t = feature.split([n_s, n_t], dim=0)\n        loss_s = criterion(output_s, target_s.cuda())\n        loss_trans = jmmd_loss(\n            (feature_s, F.softmax(output_s, dim=1)),\n            (feature_t, F.softmax(output_t, dim=1))\n        )\n        loss = loss_s + loss_trans * args.trade_off\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)\n                reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)\n                prec1 = utils.reduce_tensor(prec1, args.world_size)\n            else:\n                reduced_loss_s = loss_s.data\n                reduced_loss_trans = loss_trans.data\n\n            # to_python_float incurs a host<->device sync\n            losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))\n            losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))\n            top1.update(to_python_float(prec1), input_s.size(0))\n            global_step = epoch * num_iterations + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar('train/top1', to_python_float(prec1), global_step)\n                writer.add_scalar(\"train/loss (s)\", to_python_float(reduced_loss_s), global_step)\n                writer.add_scalar(\"train/loss (trans)\", to_python_float(reduced_loss_trans), global_step)\n                writer.add_figure('train/predictions vs. actuals',\n                                  utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,\n                                                           metadata_s, train_labeled_loader.dataset.metadata_map),\n                                  global_step=global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\\t'\n                      'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                    epoch, i, len(train_labeled_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time,\n                    loss_s=losses_s, loss_trans=losses_trans, top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = sorted(name for name in models.__dict__\n                         if name.islower() and not name.startswith(\"__\")\n                         and callable(models.__dict__[name]))\n    parser = argparse.ArgumentParser(description='JAN')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: fmow)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[\"test_unlabeled\", ])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default=\"acc_worst_region\")\n    parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=None, metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--bottleneck-dim', default=512, type=int,\n                        help='Dimension of bottleneck')\n    parser.add_argument('--linear', default=False, action='store_true',\n                        help='whether use the linear version')\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR',\n                        help='Initial learning rate.  Will be scaled by <global batch size>/256: '\n                             'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument(\"--log\", type=str, default='jan',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/jan.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/jan/fmow/lr_0_1_aa_v0_densenet121\n\nCUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d \"iwildcam\" --aa \"v0\" --unlabeled-list \"extra_unlabeled\" --lr 0.3 --opt-level O1 \\\n  --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric \"F1-macro_all\" \\\n  --log logs/jan/iwildcam/lr_0_3_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/mdd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\nimport math\nfrom itertools import cycle\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim\nimport torch.utils.data\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nimport torchvision.models as models\nfrom torch.utils.tensorboard import SummaryWriter\nfrom timm.loss.cross_entropy import LabelSmoothingCrossEntropy\nimport wilds\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp, optimizers\n    from apex.multi_tensor_apply import multi_tensor_applier\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom tllib.alignment.mdd import ClassificationMarginDisparityDiscrepancy \\\n    as MarginDisparityDiscrepancy, ImageClassifier as Classifier\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    best_prec1 = 0\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    train_transform = utils.get_train_transform(\n        img_size=args.img_size,\n        scale=args.scale,\n        ratio=args.ratio,\n        hflip=args.hflip,\n        vflip=args.vflip,\n        color_jitter=args.color_jitter,\n        auto_augment=args.aa,\n        interpolation=args.interpolation,\n    )\n    val_transform = utils.get_val_transform(\n        img_size=args.img_size,\n        crop_pct=args.crop_pct,\n        interpolation=args.interpolation,\n    )\n    if args.local_rank == 0:\n        print(\"train_transform: \", train_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, verbose=args.local_rank == 0)\n\n    # create model\n    if args.local_rank == 0:\n        if not args.scratch:\n            print(\"=> using pre-trained model '{}'\".format(args.arch))\n        else:\n            print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrain=not args.scratch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,\n                       width=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch)\n    mdd = MarginDisparityDiscrepancy(args.margin)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n\n    model = model.cuda().to(memory_format=memory_format)\n\n    # Scale learning rate based on global batch size\n    args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.\n    optimizer = torch.optim.SGD(\n        model.get_parameters(), args.lr, momentum=args.momentum,\n        weight_decay=args.weight_decay, nesterov=True)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    model, optimizer = amp.initialize(model, optimizer,\n                                      opt_level=args.opt_level,\n                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                      loss_scale=args.loss_scale\n                                      )\n\n    # Use cosine annealing learning rate strategy\n    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optimizer,\n        lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)\n    )\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n\n    # define loss function (criterion)\n    if args.smoothing:\n        criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()\n    else:\n        criterion = nn.CrossEntropyLoss().cuda()\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_unlabeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n        train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)\n    train_unlabeled_loader = DataLoader(\n        train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n            train_unlabeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        if args.local_rank == 0:\n            print(lr_scheduler.get_last_lr())\n            writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, train_unlabeled_loader, model, criterion, mdd, optimizer, epoch, writer, args)\n\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            prec1 = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        if args.local_rank == 0:\n            is_best = prec1 > best_prec1\n            best_prec1 = max(prec1, best_prec1)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n\ndef train(train_labeled_loader, train_unlabeled_loader, model, criterion, mdd,\n          optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses_s = AverageMeter('Loss (s)', ':3.2f')\n    losses_trans = AverageMeter('Loss (transfer)', ':3.2f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))\n\n    for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \\\n            zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):\n\n        # compute output\n        n_s, n_t = len(input_s), len(input_t)\n        input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)\n        output, output_adv = model(input)\n        output_s, output_t = output.split([n_s, n_t], dim=0)\n        output_adv_s, output_adv_t = output_adv.split([n_s, n_t], dim=0)\n        loss_s = criterion(output_s, target_s.cuda())\n        loss_trans = -mdd(output_s, output_adv_s, output_t, output_adv_t)\n        loss = loss_s + loss_trans * args.trade_off\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if args.distributed:\n            model.module.step()\n        else:\n            model.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)\n                reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)\n                prec1 = utils.reduce_tensor(prec1, args.world_size)\n            else:\n                reduced_loss_s = loss_s.data\n                reduced_loss_trans = loss_trans.data\n\n            # to_python_float incurs a host<->device sync\n            losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))\n            losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))\n            top1.update(to_python_float(prec1), input_s.size(0))\n            global_step = epoch * num_iterations + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar('train/top1', to_python_float(prec1), global_step)\n                writer.add_scalar(\"train/loss (s)\", to_python_float(reduced_loss_s), global_step)\n                writer.add_scalar(\"train/loss (trans)\", to_python_float(reduced_loss_trans), global_step)\n                writer.add_figure('train/predictions vs. actuals',\n                                  utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,\n                                                           metadata_s, train_labeled_loader.dataset.metadata_map),\n                                  global_step=global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\\t'\n                      'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\\t'\n                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                    epoch, i, len(train_labeled_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans, top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = sorted(name for name in models.__dict__\n                         if name.islower() and not name.startswith(\"__\")\n                         and callable(models.__dict__[name]))\n\n    parser = argparse.ArgumentParser(description='MDD')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: fmow)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[\"test_unlabeled\", ])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default=\"acc_worst_region\")\n    parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',\n                        help='Image patch size (default: None => model default)')\n    parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,\n                        metavar='N', help='Input image center crop percent (for validation only)')\n    parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',\n                        help='Image resize interpolation type (overrides model)')\n    parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',\n                        help='Random resize scale (default: 0.5 1.0)')\n    parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',\n                        help='Random resize aspect ratio (default: 0.75 1.33)')\n    parser.add_argument('--hflip', type=float, default=0.5,\n                        help='Horizontal flip training aug probability')\n    parser.add_argument('--vflip', type=float, default=0.,\n                        help='Vertical flip training aug probability')\n    parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',\n                        help='Color jitter factor (default: 0.4)')\n    parser.add_argument('--aa', type=str, default=None, metavar='NAME',\n                        help='Use AutoAugment policy. \"v0\" or \"original\". (default: None)')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')\n    parser.add_argument('--smoothing', type=float, default=0.1,\n                        help='Label smoothing (default: 0.1)')\n    parser.add_argument('--bottleneck-dim', default=2048, type=int)\n    parser.add_argument('--margin', type=float, default=4., help=\"margin gamma\")\n    parser.add_argument('--trade-off', default=1., type=float,\n                        help='the trade-off hyper-parameter for transfer loss')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR',\n                        help='Initial learning rate.  Will be scaled by <global batch size>/256: '\n                             'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N',\n                        help='mini-batch size per process for source and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--local_rank\", default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument(\"--log\", type=str, default='mdd',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/mdd.sh",
    "content": "CUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d \"fmow\" --aa \"v0\" --arch \"densenet121\" \\\n  --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/mdd/fmow/lr_0_1_aa_v0_densenet121\n\nCUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d \"iwildcam\" --aa \"v0\" --unlabeled-list \"extra_unlabeled\" --lr 0.3 --opt-level O1 \\\n  --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric \"F1-macro_all\" \\\n  --log logs/mdd/iwildcam/lr_0_3_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/requirements.txt",
    "content": "wilds\ntimm\ntensorflow\ntensorboard\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_image_classification/utils.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport time\nimport math\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport sys\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.distributed as dist\nfrom torch.utils.data.distributed import DistributedSampler\nfrom torch.utils.data import DataLoader, ConcatDataset\nfrom torchvision import transforms\nfrom PIL import Image\n\nimport wilds\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT\nfrom timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform\nimport timm\n\nsys.path.append('../../..')\n\nfrom tllib.vision.transforms import Denormalize\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\ndef get_model_names():\n    return timm.list_models()\n\n\ndef get_model(model_name, pretrain=True):\n    # load models from pytorch-image-models\n    backbone = timm.create_model(model_name, pretrained=pretrain)\n    try:\n        backbone.out_features = backbone.get_classifier().in_features\n        backbone.reset_classifier(0, '')\n    except:\n        backbone.out_features = backbone.head.in_features\n        backbone.head = nn.Identity()\n    return backbone\n\n\ndef get_dataset(dataset_name, root, unlabeled_list=(\"test_unlabeled\",), test_list=(\"test\",),\n                transform_train=None, transform_test=None, verbose=True, transform_train_target=None):\n    if transform_train_target is None:\n        transform_train_target = transform_train\n    labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True)\n    unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)\n    num_classes = labeled_dataset.n_classes\n    train_labeled_dataset = labeled_dataset.get_subset(\"train\", transform=transform_train)\n\n    train_unlabeled_datasets = [\n        unlabeled_dataset.get_subset(u, transform=transform_train_target)\n        for u in unlabeled_list\n    ]\n    train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)\n    test_datasets = [\n        labeled_dataset.get_subset(t, transform=transform_test)\n        for t in test_list\n    ]\n\n    if dataset_name == \"fmow\":\n        from wilds.datasets.fmow_dataset import categories\n        class_names = categories\n    else:\n        class_names = list(range(num_classes))\n\n    if verbose:\n        print(\"Datasets\")\n        for n, d in zip([\"train\"] + unlabeled_list + test_list,\n                        [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):\n            print(\"\\t{}:{}\".format(n, len(d)))\n        print(\"\\t#classes:\", num_classes)\n\n    return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_classes, class_names\n\n\ndef collate_list(vec):\n    \"\"\"\n    Adapted from https://github.com/p-lambda/wilds\n    If vec is a list of Tensors, it concatenates them all along the first dimension.\n\n    If vec is a list of lists, it joins these lists together, but does not attempt to\n    recursively collate. This allows each element of the list to be, e.g., its own dict.\n\n    If vec is a list of dicts (with the same keys in each dict), it returns a single dict\n    with the same keys. For each key, it recursively collates all entries in the list.\n    \"\"\"\n    if not isinstance(vec, list):\n        raise TypeError(\"collate_list must take in a list\")\n    elem = vec[0]\n    if torch.is_tensor(elem):\n        return torch.cat(vec)\n    elif isinstance(elem, list):\n        return [obj for sublist in vec for obj in sublist]\n    elif isinstance(elem, dict):\n        return {k: collate_list([d[k] for d in vec]) for k in elem}\n    else:\n        raise TypeError(\"Elements of the list to collate must be tensors or dicts.\")\n\n\ndef get_train_transform(img_size, scale=None, ratio=None, hflip=0.5, vflip=0.,\n                        color_jitter=0.4, auto_augment=None, interpolation='bilinear'):\n    scale = tuple(scale or (0.08, 1.0))  # default imagenet scale range\n    ratio = tuple(ratio or (3. / 4., 4. / 3.))  # default imagenet ratio range\n    transforms_list = [\n        transforms.RandomResizedCrop(img_size, scale=scale, ratio=ratio, interpolation=_pil_interp(interpolation))]\n    if hflip > 0.:\n        transforms_list += [transforms.RandomHorizontalFlip(p=hflip)]\n    if vflip > 0.:\n        transforms_list += [transforms.RandomVerticalFlip(p=vflip)]\n\n    if auto_augment:\n        assert isinstance(auto_augment, str)\n        if isinstance(img_size, (tuple, list)):\n            img_size_min = min(img_size)\n        else:\n            img_size_min = img_size\n        aa_params = dict(\n            translate_const=int(img_size_min * 0.45),\n            img_mean=tuple([min(255, round(255 * x)) for x in IMAGENET_DEFAULT_MEAN]),\n        )\n        if interpolation and interpolation != 'random':\n            aa_params['interpolation'] = _pil_interp(interpolation)\n        if auto_augment.startswith('rand'):\n            transforms_list += [rand_augment_transform(auto_augment, aa_params)]\n        elif auto_augment.startswith('augmix'):\n            aa_params['translate_pct'] = 0.3\n            transforms_list += [augment_and_mix_transform(auto_augment, aa_params)]\n        else:\n            transforms_list += [auto_augment_transform(auto_augment, aa_params)]\n    elif color_jitter is not None:\n        # color jitter is enabled when not using AA\n        if isinstance(color_jitter, (list, tuple)):\n            # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation\n            # or 4 if also augmenting hue\n            assert len(color_jitter) in (3, 4)\n        else:\n            # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue\n            color_jitter = (float(color_jitter),) * 3\n        transforms_list += [transforms.ColorJitter(*color_jitter)]\n\n    transforms_list += [\n        transforms.ToTensor(),\n        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n    ]\n    return transforms.Compose(transforms_list)\n\n\ndef get_val_transform(img_size=224, crop_pct=None, interpolation='bilinear'):\n    crop_pct = crop_pct or DEFAULT_CROP_PCT\n\n    if isinstance(img_size, (tuple, list)):\n        assert len(img_size) == 2\n        if img_size[-1] == img_size[-2]:\n            # fall-back to older behaviour so Resize scales to shortest edge if target is square\n            scale_size = int(math.floor(img_size[0] / crop_pct))\n        else:\n            scale_size = tuple([int(x / crop_pct) for x in img_size])\n    else:\n        scale_size = int(math.floor(img_size / crop_pct))\n\n    return transforms.Compose([\n        transforms.Resize(scale_size, _pil_interp(interpolation)),\n        transforms.CenterCrop(img_size),\n        transforms.ToTensor(),\n        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n    ])\n\n\ndef _pil_interp(method):\n    if method == 'bicubic':\n        return Image.BICUBIC\n    elif method == 'lanczos':\n        return Image.LANCZOS\n    elif method == 'hamming':\n        return Image.HAMMING\n    else:\n        # default bilinear, do we want to allow nearest?\n        return Image.BILINEAR\n\n\ndef validate(val_dataset, model, epoch, writer, args):\n    val_sampler = None\n    if args.distributed:\n        val_sampler = DistributedSampler(val_dataset)\n\n    val_loader = DataLoader(\n        val_dataset, batch_size=args.batch_size[0], shuffle=False,\n        num_workers=args.workers, pin_memory=True, sampler=val_sampler)\n\n    all_y_true = []\n    all_y_pred = []\n    all_metadata = []\n\n    sampled_inputs = []\n    sampled_outputs = []\n    sampled_targets = []\n    sampled_metadata = []\n\n    batch_time = AverageMeter('Time', ':6.3f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(val_loader):\n        # compute output\n        with torch.no_grad():\n            output = model(input.cuda()).cpu()\n\n        all_y_true.append(target)\n        all_y_pred.append(output.argmax(1))\n        all_metadata.append(metadata)\n\n        sampled_inputs.append(input[0:1])\n        sampled_targets.append(target[0:1])\n        sampled_outputs.append(output[0:1])\n        sampled_metadata.append(metadata[0:1])\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if args.local_rank == 0 and i % args.print_freq == 0:\n            progress.display(i)\n\n    if args.local_rank == 0:\n        writer.add_figure(\n            'test/predictions vs. actuals',\n            plot_classes_preds(\n                collate_list(sampled_inputs),\n                collate_list(sampled_targets),\n                collate_list(sampled_outputs),\n                args.class_names,\n                collate_list(sampled_metadata),\n                val_dataset.metadata_map,\n                nrows=min(int(len(val_loader) / 4), 50)\n            ),\n            global_step=epoch\n        )\n\n        # evaluate\n        results = val_dataset.eval(\n            collate_list(all_y_pred),\n            collate_list(all_y_true),\n            collate_list(all_metadata)\n        )\n        print(results[1])\n\n        for k, v in results[0].items():\n            if v == 0 or \"Other\" in k:\n                continue\n            writer.add_scalar(\"test/{}\".format(k), v, global_step=epoch)\n\n        return results[0][args.metric]\n\n\ndef reduce_tensor(tensor, world_size):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.reduce_op.SUM)\n    rt /= world_size\n    return rt\n\n\ndef matplotlib_imshow(img):\n    \"\"\"helper function to show an image\"\"\"\n    img = Denormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(img)\n    img = np.transpose(img.numpy(), (1, 2, 0))\n    plt.imshow(img)\n\n\ndef plot_classes_preds(images, labels, outputs, class_names, metadata, metadata_map, nrows=4):\n    '''\n    Generates matplotlib Figure using a trained network, along with images\n    and labels from a batch, that shows the network's top prediction along\n    with its probability, alongside the actual label, coloring this\n    information based on whether the prediction was correct or not.\n    Uses the \"images_to_probs\" function.\n    '''\n    # convert output probabilities to predicted class\n    _, preds_tensor = torch.max(outputs, 1)\n    preds = np.squeeze(preds_tensor.numpy())\n    probs = [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)]\n\n    # plot the images in the batch, along with predicted and true labels\n    fig = plt.figure(figsize=(12, nrows * 4))\n    domains = get_domain_names(metadata, metadata_map)\n    for idx in np.arange(min(nrows * 4, len(images))):\n        ax = fig.add_subplot(nrows, 4, idx + 1, xticks=[], yticks=[])\n        matplotlib_imshow(images[idx])\n        ax.set_title(\"{0}, {1:.1f}%\\n(label: {2}\\ndomain: {3})\".format(\n            class_names[preds[idx]],\n            probs[idx] * 100.0,\n            class_names[labels[idx]],\n            domains[idx],\n        ), color=(\"green\" if preds[idx] == labels[idx].item() else \"red\"))\n    return fig\n\n\ndef get_domain_names(metadata, metadata_map):\n    return get_domain_ids(metadata)\n\n\ndef get_domain_ids(metadata):\n    return [int(m[0]) for m in metadata]\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_ogb_molpcba/README.md",
    "content": "# Unsupervised Domain Adaptation for WILDS (Molecule classification)\n\n## Installation\n\nIt's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results.\n\nThen, you need to run\n\n```\npip install -r requirements.txt\n```\n\nAt last, you need to install torch_sparse following `https://github.com/rusty1s/pytorch_sparse`.\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [OGB-MolPCBA (WILDS)](https://wilds.stanford.edu/datasets/)\n\n## Supported Methods\n\nTODO\n\n## Usage\n\nThe shell files give all the training scripts we use, e.g.\n\n```\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \\\n  --seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic\n```\n\n## Results\n\n### Performance on WILDS-OGB-MolPCBA (GIN-virtual)\n\n| Methods | Val Avg Precision | Test Avg Precision | GPU Memory Usage(GB)|\n| --- | --- | --- | --- |\n| ERM | 29.0 | 28.0 | 17.8 |\n\n### Visualization\n\nWe use tensorboard to record the training process and visualize the outputs of the models.\n\n```\ntensorboard --logdir=logs\n```\n\n<img src=\"./fig/ogb-molpcba_train_loss.png\" width=\"300\"/>\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_ogb_molpcba/erm.py",
    "content": "\"\"\"\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport argparse\nimport shutil\nimport time\nimport pprint\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.data import DataLoader\nfrom torch.utils.tensorboard import SummaryWriter\nimport wilds\n\nimport utils\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\n\n\ndef main(args):\n    logger = CompleteLogger(args.log, args.phase)\n    writer = SummaryWriter(args.log)\n    pprint.pprint(args)\n\n    print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    # Data loading code\n    # There are no well-developed data augmentation techniques for molecular graphs.\n    train_transform = None\n    val_transform = None\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \\\n        utils.get_dataset('ogb-molpcba', args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, use_unlabeled=args.use_unlabeled, verbose=True)\n\n    # create model\n    print(\"=> creating model '{}'\".format(args.arch))\n    model = utils.get_model(args.arch, args.num_classes)\n    model = model.cuda().to()\n\n    optimizer = torch.optim.Adam(\n        filter(lambda p: p.requires_grad, model.parameters()),\n        lr=args.lr, weight_decay=args.weight_decay\n    )\n\n    # Data loading code\n    train_labeled_sampler = None\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler,\n        collate_fn=train_labeled_dataset.collate)\n\n    # define loss function (criterion)\n    criterion = utils.reduced_bce_logit_loss\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    # start training\n    best_val_metric = 0\n    test_metric = 0\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args)\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            print(n)\n            if n == 'val':\n                tmp_val_metric = utils.validate(d, model, epoch, writer, args)\n            elif n == 'test':\n                tmp_test_metric = utils.validate(d, model, epoch, writer, args)\n\n        # remember best mse and save checkpoint\n        is_best = tmp_val_metric > best_val_metric\n        best_val_metric = max(tmp_val_metric, best_val_metric)\n        torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n        if is_best:\n            test_metric = tmp_test_metric\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n    print(\"best val performance: {:.3f}\".format(best_val_metric))\n    print(\"test performance: {:.3f}\".format(test_metric))\n    logger.close()\n    writer.close()\n\n\ndef train(train_loader, model, criterion, optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(train_loader):\n\n        # compute output\n        output = model(input.cuda())\n        loss = criterion(output, target.cuda())\n\n        # compute gradient and do optimizer step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            losses.update(loss, input.size(0))\n            global_step = epoch * len(train_loader) + i\n\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            writer.add_scalar('train/loss', loss, global_step)\n\n            print('Epoch: [{0}][{1}/{2}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Speed {3:.3f} ({4:.3f})\\t'\n                  'Loss {loss.val:.10f} ({loss.avg:.4f})\\t'.format(\n                epoch, i, len(train_loader),\n                args.batch_size[0] / batch_time.val,\n                args.batch_size[0] / batch_time.avg,\n                batch_time=batch_time, loss=losses))\n\n\nif __name__ == '__main__':\n    model_names = utils.get_model_names()\n    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='ogb-molpcba', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: ogb-molpcba)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[])\n    parser.add_argument('--test-list', nargs='+', default=['val', 'test'])\n    parser.add_argument('--metric', default='ap',\n                        help='metric used to evaluate model performance. (default: average precision)')\n    parser.add_argument('--use-unlabeled', action='store_true',\n                        help='Whether use unlabeled data for training or not.')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='gin_virtual',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: gin_virtual)')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='Learning rate')\n    parser.add_argument('--weight-decay', '--wd', default=0.0, type=float,\n                        metavar='W', help='weight decay (default: 0.0)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=200, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=50, type=int,\n                        metavar='N', help='print frequency (default: 50)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--log', type=str, default='src_only',\n                        help='Where to save logs, checkpoints and debugging images.')\n    parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_ogb_molpcba/erm.sh",
    "content": "# ogb-molpcba\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \\\n  --seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_ogb_molpcba/gin.py",
    "content": "\"\"\"\nAdapted from \"https://github.com/p-lambda/wilds\"\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport torch\nfrom torch_geometric.nn import MessagePassing\nfrom torch_geometric.nn import global_mean_pool, global_add_pool\nimport torch.nn.functional as F\n\nfrom ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder\n\n__all__ = ['gin_virtual']\n\n\nclass GINVirtual(torch.nn.Module):\n    \"\"\"\n    Graph Isomorphism Network augmented with virtual node for multi-task binary graph classification.\n\n    Args:\n        num_tasks (int): number of binary label tasks. default to 128 (number of tasks of ogbg-molpcba)\n        num_layers (int): number of message passing layers of GNN\n        emb_dim (int): dimensionality of hidden channels\n        dropout (float): dropout ratio applied to hidden channels\n\n    Inputs:\n        - batched Pytorch Geometric graph object\n\n    Outputs:\n        - prediction (tensor): float torch tensor of shape (num_graphs, num_tasks)\n    \"\"\"\n\n    def __init__(self, num_tasks=128, num_layers=5, emb_dim=300, dropout=0.5):\n        super(GINVirtual, self).__init__()\n\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.emb_dim = emb_dim\n        self.num_tasks = num_tasks\n        if num_tasks is None:\n            self.d_out = self.emb_dim\n        else:\n            self.d_out = self.num_tasks\n\n        if self.num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        # GNN to generate node embeddings\n        self.gnn_node = GINVirtualNode(num_layers, emb_dim, dropout=dropout)\n\n        # Pooling function to generate whole-graph embeddings\n        self.pool = global_mean_pool\n        if num_tasks is None:\n            self.graph_pred_linear = None\n        else:\n            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)\n\n    def forward(self, batched_data):\n        h_node = self.gnn_node(batched_data)\n\n        h_graph = self.pool(h_node, batched_data.batch)\n\n        if self.graph_pred_linear is None:\n            return h_graph\n        else:\n            return self.graph_pred_linear(h_graph)\n\n\nclass GINVirtualNode(torch.nn.Module):\n    \"\"\"\n    Helper function of Graph Isomorphism Network augmented with virtual node for multi-task binary graph classification\n    This will generate node embeddings.\n\n    Args:\n        num_layers (int): number of message passing layers of GNN\n        emb_dim (int): dimensionality of hidden channels\n        dropout (float, optional): dropout ratio applied to hidden channels. Default: 0.5\n\n    Inputs:\n        - batched Pytorch Geometric graph object\n    Outputs:\n        - node_embedding (tensor): float torch tensor of shape (num_nodes, emb_dim)\n    \"\"\"\n\n    def __init__(self, num_layers, emb_dim, dropout=0.5):\n        super(GINVirtualNode, self).__init__()\n        self.num_layers = num_layers\n        self.dropout = dropout\n\n        if self.num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        self.atom_encoder = AtomEncoder(emb_dim)\n\n        # set the initial virtual node embedding to 0.\n        self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)\n        torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)\n\n        # List of GNNs\n        self.convs = torch.nn.ModuleList()\n        # batch norms applied to node embeddings\n        self.batch_norms = torch.nn.ModuleList()\n\n        # List of MLPs to transform virtual node at every layer\n        self.mlp_virtualnode_list = torch.nn.ModuleList()\n\n        for layer in range(num_layers):\n            self.convs.append(GINConv(emb_dim))\n            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))\n\n        for layer in range(num_layers - 1):\n            self.mlp_virtualnode_list.append(\n                torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),\n                                    torch.nn.ReLU(),\n                                    torch.nn.Linear(2 * emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim),\n                                    torch.nn.ReLU()))\n\n    def forward(self, batched_data):\n        x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch\n\n        # virtual node embeddings for graphs\n        virtualnode_embedding = self.virtualnode_embedding(\n            torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))\n\n        h_list = [self.atom_encoder(x)]\n        for layer in range(self.num_layers):\n            # add message from virtual nodes to graph nodes\n            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]\n\n            # Message passing among graph nodes\n            h = self.convs[layer](h_list[layer], edge_index, edge_attr)\n\n            h = self.batch_norms[layer](h)\n            if layer == self.num_layers - 1:\n                # remove relu for the last layer\n                h = F.dropout(h, self.dropout, training=self.training)\n            else:\n                h = F.dropout(F.relu(h), self.dropout, training=self.training)\n\n            h_list.append(h)\n\n            # update the virtual nodes\n            if layer < self.num_layers - 1:\n                # add message from graph nodes to virtual nodes\n                virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding\n                # transform virtual nodes using MLP\n                virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp),\n                                                  self.dropout, training=self.training)\n\n        node_embedding = h_list[-1]\n\n        return node_embedding\n\n\nclass GINConv(MessagePassing):\n    \"\"\"\n    Graph Isomorphism Network message passing.\n\n    Args:\n        emb_dim (int): node embedding dimensionality\n\n    Inputs:\n        - x (tensor): node embedding\n        - edge_index (tensor): edge connectivity information\n        - edge_attr (tensor): edge feature\n    Outputs:\n        - prediction (tensor): output node embedding\n    \"\"\"\n\n    def __init__(self, emb_dim):\n        super(GINConv, self).__init__(aggr=\"add\")\n\n        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),\n                                       torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))\n        self.eps = torch.nn.Parameter(torch.Tensor([0]))\n\n        self.bond_encoder = BondEncoder(emb_dim=emb_dim)\n\n    def forward(self, x, edge_index, edge_attr):\n        edge_embedding = self.bond_encoder(edge_attr)\n        out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))\n\n        return out\n\n    def message(self, x_j, edge_attr):\n        return F.relu(x_j + edge_attr)\n\n    def update(self, aggr_out):\n        return aggr_out\n\n\ndef gin_virtual(num_tasks, dropout=0.5):\n    model = GINVirtual(num_tasks=num_tasks, dropout=dropout)\n    return model\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_ogb_molpcba/requirements.txt",
    "content": "torch_geometric\nwilds\ntensorflow\ntensorboard\nogb"
  },
  {
    "path": "examples/domain_adaptation/wilds_ogb_molpcba/utils.py",
    "content": "\"\"\"\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport time\nimport sys\n\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader, ConcatDataset\n\nimport wilds\n\nsys.path.append('../../..')\nimport gin as models\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\ndef reduced_bce_logit_loss(y_pred, y_target):\n    \"\"\"\n    Every item of y_target has n elements which may be labeled by nan.\n    Nan values should not be used while calculating loss.\n    So extract elements which are not nan first, and then calculate loss.\n    \"\"\"\n    loss = nn.BCEWithLogitsLoss(reduction='none').cuda()\n    is_labeled = ~torch.isnan(y_target)\n    y_pred = y_pred[is_labeled].float()\n    y_target = y_target[is_labeled].float()\n    metrics = loss(y_pred, y_target)\n    return metrics.mean()\n\n\ndef get_dataset(dataset_name, root, unlabeled_list=('test_unlabeled',), test_list=('test',),\n                transform_train=None, transform_test=None, use_unlabeled=True, verbose=True):\n    labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True)\n    train_labeled_dataset = labeled_dataset.get_subset('train', transform=transform_train)\n\n    if use_unlabeled:\n        unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)\n        train_unlabeled_datasets = [\n            unlabeled_dataset.get_subset(u, transform=transform_train)\n            for u in unlabeled_list\n        ]\n        train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)\n    else:\n        unlabeled_list = []\n        train_unlabeled_datasets = []\n        train_unlabeled_dataset = None\n\n    test_datasets = [\n        labeled_dataset.get_subset(t, transform=transform_test)\n        for t in test_list\n    ]\n\n    if dataset_name == 'ogb-molpcba':\n        num_classes = labeled_dataset.y_size\n    else:\n        num_classes = labeled_dataset.n_classes\n    class_names = list(range(num_classes))\n\n    if verbose:\n        print('Datasets')\n        for n, d in zip(['train'] + unlabeled_list + test_list,\n                        [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):\n            print('\\t{}:{}'.format(n, len(d)))\n        print('\\t#classes:', num_classes)\n\n    return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_classes, class_names\n\n\ndef get_model_names():\n    return sorted(name for name in models.__dict__ if\n                  name.islower() and not name.startswith('__') and callable(models.__dict__[name]))\n\n\ndef get_model(arch, num_classes):\n    if arch in models.__dict__:\n        model = models.__dict__[arch](num_tasks=num_classes)\n    else:\n        raise ValueError('{} is not supported'.format(arch))\n    return model\n\n\ndef collate_list(vec):\n    \"\"\"\n    Adapted from https://github.com/p-lambda/wilds\n    If vec is a list of Tensors, it concatenates them all along the first dimension.\n\n    If vec is a list of lists, it joins these lists together, but does not attempt to\n    recursively collate. This allows each element of the list to be, e.g., its own dict.\n\n    If vec is a list of dicts (with the same keys in each dict), it returns a single dict\n    with the same keys. For each key, it recursively collates all entries in the list.\n    \"\"\"\n    if not isinstance(vec, list):\n        raise TypeError(\"collate_list must take in a list\")\n    elem = vec[0]\n    if torch.is_tensor(elem):\n        return torch.cat(vec)\n    elif isinstance(elem, list):\n        return [obj for sublist in vec for obj in sublist]\n    elif isinstance(elem, dict):\n        return {k: collate_list([d[k] for d in vec]) for k in elem}\n    else:\n        raise TypeError(\"Elements of the list to collate must be tensors or dicts.\")\n\n\ndef validate(val_dataset, model, epoch, writer, args):\n    val_sampler = None\n    val_loader = DataLoader(\n        val_dataset, batch_size=args.batch_size[0], shuffle=False,\n        num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=val_dataset.collate)\n\n    all_y_true = []\n    all_y_pred = []\n    all_metadata = []\n\n    batch_time = AverageMeter('Time', ':6.3f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(val_loader):\n        # compute output\n        with torch.no_grad():\n            output = model(input.cuda()).cpu()\n\n        all_y_true.append(target)\n        all_y_pred.append(output)\n        all_metadata.append(metadata)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if args.local_rank == 0 and i % args.print_freq == 0:\n            progress.display(i)\n\n    # evaluate\n    results = val_dataset.eval(\n        collate_list(all_y_pred),\n        collate_list(all_y_true),\n        collate_list(all_metadata)\n    )\n    print(results[1])\n\n    for k, v in results[0].items():\n        if v == 0 or \"Other\" in k:\n            continue\n        writer.add_scalar(\"test/{}\".format(k), v, global_step=epoch)\n\n    return results[0][args.metric]\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_poverty/README.md",
    "content": "# Unsupervised Domain Adaptation for WILDS (Image Regression)\n\n## Installation\n\nIt's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results.\n\nYou need to install apex following `https://github.com/NVIDIA/apex`. Then run\n\n```\npip install -r requirements.txt\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [PovertyMap (WILDS)](https://wilds.stanford.edu/datasets/)\n\n## Supported Methods\n\nTODO\n\n## Usage\n\nOur code is based\non [https://github.com/NVIDIA/apex/edit/master/examples/imagenet](https://github.com/NVIDIA/apex/edit/master/examples/imagenet)\n. It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and\nVGG, on the WILDS dataset.  \nCommand-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed\nprecision \"optimization levels\" or `opt_level`s.  \nFor a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).\n\nThe shell files give all the training scripts we use, e.g.\n\n```\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \\\n  --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A\n```\n\n## Results\n\n### Performance on WILDS-PovertyMap (ResNet18-MultiSpectral)\n\n| Method | Val Pearson r | Test Pearson r | Val Worst-U/R Pearson r | Test Worst-U/R Pearson r | GPU Memory Usage(GB) |\n| --- | --- | --- | --- | --- | --- |\n| ERM | 0.80 | 0.80 | 0.54 | 0.50 | 3.5 |\n\n### Distributed training\n\nWe uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process.\n\n```\nCUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 erm.py /data/wilds --arch 'resnet18_ms' \\\n    --opt-level O1 --deterministic --log logs/erm/poverty --lr 1e-3 --wd 0.0 --epochs 200 --metric r_wg --split_scheme official -b 64 64 --fold A\n```\n\n### Visualization\n\nWe use tensorboard to record the training process and visualize the outputs of the models.\n\n```\ntensorboard --logdir=logs\n```\n\n<img src=\"./fig/poverty_train_loss.png\" width=\"300\"/>"
  },
  {
    "path": "examples/domain_adaptation/wilds_poverty/erm.py",
    "content": "\"\"\"\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport argparse\nimport os\nimport shutil\nimport time\nimport pprint\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nfrom torch.utils.tensorboard import SummaryWriter\n\ntry:\n    from apex.parallel import DistributedDataParallel as DDP\n    from apex.fp16_utils import *\n    from apex import amp\nexcept ImportError:\n    raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to run this example.\")\n\nimport utils\nfrom utils import Regressor\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\n\n\ndef main(args):\n    writer = None\n    if args.local_rank == 0:\n        logger = CompleteLogger(args.log, args.phase)\n        if args.phase == 'train':\n            writer = SummaryWriter(args.log)\n        pprint.pprint(args)\n        print(\"opt_level = {}\".format(args.opt_level))\n        print(\"keep_batchnorm_fp32 = {}\".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))\n        print(\"loss_scale = {}\".format(args.loss_scale), type(args.loss_scale))\n\n        print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    args.distributed = False\n    if 'WORLD_SIZE' in os.environ:\n        args.distributed = int(os.environ['WORLD_SIZE']) > 1\n\n    args.gpu = 0\n    args.world_size = 1\n\n    if args.distributed:\n        args.gpu = args.local_rank\n        torch.cuda.set_device(args.gpu)\n        torch.distributed.init_process_group(backend='nccl',\n                                             init_method='env://')\n        args.world_size = torch.distributed.get_world_size()\n\n    assert torch.backends.cudnn.enabled, \"Amp requires cudnn backend to be enabled.\"\n\n    if args.channels_last:\n        memory_format = torch.channels_last\n    else:\n        memory_format = torch.contiguous_format\n\n    # Data loading code\n    # Images in povertyMap dataset have 8 channels and traditional data augmentation\n    # methods have no effect on performance.\n    train_transform = None\n    val_transform = None\n    if args.local_rank == 0:\n        print(\"train_transform: \", train_transform)\n        print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_channels = \\\n        utils.get_dataset('poverty', args.data_dir, args.unlabeled_list, args.test_list, args.split_scheme,\n                          train_transform, val_transform, use_unlabeled=args.use_unlabeled,\n                          verbose=args.local_rank == 0, fold=args.fold)\n\n    # create model\n    if args.local_rank == 0:\n        print(\"=> creating model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, args.num_channels)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = Regressor(backbone, pool_layer=pool_layer, finetune=False)\n\n    if args.sync_bn:\n        import apex\n        if args.local_rank == 0:\n            print(\"using apex synced BN\")\n        model = apex.parallel.convert_syncbn_model(model)\n    model = model.cuda().to(memory_format=memory_format)\n\n    optimizer = torch.optim.Adam(\n        filter(lambda p: p.requires_grad, model.parameters()),\n        lr=args.lr, weight_decay=args.weight_decay)\n\n    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,\n    # for convenient interoperation with argparse.\n    model, optimizer = amp.initialize(model, optimizer,\n                                      opt_level=args.opt_level,\n                                      keep_batchnorm_fp32=args.keep_batchnorm_fp32,\n                                      loss_scale=args.loss_scale\n                                      )\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=args.gamma, step_size=args.step_size)\n\n    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.\n    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called\n    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter\n    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.\n    if args.distributed:\n        # By default, apex.parallel.DistributedDataParallel overlaps communication with\n        # computation in the backward pass.\n        # model = DDP(model)\n        # delay_allreduce delays all communication to the end of the backward pass.\n        model = DDP(model, delay_allreduce=True)\n\n    # Data loading code\n    train_labeled_sampler = None\n    if args.distributed:\n        train_labeled_sampler = DistributedSampler(train_labeled_dataset)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler)\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    # start training\n    best_val_metric = 0\n    test_metric = 0\n    for epoch in range(args.epochs):\n        if args.distributed:\n            train_labeled_sampler.set_epoch(epoch)\n\n        lr_scheduler.step(epoch)\n        if args.local_rank == 0:\n            print(lr_scheduler.get_last_lr())\n            writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, model, optimizer, epoch, writer, args)\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            if args.local_rank == 0:\n                print(n)\n            if n == 'val':\n                tmp_val_metric = utils.validate(d, model, epoch, writer, args)\n            elif n == 'test':\n                tmp_test_metric = utils.validate(d, model, epoch, writer, args)\n\n        # remember best mse and save checkpoint\n        if args.local_rank == 0:\n            is_best = tmp_val_metric > best_val_metric\n            best_val_metric = max(tmp_val_metric, best_val_metric)\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if is_best:\n                test_metric = tmp_test_metric\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n    print('best val performance: {:.3f}'.format(best_val_metric))\n    print('test performance: {:.3f}'.format(test_metric))\n\n\ndef train(train_loader, model, optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(train_loader):\n\n        # compute output\n        output, _ = model(input.cuda())\n        loss = F.mse_loss(output, target.cuda())\n\n        # compute gradient and do optimizer step\n        optimizer.zero_grad()\n        with amp.scale_loss(loss, optimizer) as scaled_loss:\n            scaled_loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Average loss and accuracy across processes for logging\n            if args.distributed:\n                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)\n            else:\n                reduced_loss = loss.data\n\n            # to_python_float incurs a host<->device sync\n            losses.update(to_python_float(reduced_loss), input.size(0))\n            global_step = epoch * len(train_loader) + i\n\n            torch.cuda.synchronize()\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            if args.local_rank == 0:\n                writer.add_scalar(\"train/loss\", to_python_float(reduced_loss), global_step)\n\n                print('Epoch: [{0}][{1}/{2}]\\t'\n                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                      'Speed {3:.3f} ({4:.3f})\\t'\n                      'Loss {loss.val:.10f} ({loss.avg:.4f})'.format(\n                    epoch, i, len(train_loader),\n                    args.world_size * args.batch_size[0] / batch_time.val,\n                    args.world_size * args.batch_size[0] / batch_time.avg,\n                    batch_time=batch_time,\n                    loss=losses))\n\n\nif __name__ == '__main__':\n    model_names = utils.get_model_names()\n    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[])\n    parser.add_argument('--test-list', nargs='+', default=['val', 'test'])\n    parser.add_argument('--metric', default='r_wg',\n                        help='metric used to evaluate model performance.'\n                             '(default: worst-U/R Pearson r)')\n    parser.add_argument('--split-scheme', type=str,\n                        help='Identifies how the train/val/test split is constructed.'\n                             'Choices are dataset-specific.')\n    parser.add_argument('--fold', type=str, default='A', choices=['A', 'B', 'C', 'D', 'E'],\n                        help='Fold for poverty dataset. Poverty has 5 different cross validation folds,'\n                             'each splitting the countries differently.')\n    parser.add_argument('--use-unlabeled', action='store_true',\n                        help='Whether use unlabeled data for training or not.')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18_ms',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: resnet18_ms)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor.')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='Learning rate')\n    parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,\n                        metavar='W', help='weight decay (default: 1e-4)')\n    parser.add_argument('--gamma', type=int, default=0.96, help='parameter for StepLR scheduler')\n    parser.add_argument('--step-size', type=int, default=1, help='parameter for StepLR scheduler')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (64, 64))')\n    parser.add_argument('--print-freq', '-p', default=50, type=int,\n                        metavar='N', help='print frequency (default: 50)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int)\n    parser.add_argument('--sync-bn', action='store_true',\n                        help='enabling apex sync BN.')\n    parser.add_argument('--opt-level', type=str)\n    parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)\n    parser.add_argument('--loss-scale', type=str, default=None)\n    parser.add_argument('--channels-last', type=bool, default=False)\n    parser.add_argument('--log', type=str, default='src_only',\n                        help='Where to save logs, checkpoints and debugging images.')\n    parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_poverty/erm.sh",
    "content": "# official split scheme\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \\\n  --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold B \\\n  --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_B\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold C \\\n  --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_C\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold D \\\n  --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_D\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold E \\\n  --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_E\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_poverty/requirements.txt",
    "content": "wilds\ntensorflow\ntensorboard"
  },
  {
    "path": "examples/domain_adaptation/wilds_poverty/resnet_ms.py",
    "content": "\"\"\"\nModified based on torchvision.models.resnet\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport torch.nn as nn\nfrom torchvision import models\nfrom torchvision.models.resnet import BasicBlock, Bottleneck\nimport copy\n\n__all__ = ['resnet18_ms', 'resnet34_ms', 'resnet50_ms', 'resnet101_ms', 'resnet152_ms']\n\n\nclass ResNetMS(models.ResNet):\n    \"\"\"\n    ResNet with input channels parameter, without fully connected layer.\n    \"\"\"\n\n    def __init__(self, in_channels, *args, **kwargs):\n        super(ResNetMS, self).__init__(*args, **kwargs)\n        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self._out_features = self.fc.in_features\n        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        # x = self.avgpool(x)\n        # x = torch.flatten(x, 1)\n        # x = self.fc(x)\n        return x\n\n    @property\n    def out_features(self) -> int:\n        \"\"\"The dimension of output features\"\"\"\n        return self._out_features\n\n    def copy_head(self) -> nn.Module:\n        \"\"\"Copy the origin fully connected layer\"\"\"\n        return copy.deepcopy(self.fc)\n\n\ndef resnet18_ms(num_channels=3):\n    model = ResNetMS(num_channels, BasicBlock, [2, 2, 2, 2])\n    return model\n\n\ndef resnet34_ms(num_channels=3):\n    model = ResNetMS(num_channels, BasicBlock, [3, 4, 6, 3])\n    return model\n\n\ndef resnet50_ms(num_channels=3):\n    model = ResNetMS(num_channels, Bottleneck, [3, 4, 6, 3])\n    return model\n\n\ndef resnet101_ms(num_channels=3):\n    model = ResNetMS(num_channels, Bottleneck, [3, 4, 23, 3])\n    return model\n\n\ndef resnet152_ms(num_channels=3):\n    model = ResNetMS(num_channels, Bottleneck, [3, 8, 36, 3])\n    return model\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_poverty/utils.py",
    "content": "\"\"\"\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport time\nimport sys\n\nfrom typing import Tuple, Optional, List, Dict\nimport torch\nimport torch.nn as nn\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader, ConcatDataset\nfrom torch.utils.data.distributed import DistributedSampler\n\nimport wilds\nimport resnet_ms as models\n\nsys.path.append('../../..')\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\nclass Regressor(nn.Module):\n    \"\"\"A generic Regressor class for domain adaptation.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data\n        bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1\n        head (torch.nn.Module, optional): Any regressor head. Use :class:`torch.nn.Linear` by default\n        finetune (bool): Whether finetune the regressor or train from scratch. Default: True\n\n    .. note::\n        Different regressors are used in different domain adaptation algorithms to achieve better accuracy\n        respectively, and we provide a suggested `Regressor` for different algorithms.\n        Remember they are not the core of algorithms. You can implement your own `Regressor` and combine it with\n        the domain adaptation algorithm in this algorithm library.\n\n    .. note::\n        The learning rate of this regressor is set 10 times to that of the feature extractor for better accuracy\n        by default. If you have other optimization strategies, please over-ride :meth:`~Regressor.get_parameters`.\n\n    Inputs:\n        - x (tensor): input data fed to `backbone`\n\n    Outputs:\n        - predictions: regressor's predictions\n        - features: features after `bottleneck` layer and before `head` layer\n\n    Shape:\n        - Inputs: (minibatch, *) where * means, any number of additional dimensions\n        - predictions: (minibatch, `num_values`)\n        - features: (minibatch, `features_dim`)\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, bottleneck: Optional[nn.Module] = None, bottleneck_dim: Optional[int] = -1,\n                 head: Optional[nn.Module] = None, finetune=True, pool_layer=None):\n        super(Regressor, self).__init__()\n        self.backbone = backbone\n        if pool_layer is None:\n            self.pool_layer = nn.Sequential(\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            )\n        else:\n            self.pool_layer = pool_layer\n        if bottleneck is None:\n            self.bottleneck = nn.Identity()\n            self._features_dim = backbone.out_features\n        else:\n            self.bottleneck = bottleneck\n            assert bottleneck_dim > 0\n            self._features_dim = bottleneck_dim\n\n        if head is None:\n            self.head = nn.Linear(self._features_dim, 1)\n        else:\n            self.head = head\n        self.finetune = finetune\n\n    @property\n    def features_dim(self) -> int:\n        \"\"\"The dimension of features before the final `head` layer\"\"\"\n        return self._features_dim\n\n    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\"\"\"\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        predictions = self.head(f)\n        if self.training:\n            return predictions, f\n        else:\n            return predictions\n\n    def get_parameters(self, base_lr=1.0) -> List[Dict]:\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr},\n        ]\n\n        return params\n\n\ndef get_dataset(dataset_name, root, unlabeled_list=(\"test_unlabeled\",), test_list=(\"test\",),\n                split_scheme='official', transform_train=None, transform_test=None, use_unlabeled=True,\n                verbose=True, **kwargs):\n    labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, split_scheme=split_scheme, **kwargs)\n    train_labeled_dataset = labeled_dataset.get_subset(\"train\", transform=transform_train)\n\n    if use_unlabeled:\n        unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)\n        train_unlabeled_datasets = [\n            unlabeled_dataset.get_subset(u, transform=transform_train)\n            for u in unlabeled_list\n        ]\n        train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)\n    else:\n        unlabeled_list = []\n        train_unlabeled_datasets = []\n        train_unlabeled_dataset = None\n\n    test_datasets = [\n        labeled_dataset.get_subset(t, transform=transform_test)\n        for t in test_list\n    ]\n\n    num_channels = labeled_dataset.get_input(0).size()[0]\n\n    if verbose:\n        print(\"Datasets\")\n        for n, d in zip([\"train\"] + unlabeled_list + test_list,\n                        [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):\n            print(\"\\t{}:{}\".format(n, len(d)))\n\n    return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_channels\n\n\ndef get_model_names():\n    return sorted(name for name in models.__dict__ if\n                  name.islower() and not name.startswith('__') and callable(models.__dict__[name]))\n\n\ndef get_model(arch, num_channels):\n    if arch in models.__dict__:\n        model = models.__dict__[arch](num_channels=num_channels)\n    else:\n        raise ValueError('{} is not supported'.format(arch))\n    return model\n\n\ndef collate_list(vec):\n    \"\"\"\n    Adapted from https://github.com/p-lambda/wilds\n    If vec is a list of Tensors, it concatenates them all along the first dimension.\n\n    If vec is a list of lists, it joins these lists together, but does not attempt to\n    recursively collate. This allows each element of the list to be, e.g., its own dict.\n\n    If vec is a list of dicts (with the same keys in each dict), it returns a single dict\n    with the same keys. For each key, it recursively collates all entries in the list.\n    \"\"\"\n    if not isinstance(vec, list):\n        raise TypeError(\"collate_list must take in a list\")\n    elem = vec[0]\n    if torch.is_tensor(elem):\n        return torch.cat(vec)\n    elif isinstance(elem, list):\n        return [obj for sublist in vec for obj in sublist]\n    elif isinstance(elem, dict):\n        return {k: collate_list([d[k] for d in vec]) for k in elem}\n    else:\n        raise TypeError(\"Elements of the list to collate must be tensors or dicts.\")\n\n\ndef reduce_tensor(tensor, world_size):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.reduce_op.SUM)\n    rt /= world_size\n    return rt\n\n\ndef validate(val_dataset, model, epoch, writer, args):\n    val_sampler = None\n    if args.distributed:\n        val_sampler = DistributedSampler(val_dataset)\n\n    val_loader = DataLoader(\n        val_dataset, batch_size=args.batch_size[0], shuffle=False,\n        num_workers=args.workers, pin_memory=True, sampler=val_sampler)\n\n    all_y_true = []\n    all_y_pred = []\n    all_metadata = []\n\n    batch_time = AverageMeter('Time', ':6.3f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(val_loader):\n        # compute output\n        with torch.no_grad():\n            output = model(input.cuda()).cpu()\n\n        all_y_true.append(target)\n        all_y_pred.append(output)\n        all_metadata.append(metadata)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if args.local_rank == 0 and i % args.print_freq == 0:\n            progress.display(i)\n\n    if args.local_rank == 0:\n\n        # evaluate\n        results = val_dataset.eval(\n            collate_list(all_y_pred),\n            collate_list(all_y_true),\n            collate_list(all_metadata)\n        )\n        print(results[1])\n\n        for k, v in results[0].items():\n            if v == 0 or \"Other\" in k:\n                continue\n            writer.add_scalar(\"test/{}\".format(k), v, global_step=epoch)\n\n        return results[0][args.metric]\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_text/README.md",
    "content": "# Unsupervised Domain Adaptation for WILDS (Text Classification)\n\n## Installation\n\nIt's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results.\n\nYou need to run\n\n```\npip install -r requirements.txt\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [CivilComments (WILDS)](https://wilds.stanford.edu/datasets/)\n- [Amazon (WILDS)](https://wilds.stanford.edu/datasets/)\n\n## Supported Methods\n\nTODO\n\n## Usage\n\nThe shell files give all the training scripts we use, e.g.\n\n```\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d \"civilcomments\" --unlabeled-list \"extra_unlabeled\" \\\n  --uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric \"acc_wg\" \\\n  --seed 0 --deterministic --log logs/erm/civilcomments\n```\n\n## Results\n\n### Performance on WILDS-CivilComments (DistilBert)\n\n| Methods | Val Avg Acc | Val Worst-Group Acc | Test Avg Acc | Test Worst-Group Acc | GPU Memory Usage(GB)|\n| --- | --- | --- | --- | --- | --- |\n| ERM | 89.2 | 67.7 | 88.9 | 68.5 | 6.4 |\n\n### Performance on WILDS-Amazon (DistilBert)\n\n| Methods | Val Avg Acc | Test Avg Acc | Val 10% Acc | Test 10% Acc | GPU Memory Usage(GB)|\n| --- | --- | --- | --- | --- | --- |\n| ERM | 72.6 | 71.6 | 54.7 | 53.8 | 12.8 |\n\n### Visualization\n\nWe use tensorboard to record the training process and visualize the outputs of the models.\n\n```\ntensorboard --logdir=logs\n```\n\n#### WILDS-CivilComments\n\n<img src=\"./fig/civilcomments_train_loss.png\" width=\"300\"/>\n\n#### WILDS-Amazon\n\n<img src=\"./fig/amazon_train_loss.png\" width=\"300\"/>"
  },
  {
    "path": "examples/domain_adaptation/wilds_text/erm.py",
    "content": "\"\"\"\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport argparse\nimport shutil\nimport time\nimport pprint\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.sampler import WeightedRandomSampler\nfrom torch.utils.tensorboard import SummaryWriter\nfrom transformers import AdamW, get_linear_schedule_with_warmup\n\nimport wilds\nfrom wilds.common.grouper import CombinatorialGrouper\n\nimport utils\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.meter import AverageMeter\nfrom tllib.utils.metric import accuracy\n\n\ndef main(args):\n    logger = CompleteLogger(args.log, args.phase)\n    writer = SummaryWriter(args.log)\n    pprint.pprint(args)\n\n    print(\"\\nCUDNN VERSION: {}\\n\".format(torch.backends.cudnn.version()))\n\n    cudnn.benchmark = True\n    if args.deterministic:\n        cudnn.benchmark = False\n        cudnn.deterministic = True\n        torch.manual_seed(args.seed)\n        torch.set_printoptions(precision=10)\n\n    # Data loading code\n    train_transform = utils.get_transform(args.arch, args.max_token_length)\n    val_transform = utils.get_transform(args.arch, args.max_token_length)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_labeled_dataset, train_unlabeled_dataset, test_datasets, labeled_dataset, args.num_classes, args.class_names = \\\n        utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,\n                          train_transform, val_transform, use_unlabeled=args.use_unlabeled, verbose=True)\n\n    # create model\n    print(\"=> using model '{}'\".format(args.arch))\n    model = utils.get_model(args.arch, args.num_classes)\n    model = model.cuda().to()\n\n    # Data loading code\n    train_labeled_sampler = None\n    if args.uniform_over_groups:\n        train_grouper = CombinatorialGrouper(dataset=labeled_dataset, groupby_fields=args.groupby_fields)\n        groups, group_counts = train_grouper.metadata_to_group(train_labeled_dataset.metadata_array, return_counts=True)\n        group_weights = 1 / group_counts\n        weights = group_weights[groups]\n        train_labeled_sampler = WeightedRandomSampler(weights, len(train_labeled_dataset), replacement=True)\n\n    train_labeled_loader = DataLoader(\n        train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),\n        num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler\n    )\n\n    no_decay = ['bias', 'LayerNorm.weight']\n    decay_params = []\n    no_decay_params = []\n    for names, params in model.named_parameters():\n        if any(nd in names for nd in no_decay):\n            no_decay_params.append(params)\n        else:\n            decay_params.append(params)\n    params = [\n        {'params': decay_params, 'weight_decay': args.weight_decay},\n        {'params': no_decay_params, 'weight_decay': 0.0}\n    ]\n    optimizer = AdamW(params, lr=args.lr)\n\n    lr_scheduler = get_linear_schedule_with_warmup(optimizer,\n                                                   num_training_steps=len(train_labeled_loader) * args.epochs,\n                                                   num_warmup_steps=0)\n    lr_scheduler.step_every_batch = True\n    lr_scheduler.use_metric = False\n\n    # define loss function (criterion)\n    criterion = nn.CrossEntropyLoss().cuda()\n\n    if args.phase == 'test':\n        # resume from the latest checkpoint\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n        for n, d in zip(args.test_list, test_datasets):\n            print(n)\n            utils.validate(d, model, -1, writer, args)\n        return\n\n    best_val_metric = 0\n    test_metric = 0\n    for epoch in range(args.epochs):\n        lr_scheduler.step(epoch)\n        print(lr_scheduler.get_last_lr())\n        writer.add_scalar(\"train/lr\", lr_scheduler.get_last_lr()[-1], epoch)\n        # train for one epoch\n        train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args)\n        # evaluate on validation set\n        for n, d in zip(args.test_list, test_datasets):\n            print(n)\n            if n == 'val':\n                tmp_val_metric = utils.validate(d, model, epoch, writer, args)\n            elif n == 'test':\n                tmp_test_metric = utils.validate(d, model, epoch, writer, args)\n\n        # remember best prec@1 and save checkpoint\n        is_best = tmp_val_metric > best_val_metric\n        best_val_metric = max(tmp_val_metric, best_val_metric)\n        torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n        if is_best:\n            test_metric = tmp_test_metric\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n\n    print('best val performance: {:.3f}'.format(best_val_metric))\n    print('test performance: {:.3f}'.format(test_metric))\n    logger.close()\n    writer.close()\n\n\ndef train(train_loader, model, criterion, optimizer, epoch, writer, args):\n    batch_time = AverageMeter('Time', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    top1 = AverageMeter('Top 1', ':3.1f')\n\n    # switch to train mode\n    model.train()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(train_loader):\n\n        # compute output\n        output = model(input.cuda())\n        loss = criterion(output, target.cuda())\n\n        # compute gradient and do optimizer step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if i % args.print_freq == 0:\n            # Every print_freq iterations, check the loss, accuracy, and speed.\n            # For best performance, it doesn't make sense to print these metrics every\n            # iteration, since they incur an allreduce and some host<->device syncs.\n\n            # Measure accuracy\n            prec1, = accuracy(output.data, target.cuda(), topk=(1,))\n\n            losses.update(loss, input.size(0))\n            top1.update(prec1, input.size(0))\n            global_step = epoch * len(train_loader) + i\n\n            batch_time.update((time.time() - end) / args.print_freq)\n            end = time.time()\n\n            writer.add_scalar(\"train/loss\", loss, global_step)\n\n            print('Epoch: [{0}][{1}/{2}]\\t'\n                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n                  'Speed {3:.3f} ({4:.3f})\\t'\n                  'Loss {loss.val:.10f} ({loss.avg:.4f})\\t'\n                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n                epoch, i, len(train_loader),\n                args.batch_size[0] / batch_time.val,\n                args.batch_size[0] / batch_time.avg,\n                batch_time=batch_time,\n                loss=losses, top1=top1))\n\n\nif __name__ == '__main__':\n    model_names = utils.get_model_names()\n    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')\n    # Dataset parameters\n    parser.add_argument('data_dir', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='civilcomments', choices=wilds.supported_datasets,\n                        help='dataset: ' + ' | '.join(wilds.supported_datasets) +\n                             ' (default: civilcomments)')\n    parser.add_argument('--unlabeled-list', nargs='+', default=[])\n    parser.add_argument('--test-list', nargs='+', default=[\"val\", \"test\"])\n    parser.add_argument('--metric', default='acc_wg',\n                        help='metric used to evaluate model performance. (default: worst group accuracy)')\n    parser.add_argument('--uniform-over-groups', action='store_true',\n                        help='sample examples such that batches are uniform over groups')\n    parser.add_argument('--groupby-fields', nargs='+',\n                        help='Group data by given fields. It means that items which have the same'\n                             'values in those fields should be grouped.')\n    parser.add_argument('--use-unlabeled', action='store_true',\n                        help='Whether use unlabeled data for training or not.')\n    # model parameters\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='distilbert-base-uncased',\n                        choices=model_names,\n                        help='model architecture: ' +\n                             ' | '.join(model_names) +\n                             ' (default: distilbert-base-uncased)')\n    parser.add_argument('--max-token-length', type=int, default=300,\n                        help='The maximum size of a sequence.')\n    # Learning rate schedule parameters\n    parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n                        metavar='LR', help='Learning rate.')\n    parser.add_argument('--weight-decay', '--wd', default=0.01, type=float,\n                        metavar='W', help='weight decay (default: 0.01)')\n    # training parameters\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=5, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch-size', default=(16, 16), type=int, nargs='+',\n                        metavar='N', help='mini-batch size per process for source'\n                                          ' and target domain (default: (16, 16))')\n    parser.add_argument('--print-freq', '-p', default=200, type=int,\n                        metavar='N', help='print frequency (default: 200)')\n    parser.add_argument('--deterministic', action='store_true')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument('--log', type=str, default='src_only',\n                        help='Where to save logs, checkpoints and debugging images.')\n    parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis'm only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_text/erm.sh",
    "content": "# civilcomments\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d \"civilcomments\" --unlabeled-list \"extra_unlabeled\" \\\n  --uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric \"acc_wg\" \\\n  --seed 0 --deterministic --log logs/erm/civilcomments\n\n# amazon\nCUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d \"amazon\" --max-token-length 512 \\\n  --lr 1e-5 -b 24 24 --epochs 3 --metric \"10th_percentile_acc\" --seed 0 --deterministic --log logs/erm/amazon\n"
  },
  {
    "path": "examples/domain_adaptation/wilds_text/requirements.txt",
    "content": "wilds\ntensorflow\ntensorboard\ntransformers"
  },
  {
    "path": "examples/domain_adaptation/wilds_text/utils.py",
    "content": "\"\"\"\n@author: Jiaxin Li\n@contact: thulijx@gmail.com\n\"\"\"\nimport time\nimport sys\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DataLoader, ConcatDataset\nfrom transformers import DistilBertTokenizerFast\nfrom transformers import DistilBertForSequenceClassification\n\nimport wilds\n\nsys.path.append('../../..')\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\nclass DistilBertClassifier(DistilBertForSequenceClassification):\n    \"\"\"\n    Adapted from https://github.com/p-lambda/wilds\n    \"\"\"\n\n    def __call__(self, x):\n        input_ids = x[:, :, 0]\n        attention_mask = x[:, :, 1]\n        outputs = super().__call__(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n        )[0]\n        return outputs\n\n\ndef get_transform(arch, max_token_length):\n    \"\"\"\n    Adapted from https://github.com/p-lambda/wilds\n    \"\"\"\n    if arch == 'distilbert-base-uncased':\n        tokenizer = DistilBertTokenizerFast.from_pretrained(arch)\n    else:\n        raise ValueError(\"Model: {arch} not recognized\".format(arch))\n\n    def transform(text):\n        tokens = tokenizer(text, padding='max_length', truncation=True,\n                           max_length=max_token_length, return_tensors='pt')\n        if arch == 'bert_base_uncased':\n            x = torch.stack(\n                (\n                    tokens[\"input_ids\"],\n                    tokens[\"attention_mask\"],\n                    tokens[\"token_type_ids\"],\n                ),\n                dim=2,\n            )\n        elif arch == 'distilbert-base-uncased':\n            x = torch.stack((tokens[\"input_ids\"], tokens[\"attention_mask\"]), dim=2)\n        x = torch.squeeze(x, dim=0)  # First shape dim is always 1\n        return x\n\n    return transform\n\n\ndef get_dataset(dataset_name, root, unlabeled_list=('extra_unlabeled',), test_list=('test',),\n                transform_train=None, transform_test=None, use_unlabeled=True, verbose=True):\n    labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True)\n    train_labeled_dataset = labeled_dataset.get_subset('train', transform=transform_train)\n\n    if use_unlabeled:\n        unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)\n        train_unlabeled_datasets = [\n            unlabeled_dataset.get_subset(u, transform=transform_train)\n            for u in unlabeled_list\n        ]\n        train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)\n    else:\n        unlabeled_list = []\n        train_unlabeled_datasets = []\n        train_unlabeled_dataset = None\n\n    test_datasets = [\n        labeled_dataset.get_subset(t, transform=transform_test)\n        for t in test_list\n    ]\n\n    num_classes = labeled_dataset.n_classes\n    class_names = list(range(num_classes))\n\n    if verbose:\n        print('Datasets')\n        for n, d in zip(['train'] + unlabeled_list + test_list,\n                        [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):\n            print('\\t{}:{}'.format(n, len(d)))\n        print('\\t#classes:', num_classes)\n\n    return train_labeled_dataset, train_unlabeled_dataset, test_datasets, labeled_dataset, num_classes, class_names\n\n\ndef get_model_names():\n    return ['distilbert-base-uncased']\n\n\ndef get_model(arch, num_classes):\n    if arch == 'distilbert-base-uncased':\n        model = DistilBertClassifier.from_pretrained(arch, num_labels=num_classes)\n    else:\n        raise ValueError('{} is not supported'.format(arch))\n    return model\n\n\ndef reduce_tensor(tensor, world_size):\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.reduce_op.SUM)\n    rt /= world_size\n    return rt\n\n\ndef collate_list(vec):\n    \"\"\"\n    Adapted from https://github.com/p-lambda/wilds\n    If vec is a list of Tensors, it concatenates them all along the first dimension.\n\n    If vec is a list of lists, it joins these lists together, but does not attempt to\n    recursively collate. This allows each element of the list to be, e.g., its own dict.\n\n    If vec is a list of dicts (with the same keys in each dict), it returns a single dict\n    with the same keys. For each key, it recursively collates all entries in the list.\n    \"\"\"\n    if not isinstance(vec, list):\n        raise TypeError(\"collate_list must take in a list\")\n    elem = vec[0]\n    if torch.is_tensor(elem):\n        return torch.cat(vec)\n    elif isinstance(elem, list):\n        return [obj for sublist in vec for obj in sublist]\n    elif isinstance(elem, dict):\n        return {k: collate_list([d[k] for d in vec]) for k in elem}\n    else:\n        raise TypeError(\"Elements of the list to collate must be tensors or dicts.\")\n\n\ndef validate(val_dataset, model, epoch, writer, args):\n    val_sampler = None\n    val_loader = DataLoader(\n        val_dataset, batch_size=args.batch_size[0], shuffle=False,\n        num_workers=args.workers, pin_memory=True, sampler=val_sampler)\n\n    all_y_true = []\n    all_y_pred = []\n    all_metadata = []\n\n    batch_time = AverageMeter('Time', ':6.3f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    end = time.time()\n\n    for i, (input, target, metadata) in enumerate(val_loader):\n        # compute output\n        with torch.no_grad():\n            output = model(input.cuda()).cpu()\n\n        all_y_true.append(target)\n        all_y_pred.append(output.argmax(1))\n        all_metadata.append(metadata)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if args.local_rank == 0 and i % args.print_freq == 0:\n            progress.display(i)\n\n    # evaluate\n    results = val_dataset.eval(\n        collate_list(all_y_pred),\n        collate_list(all_y_true),\n        collate_list(all_metadata)\n    )\n    print(results[1])\n\n    for k, v in results[0].items():\n        if v == 0 or \"Other\" in k:\n            continue\n        writer.add_scalar(\"test/{}\".format(k), v, global_step=epoch)\n\n    return results[0][args.metric]\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/README.md",
    "content": "# Domain Generalization for Image Classification\n\n## Installation\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).\nYou also need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)\n- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)\n- [DomainNet](http://ai.bu.edu/M3SDA/)\n- [PACS](https://domaingeneralization.github.io/#data)\n\n## Supported Methods\n\n- [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)\n- [Domain Generalization with MixStyle (MixStyle, 2021 ICLR)](https://arxiv.org/abs/2104.02008)\n- [Learning to Generalize: Meta-Learning for Domain Generalization (MLDG, 2018 AAAI)](https://arxiv.org/pdf/1710.03463.pdf)\n- [Invariant Risk Minimization (IRM)](https://arxiv.org/abs/1907.02893)\n- [Out-of-Distribution Generalization via Risk Extrapolation (VREx, 2021 ICML)](https://arxiv.org/abs/2003.00688)\n- [Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (GroupDRO)](https://arxiv.org/abs/1911.08731)\n- [Deep CORAL: Correlation Alignment for Deep Domain Adaptation (Deep Coral, 2016 ECCV)](https://arxiv.org/abs/1607.01719)\n\n## Usage\n\nThe shell files give the script to reproduce the benchmark with specified hyper-parameters.\nFor example, if you want to train IRM on Office-Home, use the following script\n\n```shell script\n# Train with IRM on Office-Home Ar Cl Rw -> Pr task using ResNet 50.\n# Assume you have put the datasets under the path `data/office-home`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr\n```\nNote that ``-s`` specifies the source domain, ``-t`` specifies the target domain,\nand ``--log`` specifies where to store results.\n\n## Experiment and Results\nFollowing [DomainBed](https://github.com/facebookresearch/DomainBed), we select hyper-parameters based on\nthe model's performance on `training-domain validation set` (first rule in DomainBed).\nConcretely, we save model with the highest accuracy on `training-domain validation set` and then \nload this checkpoint to test on the target domain.\n\nHere are some differences between our implementation and DomainBed. For the model, \nwe do not freeze `BatchNorm2d` layers and do not insert additional `Dropout` layer except for `PACS` dataset. \nFor the optimizer, we use `SGD` with momentum by default and find this usually achieves better performance than `Adam`.\n\n**Notations**\n- ``ERM`` refers to the model trained with data from the source domain.\n- ``Avg`` is the accuracy reported by `TLlib`.\n\n### PACS accuracy on ResNet-50\n\n| Methods  | avg  | A    | C    | P    | S    |\n|----------|------|------|------|------|------|\n| ERM      | 86.4 | 88.5 | 78.4 | 97.2 | 81.4 |\n| IBN      | 87.8 | 88.2 | 84.5 | 97.1 | 81.4 |\n| MixStyle | 87.4 | 87.8 | 82.3 | 95.0 | 84.5 |\n| MLDG     | 87.2 | 88.2 | 81.4 | 96.6 | 82.5 |\n| IRM      | 86.9 | 88.0 | 82.5 | 98.0 | 79.0 |\n| VREx     | 87.0 | 87.2 | 82.3 | 97.4 | 81.0 |\n| GroupDRO | 87.3 | 88.9 | 81.7 | 97.8 | 80.8 |\n| CORAL    | 86.4 | 89.1 | 80.0 | 97.4 | 79.1 |\n\n### Office-Home accuracy on ResNet-50\n\n| Methods  | avg  | A    | C    | P    | R    |\n|----------|------|------|------|------|------|\n| ERM      | 70.8 | 68.3 | 55.9 | 78.9 | 80.0 |\n| IBN      | 69.9 | 67.4 | 55.2 | 77.3 | 79.6 |\n| MixStyle | 71.7 | 66.8 | 58.1 | 78.0 | 79.9 |\n| MLDG     | 70.3 | 65.9 | 57.6 | 78.2 | 79.6 |\n| IRM      | 70.3 | 66.7 | 54.8 | 78.6 | 80.9 |\n| VREx     | 70.2 | 66.9 | 54.9 | 78.2 | 80.9 |\n| GroupDRO | 70.0 | 66.7 | 55.2 | 78.8 | 79.9 |\n| CORAL    | 70.9 | 68.3 | 55.4 | 78.8 | 81.0 |\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{IBN-Net,  \n    author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang},  \n    title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net},  \n    booktitle = {ECCV},  \n    year = {2018}  \n}\n\n@inproceedings{mixstyle,\n    title={Domain Generalization with MixStyle},\n    author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},\n    booktitle={ICLR},\n    year={2021}\n}\n\n@inproceedings{MLDG,\n    title={Learning to Generalize: Meta-Learning for Domain Generalization},\n    author={Li, Da and Yang, Yongxin and Song, Yi-Zhe and Hospedales, Timothy},\n    booktitle={AAAI},\n    year={2018}\n}\n \n@misc{IRM,\n    title={Invariant Risk Minimization}, \n    author={Martin Arjovsky and Léon Bottou and Ishaan Gulrajani and David Lopez-Paz},\n    year={2020},\n    eprint={1907.02893},\n    archivePrefix={arXiv},\n    primaryClass={stat.ML}\n}\n\n@inproceedings{VREx,\n    title={Out-of-Distribution Generalization via Risk Extrapolation (REx)}, \n    author={David Krueger and Ethan Caballero and Joern-Henrik Jacobsen and Amy Zhang and Jonathan Binas and Dinghuai Zhang and Remi Le Priol and Aaron Courville},\n    year={2021},\n    booktitle={ICML},\n}\n\n@inproceedings{GroupDRO,\n    title={Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization}, \n    author={Shiori Sagawa and Pang Wei Koh and Tatsunori B. Hashimoto and Percy Liang},\n    year={2020},\n    booktitle={ICLR}\n}\n\n@inproceedings{deep_coral,\n    title={Deep coral: Correlation alignment for deep domain adaptation},\n    author={Sun, Baochen and Saenko, Kate},\n    booktitle={ECCV},\n    year={2016},\n}\n```"
  },
  {
    "path": "examples/domain_generalization/image_classification/coral.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.alignment.coral import CorrelationAlignmentLoss\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=True, random_gray_scale=True)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,\n                                                   split='train', download=True, transform=train_transform,\n                                                   seed=args.seed)\n    sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=args.n_domains_per_batch)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,\n                              sampler=sampler, drop_last=True)\n    val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',\n                                       download=True, transform=val_transform, seed=args.seed)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',\n                                        download=True, transform=val_transform, seed=args.seed)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"train_dataset_size: \", len(train_dataset))\n    print('val_dataset_size: ', len(val_dataset))\n    print(\"test_dataset_size: \", len(test_dataset))\n    train_iter = ForeverDataIterator(train_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,\n                                       finetune=args.finetune, pool_layer=pool_layer).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)\n\n    # define loss function\n    correlation_alignment_loss = CorrelationAlignmentLoss().to(device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)\n        target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)\n        print(len(source_feature), len(target_feature))\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_val_acc1 = 0.\n    best_test_acc1 = 0.\n    for epoch in range(args.epochs):\n\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, lr_scheduler, correlation_alignment_loss, args.n_domains_per_batch,\n              epoch, args)\n\n        # evaluate on validation set\n        print(\"Evaluate on validation set...\")\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_val_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_val_acc1 = max(acc1, best_val_acc1)\n\n        # evaluate on test set\n        print(\"Evaluate on test set...\")\n        best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test acc on test set = {}\".format(acc1))\n    print(\"oracle acc on test set = {}\".format(best_test_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,\n          correlation_alignment_loss: CorrelationAlignmentLoss, n_domains_per_batch: int, epoch: int,\n          args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    losses_ce = AverageMeter('CELoss', ':3.2f')\n    losses_penalty = AverageMeter('Penalty Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_all, labels_all, _ = next(train_iter)\n        x_all = x_all.to(device)\n        labels_all = labels_all.to(device)\n\n        # compute output\n        y_all, f_all = model(x_all)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # separate into different domains\n        y_all = y_all.chunk(n_domains_per_batch, dim=0)\n        f_all = f_all.chunk(n_domains_per_batch, dim=0)\n        labels_all = labels_all.chunk(n_domains_per_batch, dim=0)\n\n        loss_ce = 0\n        loss_penalty = 0\n        cls_acc = 0\n        for domain_i in range(n_domains_per_batch):\n            # cls loss\n            y_i, labels_i = y_all[domain_i], labels_all[domain_i]\n            loss_ce += F.cross_entropy(y_i, labels_i)\n            # update acc\n            cls_acc += accuracy(y_i, labels_i)[0] / n_domains_per_batch\n            # correlation alignment loss\n            for domain_j in range(domain_i + 1, n_domains_per_batch):\n                f_i = f_all[domain_i]\n                f_j = f_all[domain_j]\n                loss_penalty += correlation_alignment_loss(f_i, f_j)\n\n        # normalize loss\n        loss_ce /= n_domains_per_batch\n        loss_penalty /= n_domains_per_batch * (n_domains_per_batch - 1) / 2\n\n        loss = loss_ce + loss_penalty * args.trade_off\n\n        losses.update(loss.item(), x_all.size(0))\n        losses_ce.update(loss_ce.item(), x_all.size(0))\n        losses_penalty.update(loss_penalty.item(), x_all.size(0))\n        cls_accs.update(cls_acc.item(), x_all.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='CORAL for Domain Generalization')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='PACS',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: PACS)')\n    parser.add_argument('-s', '--sources', nargs='+', default=None,\n                        help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', default=None,\n                        help='target domain(s)')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')\n    parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')\n    # training parameters\n    parser.add_argument('--trade-off', default=1, type=float,\n                        help='the trade off hyper parameter for correlation alignment loss')\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--n-domains-per-batch', default=3, type=int,\n                        help='number of domains in each mini-batch')\n    parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='coral',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/coral.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, PACS\nCUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_P\nCUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_A\nCUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_C\nCUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_S\n\n# ResNet50, Office-Home\nCUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/coral/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/coral/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/coral/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/coral/OfficeHome_Ar\n\n# ResNet50, DomainNet\nCUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/erm.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=True, random_gray_scale=True)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,\n                                                   split='train', download=True, transform=train_transform,\n                                                   seed=args.seed)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,\n                              shuffle=True, num_workers=args.workers, drop_last=True)\n    val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',\n                                       download=True, transform=val_transform, seed=args.seed)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',\n                                        download=True, transform=val_transform, seed=args.seed)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    print(\"train_dataset_size: \", len(train_dataset))\n    print('val_dataset_size: ', len(val_dataset))\n    print(\"test_dataset_size: \", len(test_dataset))\n    train_iter = ForeverDataIterator(train_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,\n                                       finetune=args.finetune, pool_layer=pool_layer).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)\n        target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)\n        print(len(source_feature), len(target_feature))\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_val_acc1 = 0.\n    best_test_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        print(\"Evaluate on validation set...\")\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_val_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_val_acc1 = max(acc1, best_val_acc1)\n\n        # evaluate on test set\n        print(\"Evaluate on test set...\")\n        best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test acc on test set = {}\".format(acc1))\n    print(\"oracle acc on test set = {}\".format(best_test_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int,\n          args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels, _ = next(train_iter)\n        x = x.to(device)\n        labels = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, _ = model(x)\n\n        loss = F.cross_entropy(y, labels)\n\n        cls_acc = accuracy(y, labels)[0]\n        losses.update(loss.item(), x.size(0))\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Baseline for Domain Generalization')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='PACS',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: PACS)')\n    parser.add_argument('-s', '--sources', nargs='+', default=None,\n                        help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', default=None,\n                        help='target domain(s)')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')\n    parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=0, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='baseline',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/erm.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, PACS\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_P\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_A\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_C\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_S\n\n# ResNet50, Office-Home\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/erm/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/erm/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/erm/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/erm/OfficeHome_Ar\n\n# ResNet50, DomainNet\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/groupdro.py",
    "content": "\"\"\"\nAdapted from https://github.com/facebookresearch/DomainBed\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.reweight.groupdro import AutomaticUpdateDomainWeightModule\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=True, random_gray_scale=True)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,\n                                                   split='train', download=True, transform=train_transform,\n                                                   seed=args.seed)\n    sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, args.n_domains_per_batch)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,\n                              sampler=sampler, drop_last=True)\n    val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',\n                                       download=True, transform=val_transform, seed=args.seed)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',\n                                        download=True, transform=val_transform, seed=args.seed)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"train_dataset_size: \", len(train_dataset))\n    print('val_dataset_size: ', len(val_dataset))\n    print(\"test_dataset_size: \", len(test_dataset))\n    train_iter = ForeverDataIterator(train_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,\n                                       finetune=args.finetune, pool_layer=pool_layer).to(device)\n    num_all_domains = len(train_dataset.datasets)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)\n    domain_weight_module = AutomaticUpdateDomainWeightModule(num_all_domains, args.eta, device)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)\n        target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)\n        print(len(source_feature), len(target_feature))\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_val_acc1 = 0.\n    best_test_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, lr_scheduler, domain_weight_module, args.n_domains_per_batch, epoch,\n              args)\n\n        # evaluate on validation set\n        print(\"Evaluate on validation set...\")\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_val_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_val_acc1 = max(acc1, best_val_acc1)\n\n        # evaluate on test set\n        print(\"Evaluate on test set...\")\n        best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test acc on test set = {}\".format(acc1))\n    print(\"oracle acc on test set = {}\".format(best_test_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,\n          domain_weight_module: AutomaticUpdateDomainWeightModule, n_domains_per_batch: int, epoch: int,\n          args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_all, labels_all, domain_labels = next(train_iter)\n        x_all = x_all.to(device)\n        labels_all = labels_all.to(device)\n        domain_labels = domain_labels.to(device)\n\n        # get selected domain idxes\n        domain_labels = domain_labels.chunk(n_domains_per_batch, dim=0)\n        sampled_domain_idxes = [domain_labels[i][0].item() for i in range(n_domains_per_batch)]\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        loss_per_domain = torch.zeros(n_domains_per_batch).to(device)\n        cls_acc = 0\n        for domain_id, (x_per_domain, labels_per_domain) in enumerate(\n                zip(x_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0))):\n            y_per_domain, _ = model(x_per_domain)\n            loss_per_domain[domain_id] = F.cross_entropy(y_per_domain, labels_per_domain)\n            cls_acc += accuracy(y_per_domain, labels_per_domain)[0] / n_domains_per_batch\n\n        # update domain weight\n        domain_weight_module.update(loss_per_domain, sampled_domain_idxes)\n        domain_weight = domain_weight_module.get_domain_weight(sampled_domain_idxes)\n\n        # weighted cls loss\n        loss = (loss_per_domain * domain_weight).sum()\n\n        losses.update(loss.item(), x_all.size(0))\n        cls_accs.update(cls_acc.item(), x_all.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='GroupDRO for Domain Generalization')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='PACS',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: PACS)')\n    parser.add_argument('-s', '--sources', nargs='+', default=None,\n                        help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', default=None,\n                        help='target domain(s)')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')\n    parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')\n    # training parameters\n    parser.add_argument('--eta', default=1e-2, type=float,\n                        help='the eta hyper parameter')\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--n-domains-per-batch', default=3, type=int,\n                        help='number of domains in each mini-batch')\n    parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='groupdro',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/groupdro.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, PACS\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_P\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_A\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_C\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_S\n\n# ResNet50, Office-Home\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Ar\n\n# ResNet50, DomainNet\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/ibn.sh",
    "content": "#!/usr/bin/env bash\n# IBN_ResNet50_b, PACS\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_P\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_A\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_C\nCUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_S\n\n# IBN_ResNet50_b, Office-Home\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Ar\n\n# IBN_ResNet50_b, DomainNet\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/irm.py",
    "content": "\"\"\"\nAdapted from https://github.com/facebookresearch/DomainBed\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\nimport torch.autograd as autograd\n\nimport utils\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\nclass InvariancePenaltyLoss(nn.Module):\n    r\"\"\"Invariance Penalty Loss from `Invariant Risk Minimization <https://arxiv.org/pdf/1907.02893.pdf>`_.\n    We adopt implementation from `DomainBed <https://github.com/facebookresearch/DomainBed>`_. Given classifier\n    output :math:`y` and ground truth :math:`labels`, we split :math:`y` into two parts :math:`y_1, y_2`, corresponding\n    labels are :math:`labels_1, labels_2`. Next we calculate cross entropy loss with respect to a dummy classifier\n    :math:`w`, resulting in :math:`grad_1, grad_2` . Invariance penalty is then :math:`grad_1*grad_2`.\n\n    Inputs:\n        - y: predictions from model\n        - labels: ground truth\n\n    Shape:\n        - y: :math:`(N, C)` where C means the number of classes.\n        - labels: :math:`(N, )` where N mean mini-batch size\n    \"\"\"\n\n    def __init__(self):\n        super(InvariancePenaltyLoss, self).__init__()\n        self.scale = torch.tensor(1.).requires_grad_()\n\n    def forward(self, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:\n        loss_1 = F.cross_entropy(y[::2] * self.scale, labels[::2])\n        loss_2 = F.cross_entropy(y[1::2] * self.scale, labels[1::2])\n        grad_1 = autograd.grad(loss_1, [self.scale], create_graph=True)[0]\n        grad_2 = autograd.grad(loss_2, [self.scale], create_graph=True)[0]\n        penalty = torch.sum(grad_1 * grad_2)\n        return penalty\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=True, random_gray_scale=True)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,\n                                                   split='train', download=True, transform=train_transform,\n                                                   seed=args.seed)\n    sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=args.n_domains_per_batch)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,\n                              sampler=sampler, drop_last=True)\n    val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',\n                                       download=True, transform=val_transform, seed=args.seed)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',\n                                        download=True, transform=val_transform, seed=args.seed)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"train_dataset_size: \", len(train_dataset))\n    print('val_dataset_size: ', len(val_dataset))\n    print(\"test_dataset_size: \", len(test_dataset))\n    train_iter = ForeverDataIterator(train_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,\n                                       finetune=args.finetune, pool_layer=pool_layer).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)\n\n    # define loss function\n    invariance_penalty_loss = InvariancePenaltyLoss().to(device)\n\n    # for simplicity\n    assert args.anneal_iters % args.iters_per_epoch == 0\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)\n        target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)\n        print(len(source_feature), len(target_feature))\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_val_acc1 = 0.\n    best_test_acc1 = 0.\n    for epoch in range(args.epochs):\n        if epoch * args.iters_per_epoch == args.anneal_iters:\n            # reset optimizer to avoid sharp jump in gradient magnitudes\n            optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum,\n                            weight_decay=args.wd, nesterov=True)\n            lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch - args.anneal_iters)\n\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, lr_scheduler, invariance_penalty_loss, args.n_domains_per_batch, epoch,\n              args)\n\n        # evaluate on validation set\n        print(\"Evaluate on validation set...\")\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_val_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_val_acc1 = max(acc1, best_val_acc1)\n\n        # evaluate on test set\n        print(\"Evaluate on test set...\")\n        best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test acc on test set = {}\".format(acc1))\n    print(\"oracle acc on test set = {}\".format(best_test_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,\n          invariance_penalty_loss: InvariancePenaltyLoss, n_domains_per_batch: int, epoch: int,\n          args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    losses_ce = AverageMeter('CELoss', ':3.2f')\n    losses_penalty = AverageMeter('Penalty Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_all, labels_all, _ = next(train_iter)\n        x_all = x_all.to(device)\n        labels_all = labels_all.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_all, _ = model(x_all)\n\n        # cls loss\n        loss_ce = F.cross_entropy(y_all, labels_all)\n        # penalty loss\n        loss_penalty = 0\n        for y_per_domain, labels_per_domain in zip(y_all.chunk(n_domains_per_batch, dim=0),\n                                                   labels_all.chunk(n_domains_per_batch, dim=0)):\n            # normalize loss by domain num\n            loss_penalty += invariance_penalty_loss(y_per_domain, labels_per_domain) / n_domains_per_batch\n\n        global_iter = epoch * args.iters_per_epoch + i\n        if global_iter >= args.anneal_iters:\n            trade_off = args.trade_off\n        else:\n            trade_off = 1\n        loss = loss_ce + loss_penalty * trade_off\n        cls_acc = accuracy(y_all, labels_all)[0]\n\n        losses.update(loss.item(), x_all.size(0))\n        losses_ce.update(loss_ce.item(), x_all.size(0))\n        losses_penalty.update(loss_penalty.item(), x_all.size(0))\n        cls_accs.update(cls_acc.item(), x_all.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='IRM for Domain Generalization')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='PACS',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: PACS)')\n    parser.add_argument('-s', '--sources', nargs='+', default=None,\n                        help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', default=None,\n                        help='target domain(s)')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')\n    parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')\n    # training parameters\n    parser.add_argument('--trade-off', default=1, type=float,\n                        help='the trade off hyper parameter for irm penalty')\n    parser.add_argument('--anneal-iters', default=500, type=int,\n                        help='anneal iterations (trade off is set to 1 during these iterations)')\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--n-domains-per-batch', default=3, type=int,\n                        help='number of domains in each mini-batch')\n    parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='irm',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/irm.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, PACS\nCUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_P\nCUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_A\nCUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_C\nCUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_S\n\n# ResNet50, Office-Home\nCUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/irm/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/irm/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/irm/OfficeHome_Ar\n\n# ResNet50, DomainNet\nCUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/mixstyle.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nimport tllib.normalization.mixstyle.resnet as models\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=True, random_gray_scale=True)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,\n                                                   split='train', download=True, transform=train_transform,\n                                                   seed=args.seed)\n    sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=2)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,\n                              sampler=sampler, drop_last=True)\n    val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',\n                                       download=True, transform=val_transform, seed=args.seed)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',\n                                        download=True, transform=val_transform, seed=args.seed)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    print(\"train_dataset_size: \", len(train_dataset))\n    print('val_dataset_size: ', len(val_dataset))\n    print(\"test_dataset_size: \", len(test_dataset))\n    train_iter = ForeverDataIterator(train_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = models.__dict__[args.arch](mix_layers=args.mix_layers, mix_p=args.mix_p, mix_alpha=args.mix_alpha,\n                                          pretrained=True)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,\n                                       finetune=args.finetune, pool_layer=pool_layer).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)\n        target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)\n        print(len(source_feature), len(target_feature))\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_val_acc1 = 0.\n    best_test_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        print(\"Evaluate on validation set...\")\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_val_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_val_acc1 = max(acc1, best_val_acc1)\n\n        # evaluate on test set\n        print(\"Evaluate on test set...\")\n        best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test acc on test set = {}\".format(acc1))\n    print(\"oracle acc on test set = {}\".format(best_test_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, optimizer,\n          lr_scheduler: CosineAnnealingLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels, _ = next(train_iter)\n        x = x.to(device)\n        labels = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, _ = model(x)\n\n        cls_loss = F.cross_entropy(y, labels)\n        loss = cls_loss\n\n        cls_acc = accuracy(y, labels)[0]\n\n        losses.update(loss.item(), x.size(0))\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description='MixStyle for Domain Generalization')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='PACS',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: PACS)')\n    parser.add_argument('-s', '--sources', nargs='+', default=None,\n                        help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', default=None,\n                        help='target domain(s)')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--mix-layers', nargs='+', help='layers to apply MixStyle')\n    parser.add_argument('--mix-p', default=0.5, type=float, help='probability to apply MixStyle')\n    parser.add_argument('--mix-alpha', default=0.1, type=float, help='parameter alpha for beta distribution')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')\n    parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='mixstyle',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/mixstyle.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, PACS\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s A C S -t P -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_P\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P C S -t A -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_A\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A S -t C -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_C\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A C -t S -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_S\n\n# ResNet50, Office-Home\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Ar\n\n# ResNet50, DomainNet\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/mldg.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\nimport higher\n\nimport utils\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=True, random_gray_scale=True)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,\n                                                   split='train', download=True, transform=train_transform,\n                                                   seed=args.seed)\n    n_domains_per_batch = args.n_support_domains + args.n_query_domains\n    sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=n_domains_per_batch)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,\n                              sampler=sampler, drop_last=True)\n    val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',\n                                       download=True, transform=val_transform, seed=args.seed)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',\n                                        download=True, transform=val_transform, seed=args.seed)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"train_dataset_size: \", len(train_dataset))\n    print('val_dataset_size: ', len(val_dataset))\n    print(\"test_dataset_size: \", len(test_dataset))\n    train_iter = ForeverDataIterator(train_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,\n                                       finetune=args.finetune, pool_layer=pool_layer).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)\n        target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)\n        print(len(source_feature), len(target_feature))\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_val_acc1 = 0.\n    best_test_acc1 = 0.\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, lr_scheduler, epoch, n_domains_per_batch, args)\n\n        # evaluate on validation set\n        print(\"Evaluate on validation set...\")\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_val_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_val_acc1 = max(acc1, best_val_acc1)\n\n        # evaluate on test set\n        print(\"Evaluate on test set...\")\n        best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test acc on test set = {}\".format(acc1))\n    print(\"oracle acc on test set = {}\".format(best_test_acc1))\n    logger.close()\n\n\ndef random_split(x_list, labels_list, n_domains_per_batch, n_support_domains):\n    assert n_support_domains < n_domains_per_batch\n\n    support_domain_idxes = random.sample(range(n_domains_per_batch), n_support_domains)\n    support_domain_list = [(x_list[idx], labels_list[idx]) for idx in range(n_domains_per_batch) if\n                           idx in support_domain_idxes]\n    query_domain_list = [(x_list[idx], labels_list[idx]) for idx in range(n_domains_per_batch) if\n                         idx not in support_domain_idxes]\n\n    return support_domain_list, query_domain_list\n\n\ndef train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int,\n          n_domains_per_batch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels, _ = next(train_iter)\n        x = x.to(device)\n        labels = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # split into support domain and query domain\n        x_list = x.chunk(n_domains_per_batch, dim=0)\n        labels_list = labels.chunk(n_domains_per_batch, dim=0)\n        support_domain_list, query_domain_list = random_split(x_list, labels_list, n_domains_per_batch,\n                                                              args.n_support_domains)\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        with higher.innerloop_ctx(model, optimizer, copy_initial_weights=False) as (inner_model, inner_optimizer):\n            # perform inner optimization\n            for _ in range(args.inner_iters):\n                loss_inner = 0\n                for (x_s, labels_s) in support_domain_list:\n                    y_s, _ = inner_model(x_s)\n                    # normalize loss by support domain num\n                    loss_inner += F.cross_entropy(y_s, labels_s) / args.n_support_domains\n\n                inner_optimizer.step(loss_inner)\n\n            # calculate outer loss\n            loss_outer = 0\n            cls_acc = 0\n\n            # loss on support domains\n            for (x_s, labels_s) in support_domain_list:\n                y_s, _ = model(x_s)\n                # normalize loss by support domain num\n                loss_outer += F.cross_entropy(y_s, labels_s) / args.n_support_domains\n\n            # loss on query domains\n            for (x_q, labels_q) in query_domain_list:\n                y_q, _ = inner_model(x_q)\n                # normalize loss by query domain num\n                loss_outer += F.cross_entropy(y_q, labels_q) * args.trade_off / args.n_query_domains\n                cls_acc += accuracy(y_q, labels_q)[0] / args.n_query_domains\n\n        # update statistics\n        losses.update(loss_outer.item(), args.batch_size)\n        cls_accs.update(cls_acc.item(), args.batch_size)\n\n        # compute gradient and do SGD step\n        loss_outer.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Meta Learning for Domain Generalization')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='PACS',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: PACS)')\n    parser.add_argument('-s', '--sources', nargs='+', default=None,\n                        help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', default=None,\n                        help='target domain(s)')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')\n    parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')\n    # training parameters\n    parser.add_argument('--n-support-domains', type=int, default=1,\n                        help='Number of support domains sampled in each iteration')\n    parser.add_argument('--n-query-domains', type=int, default=2,\n                        help='Number of query domains in each iteration')\n    parser.add_argument('--trade-off', type=float, default=1,\n                        help='hyper parameter beta')\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('--inner-iters', default=1, type=int,\n                        help='Number of iterations in inner loop')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='mldg',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/mldg.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, PACS\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_P\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_A\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_C\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_S\n\n# ResNet50, Office-Home\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Ar\n\n# ResNet50, DomainNet\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/requirements.txt",
    "content": "timm\nwilds\nhigher"
  },
  {
    "path": "examples/domain_generalization/image_classification/utils.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport copy\nimport random\nimport sys\nimport time\nimport timm\nimport tqdm\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms as T\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch.utils.data import Sampler, Subset, ConcatDataset\n\nsys.path.append('../../..')\nfrom tllib.modules import Classifier as ClassifierBase\nimport tllib.vision.datasets as datasets\nimport tllib.vision.models as models\nimport tllib.normalization.ibn as ibn_models\nfrom tllib.vision.transforms import ResizeImage\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\ndef get_model_names():\n    return sorted(name for name in models.__dict__ if\n                  name.islower() and not name.startswith(\"__\") and callable(models.__dict__[name])) + \\\n           sorted(name for name in ibn_models.__dict__ if\n                  name.islower() and not name.startswith(\"__\") and callable(ibn_models.__dict__[name])) + \\\n           timm.list_models()\n\n\ndef get_model(model_name):\n    if model_name in models.__dict__:\n        # load models from tllib.vision.models\n        backbone = models.__dict__[model_name](pretrained=True)\n    elif model_name in ibn_models.__dict__:\n        # load models (with ibn) from tllib.normalization.ibn\n        backbone = ibn_models.__dict__[model_name](pretrained=True)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=True)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n    return backbone\n\n\ndef get_dataset_names():\n    return sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n\nclass ConcatDatasetWithDomainLabel(ConcatDataset):\n    \"\"\"ConcatDataset with domain label\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(ConcatDatasetWithDomainLabel, self).__init__(*args, **kwargs)\n        self.index_to_domain_id = {}\n        domain_id = 0\n        start = 0\n        for end in self.cumulative_sizes:\n            for idx in range(start, end):\n                self.index_to_domain_id[idx] = domain_id\n            start = end\n            domain_id += 1\n\n    def __getitem__(self, index):\n        img, target = super(ConcatDatasetWithDomainLabel, self).__getitem__(index)\n        domain_id = self.index_to_domain_id[index]\n        return img, target, domain_id\n\n\ndef get_dataset(dataset_name, root, task_list, split='train', download=True, transform=None, seed=0):\n    assert split in ['train', 'val', 'test']\n    # load datasets from tllib.vision.datasets\n    # currently only PACS, OfficeHome and DomainNet are supported\n    supported_dataset = ['PACS', 'OfficeHome', 'DomainNet']\n    assert dataset_name in supported_dataset\n\n    dataset = datasets.__dict__[dataset_name]\n\n    train_split_list = []\n    val_split_list = []\n    test_split_list = []\n    # we follow DomainBed and split each dataset randomly into two parts, with 80% samples and 20% samples\n    # respectively, the former (larger) will be used as training set, and the latter will be used as validation set.\n    split_ratio = 0.8\n    num_classes = 0\n\n    # under domain generalization setting, we use all samples in target domain as test set\n    for task in task_list:\n        if dataset_name == 'PACS':\n            all_split = dataset(root=root, task=task, split='all', download=download, transform=transform)\n            num_classes = all_split.num_classes\n        elif dataset_name == 'OfficeHome':\n            all_split = dataset(root=root, task=task, download=download, transform=transform)\n            num_classes = all_split.num_classes\n        elif dataset_name == 'DomainNet':\n            train_split = dataset(root=root, task=task, split='train', download=download, transform=transform)\n            test_split = dataset(root=root, task=task, split='test', download=download, transform=transform)\n            num_classes = train_split.num_classes\n            all_split = ConcatDataset([train_split, test_split])\n\n        train_split, val_split = split_dataset(all_split, int(len(all_split) * split_ratio), seed)\n\n        train_split_list.append(train_split)\n        val_split_list.append(val_split)\n        test_split_list.append(all_split)\n\n    train_dataset = ConcatDatasetWithDomainLabel(train_split_list)\n    val_dataset = ConcatDatasetWithDomainLabel(val_split_list)\n    test_dataset = ConcatDatasetWithDomainLabel(test_split_list)\n\n    dataset_dict = {\n        'train': train_dataset,\n        'val': val_dataset,\n        'test': test_dataset\n    }\n    return dataset_dict[split], num_classes\n\n\ndef split_dataset(dataset, n, seed=0):\n    \"\"\"\n    Return a pair of datasets corresponding to a random split of the given\n    dataset, with n data points in the first dataset and the rest in the last,\n    using the given random seed\n    \"\"\"\n    assert (n <= len(dataset))\n    idxes = list(range(len(dataset)))\n    np.random.RandomState(seed).shuffle(idxes)\n    subset_1 = idxes[:n]\n    subset_2 = idxes[n:]\n    return Subset(dataset, subset_1), Subset(dataset, subset_2)\n\n\ndef validate(val_loader, model, args, device) -> float:\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    top1 = AverageMeter('Acc@1', ':6.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, top1],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target, _) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            loss = F.cross_entropy(output, target)\n\n            # measure accuracy and record loss\n            acc1 = accuracy(output, target)[0]\n            losses.update(loss.item(), images.size(0))\n            top1.update(acc1.item(), images.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1))\n\n    return top1.avg\n\n\ndef get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=True,\n                        random_gray_scale=True):\n    \"\"\"\n    resizing mode:\n        - default: random resized crop with scale factor(0.7, 1.0) and size 224;\n        - cen.crop: take the center crop of 224;\n        - res.|cen.crop: resize the image to 256 and take the center crop of size 224;\n        - res: resize the image to 224;\n        - res2x: resize the image to 448;\n        - res.|crop: resize the image to 256 and take a random crop of size 224;\n        - res.sma|crop: resize the image keeping its aspect ratio such that the\n            smaller side is 256, then take a random crop of size 224;\n        – inc.crop: “inception crop” from (Szegedy et al., 2015);\n        – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.RandomResizedCrop(224, scale=(0.7, 1.0))\n    elif resizing == 'cen.crop':\n        transform = T.CenterCrop(224)\n    elif resizing == 'res.|cen.crop':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224)\n        ])\n    elif resizing == 'res':\n        transform = ResizeImage(224)\n    elif resizing == 'res2x':\n        transform = ResizeImage(448)\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize((256, 256)),\n            T.RandomCrop(224)\n        ])\n    elif resizing == \"res.sma|crop\":\n        transform = T.Compose([\n            T.Resize(256),\n            T.RandomCrop(224)\n        ])\n    elif resizing == 'inc.crop':\n        transform = T.RandomResizedCrop(224)\n    elif resizing == 'cif.crop':\n        transform = T.Compose([\n            T.Resize((224, 224)),\n            T.Pad(28),\n            T.RandomCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if random_color_jitter:\n        transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3))\n    if random_gray_scale:\n        transforms.append(T.RandomGrayscale())\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n    return T.Compose(transforms)\n\n\ndef get_val_transform(resizing='default'):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 224;\n        - res2x: resize the image to 448;\n        - res.|cen.crop: resize the image to 256 and take the center crop of size 224;\n    \"\"\"\n    if resizing == 'default':\n        transform = ResizeImage(224)\n    elif resizing == 'res2x':\n        transform = ResizeImage(448)\n    elif resizing == 'res.|cen.crop':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    return T.Compose([\n        transform,\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n\n\ndef collect_feature(data_loader, feature_extractor: nn.Module, device: torch.device,\n                    max_num_features=None) -> torch.Tensor:\n    \"\"\"\n    Fetch data from `data_loader`, and then use `feature_extractor` to collect features. This function is\n    specific for domain generalization because each element in data_loader is a tuple\n    (images, labels, domain_labels).\n\n    Args:\n        data_loader (torch.utils.data.DataLoader): Data loader.\n        feature_extractor (torch.nn.Module): A feature extractor.\n        device (torch.device)\n        max_num_features (int): The max number of features to return\n\n    Returns:\n        Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\\mathcal{F}|`).\n    \"\"\"\n    feature_extractor.eval()\n    all_features = []\n    with torch.no_grad():\n        for i, (images, target, domain_labels) in enumerate(tqdm.tqdm(data_loader)):\n            if max_num_features is not None and i >= max_num_features:\n                break\n            images = images.to(device)\n            feature = feature_extractor(images).cpu()\n            all_features.append(feature)\n    return torch.cat(all_features, dim=0)\n\n\nclass ImageClassifier(ClassifierBase):\n    \"\"\"ImageClassifier specific for reproducing results of `DomainBed <https://github.com/facebookresearch/DomainBed>`_.\n    You are free to freeze all `BatchNorm2d` layers and insert one additional `Dropout` layer, this can achieve better\n    results for some datasets like PACS but may be worse for others.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract features from data\n        num_classes (int): Number of classes\n        freeze_bn (bool, optional): whether to freeze all `BatchNorm2d` layers. Default: False\n        dropout_p (float, optional): dropout ratio for additional `Dropout` layer, this layer is only used when `freeze_bn` is True. Default: 0.1\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, freeze_bn=False, dropout_p=0.1, **kwargs):\n        super(ImageClassifier, self).__init__(backbone, num_classes, **kwargs)\n        self.freeze_bn = freeze_bn\n        if freeze_bn:\n            self.feature_dropout = nn.Dropout(p=dropout_p)\n\n    def forward(self, x: torch.Tensor):\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        if self.freeze_bn:\n            f = self.feature_dropout(f)\n        predictions = self.head(f)\n        if self.training:\n            return predictions, f\n        else:\n            return predictions\n\n    def train(self, mode=True):\n        super(ImageClassifier, self).train(mode)\n        if self.freeze_bn:\n            for m in self.modules():\n                if isinstance(m, nn.BatchNorm2d):\n                    m.eval()\n\n\nclass RandomDomainSampler(Sampler):\n    r\"\"\"Randomly sample :math:`N` domains, then randomly select :math:`K` samples in each domain to form a mini-batch of\n    size :math:`N\\times K`.\n\n    Args:\n        data_source (ConcatDataset): dataset that contains data from multiple domains\n        batch_size (int): mini-batch size (:math:`N\\times K` here)\n        n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here)\n    \"\"\"\n\n    def __init__(self, data_source: ConcatDataset, batch_size: int, n_domains_per_batch: int):\n        super(Sampler, self).__init__()\n        self.n_domains_in_dataset = len(data_source.cumulative_sizes)\n        self.n_domains_per_batch = n_domains_per_batch\n        assert self.n_domains_in_dataset >= self.n_domains_per_batch\n\n        self.sample_idxes_per_domain = []\n        start = 0\n        for end in data_source.cumulative_sizes:\n            idxes = [idx for idx in range(start, end)]\n            self.sample_idxes_per_domain.append(idxes)\n            start = end\n\n        assert batch_size % n_domains_per_batch == 0\n        self.batch_size_per_domain = batch_size // n_domains_per_batch\n        self.length = len(list(self.__iter__()))\n\n    def __iter__(self):\n        sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain)\n        domain_idxes = [idx for idx in range(self.n_domains_in_dataset)]\n        final_idxes = []\n        stop_flag = False\n        while not stop_flag:\n            selected_domains = random.sample(domain_idxes, self.n_domains_per_batch)\n\n            for domain in selected_domains:\n                sample_idxes = sample_idxes_per_domain[domain]\n                if len(sample_idxes) < self.batch_size_per_domain:\n                    selected_idxes = np.random.choice(sample_idxes, self.batch_size_per_domain, replace=True)\n                else:\n                    selected_idxes = random.sample(sample_idxes, self.batch_size_per_domain)\n                final_idxes.extend(selected_idxes)\n\n                for idx in selected_idxes:\n                    if idx in sample_idxes_per_domain[domain]:\n                        sample_idxes_per_domain[domain].remove(idx)\n\n                remaining_size = len(sample_idxes_per_domain[domain])\n                if remaining_size < self.batch_size_per_domain:\n                    stop_flag = True\n\n        return iter(final_idxes)\n\n    def __len__(self):\n        return self.length\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/vrex.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import CosineAnnealingLR\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.analysis import tsne, a_distance\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                                random_color_jitter=True, random_gray_scale=True)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,\n                                                   split='train', download=True, transform=train_transform,\n                                                   seed=args.seed)\n    sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, args.n_domains_per_batch)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,\n                              sampler=sampler, drop_last=True)\n    val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',\n                                       download=True, transform=val_transform, seed=args.seed)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',\n                                        download=True, transform=val_transform, seed=args.seed)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"train_dataset_size: \", len(train_dataset))\n    print('val_dataset_size: ', len(val_dataset))\n    print(\"test_dataset_size: \", len(test_dataset))\n    train_iter = ForeverDataIterator(train_loader)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,\n                                       finetune=args.finetune, pool_layer=pool_layer).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)\n\n    # for simplicity\n    assert args.anneal_iters % args.iters_per_epoch == 0\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # extract features from both domains\n        feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)\n        source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)\n        target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)\n        print(len(source_feature), len(target_feature))\n        # plot t-SNE\n        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')\n        tsne.visualize(source_feature, target_feature, tSNE_filename)\n        print(\"Saving t-SNE to\", tSNE_filename)\n        # calculate A-distance, which is a measure for distribution discrepancy\n        A_distance = a_distance.calculate(source_feature, target_feature, device)\n        print(\"A-distance =\", A_distance)\n        return\n\n    if args.phase == 'test':\n        acc1 = utils.validate(test_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_val_acc1 = 0.\n    best_test_acc1 = 0.\n    for epoch in range(args.epochs):\n        if epoch * args.iters_per_epoch == args.anneal_iters:\n            # reset optimizer to avoid sharp jump in gradient magnitudes\n            optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum,\n                            weight_decay=args.wd, nesterov=True)\n            lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch - args.anneal_iters)\n\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, lr_scheduler, args.n_domains_per_batch, epoch, args)\n\n        # evaluate on validation set\n        print(\"Evaluate on validation set...\")\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_val_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_val_acc1 = max(acc1, best_val_acc1)\n\n        # evaluate on test set\n        print(\"Evaluate on test set...\")\n        best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))\n\n    # evaluate on test set\n    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    acc1 = utils.validate(test_loader, classifier, args, device)\n    print(\"test acc on test set = {}\".format(acc1))\n    print(\"oracle acc on test set = {}\".format(best_test_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,\n          n_domains_per_batch: int, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    losses_ce = AverageMeter('CELoss', ':3.2f')\n    losses_penalty = AverageMeter('Penalty Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x_all, labels_all, _ = next(train_iter)\n        x_all = x_all.to(device)\n        labels_all = labels_all.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_all, _ = model(x_all)\n\n        loss_ce_per_domain = torch.zeros(n_domains_per_batch).to(device)\n        for domain_id, (y_per_domain, labels_per_domain) in enumerate(\n                zip(y_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0))):\n            loss_ce_per_domain[domain_id] = F.cross_entropy(y_per_domain, labels_per_domain)\n\n        # cls loss\n        loss_ce = loss_ce_per_domain.mean()\n        # penalty loss\n        loss_penalty = ((loss_ce_per_domain - loss_ce) ** 2).mean()\n\n        global_iter = epoch * args.iters_per_epoch + i\n        if global_iter >= args.anneal_iters:\n            trade_off = args.trade_off\n        else:\n            trade_off = 1\n\n        loss = loss_ce + loss_penalty * trade_off\n        cls_acc = accuracy(y_all, labels_all)[0]\n\n        losses.update(loss.item(), x_all.size(0))\n        losses_ce.update(loss_ce.item(), x_all.size(0))\n        losses_penalty.update(loss_penalty.item(), x_all.size(0))\n        cls_accs.update(cls_acc.item(), x_all.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='VREx for Domain Generalization')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA', default='PACS',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()) +\n                             ' (default: PACS)')\n    parser.add_argument('-s', '--sources', nargs='+', default=None,\n                        help='source domain(s)')\n    parser.add_argument('-t', '--targets', nargs='+', default=None,\n                        help='target domain(s)')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    parser.add_argument('--val-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')\n    parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')\n    # training parameters\n    parser.add_argument('--trade-off', default=3, type=float,\n                        help='the trade off hyper parameter for vrex penalty')\n    parser.add_argument('--anneal-iters', default=500, type=int,\n                        help='anneal iterations (trade off is set to 1 during these iterations)')\n    parser.add_argument('-b', '--batch-size', default=36, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 36)')\n    parser.add_argument('--n-domains-per-batch', default=3, type=int,\n                        help='number of domains in each mini-batch')\n    parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='vrex',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/image_classification/vrex.sh",
    "content": "#!/usr/bin/env bash\n# ResNet50, PACS\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_P\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_A\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_C\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_S\n\n# ResNet50, Office-Home\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Pr\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Rw\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Cl\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Ar\n\n# ResNet50, DomainNet\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_c\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_i\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_p\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_q\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_r\nCUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_s\n"
  },
  {
    "path": "examples/domain_generalization/re_identification/README.md",
    "content": "# Domain Generalization for Person Re-Identification\n\n## Installation\n\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You\nalso need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [Market1501](http://zheng-lab.cecs.anu.edu.au/Project/project_reid.html)\n- [DukeMTMC](https://exposing.ai/duke_mtmc/)\n- [MSMT17](https://arxiv.org/pdf/1711.08565.pdf)\n\n## Supported Methods\n\nSupported methods include:\n\n- [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)\n- [Domain Generalization with MixStyle (MixStyle, 2021 ICLR)](https://arxiv.org/abs/2104.02008)\n\n## Usage\n\nThe shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to\ntrain MixStyle on Market1501 -> DukeMTMC task, use the following script\n\n```shell script\n# Train MixStyle on Market1501 -> DukeMTMC task using ResNet 50.\n# Assume you have put the datasets under the path `data/market1501` and `data/dukemtmc`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t DukeMTMC -a resnet50 \\\n--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2Duke\n```\n\n### Experiment and Results\n\nIn our experiments, we adopt modified resnet architecture from [MMT](https://arxiv.org/pdf/2001.01526.pdf>). For a fair\ncomparison, we use standard cross entropy loss and triplet loss in all methods.\n\n**Notations**\n\n- ``Avg`` means the mAP (mean average precision) reported by `TLlib`.\n\n### Cross dataset mAP on ResNet-50\n\n| Methods  | Avg  | Market2Duke | Duke2Market | Market2MSMT | MSMT2Market | Duke2MSMT | MSMT2Duke |\n|----------|------|-------------|-------------|-------------|-------------|-----------|-----------|\n| Baseline | 23.5 | 25.6        | 29.6        | 6.3         | 31.7        | 10.1      | 37.8      |\n| IBN      | 27.0 | 31.5        | 33.3        | 10.4        | 33.6        | 13.7      | 40.0      |\n| MixStyle | 25.5 | 27.2        | 31.6        | 8.2         | 33.9        | 12.4      | 39.9      |\n\n## Citation\n\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{IBN-Net,  \n    author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang},  \n    title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net},  \n    booktitle = {ECCV},  \n    year = {2018}  \n}\n\n@inproceedings{mixstyle,\n    title={Domain Generalization with MixStyle},\n    author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},\n    booktitle={ICLR},\n    year={2021}\n}\n```"
  },
  {
    "path": "examples/domain_generalization/re_identification/baseline.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import DataParallel\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss\nfrom tllib.vision.models.reid.identifier import ReIdentifier\nimport tllib.vision.datasets.reid as datasets\nfrom tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset\nfrom tllib.utils.scheduler import WarmupMultiStepLR\nfrom tllib.utils.metric.reid import validate, visualize_ranked_results\nfrom tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,\n                                                random_horizontal_flip=True,\n                                                random_color_jitter=False,\n                                                random_gray_scale=False)\n    val_transform = utils.get_val_transform(args.height, args.width)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    working_dir = osp.dirname(osp.abspath(__file__))\n    root = osp.join(working_dir, args.root)\n\n    # source dataset\n    source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower()))\n    sampler = RandomMultipleGallerySampler(source_dataset.train, args.num_instances)\n    train_loader = DataLoader(\n        convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),\n                                   root=source_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # target dataset\n    target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower()))\n    test_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),\n                                   root=target_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # create model\n    num_classes = source_dataset.num_train_pids\n    backbone = utils.get_model(args.arch)\n    pool_layer = nn.Identity() if args.no_pool else None\n    model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)\n    model = DataParallel(model)\n\n    # define optimizer and learning rate scheduler\n    optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,\n                     weight_decay=args.weight_decay)\n    lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1,\n                                     warmup_steps=args.warmup_steps)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # plot t-SNE\n        utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,\n                             filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)\n        # visualize ranked results\n        visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                 visualize_dir=logger.visualize_directory, width=args.width, height=args.height,\n                                 rerank=args.rerank)\n        return\n\n    if args.phase == 'test':\n        print(\"Test on source domain:\")\n        validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        print(\"Test on target domain:\")\n        validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        return\n\n    # define loss function\n    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)\n    criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)\n\n    # start training\n    best_val_mAP = 0.\n    best_test_mAP = 0.\n    for epoch in range(args.epochs):\n        # print learning rate\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(train_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args)\n\n        # update learning rate\n        lr_scheduler.step()\n\n        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):\n\n            # evaluate on validation set\n            print(\"Validation on source domain...\")\n            _, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device,\n                                  cmc_flag=True)\n\n            # remember best mAP and save checkpoint\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if val_mAP > best_val_mAP:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n            best_val_mAP = max(val_mAP, best_val_mAP)\n\n            # evaluate on test set\n            print(\"Test on target domain...\")\n            _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                   cmc_flag=True, rerank=args.rerank)\n            best_test_mAP = max(test_mAP, best_test_mAP)\n\n    # evaluate on test set\n    model.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    print(\"Test on target domain:\")\n    _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                           cmc_flag=True, rerank=args.rerank)\n    print(\"test mAP on target = {}\".format(test_mAP))\n    print(\"oracle mAP on target = {}\".format(best_test_mAP))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth,\n          criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_ce = AverageMeter('CeLoss', ':3.2f')\n    losses_triplet = AverageMeter('TripletLoss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        x, _, labels, _ = next(train_iter)\n        x = x.to(device)\n        labels = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, f = model(x)\n\n        # cross entropy loss\n        loss_ce = criterion_ce(y, labels)\n        # triplet loss\n        loss_triplet = criterion_triplet(f, f, labels)\n        loss = loss_ce + loss_triplet * args.trade_off\n\n        cls_acc = accuracy(y, labels)[0]\n        losses_ce.update(loss_ce.item(), x.size(0))\n        losses_triplet.update(loss_triplet.item(), x.size(0))\n        losses.update(loss.item(), x.size(0))\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description=\"Baseline for Domain Generalizable ReID\")\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-s', '--source', type=str, help='source domain')\n    parser.add_argument('-t', '--target', type=str, help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: reid_resnet50)')\n    parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--rate', type=float, default=0.2)\n    # training parameters\n    parser.add_argument('--trade-off', type=float, default=1,\n                        help='trade-off hyper parameter between cross entropy loss and triplet loss')\n    parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')\n    parser.add_argument('-j', '--workers', type=int, default=4)\n    parser.add_argument('-b', '--batch-size', type=int, default=16)\n    parser.add_argument('--height', type=int, default=256, help=\"input height\")\n    parser.add_argument('--width', type=int, default=128, help=\"input width\")\n    parser.add_argument('--num-instances', type=int, default=4,\n                        help=\"each minibatch consist of \"\n                             \"(batch_size // num_instances) identities, and \"\n                             \"each identity has num_instances instances, \"\n                             \"default: 4\")\n    parser.add_argument('--lr', type=float, default=0.00035,\n                        help=\"initial learning rate\")\n    parser.add_argument('--weight-decay', type=float, default=5e-4)\n    parser.add_argument('--epochs', type=int, default=80)\n    parser.add_argument('--warmup-steps', type=int, default=10, help='number of warp-up steps')\n    parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70],\n                        help='milestones for the learning rate decay')\n    parser.add_argument('--eval-step', type=int, default=40)\n    parser.add_argument('--iters-per-epoch', type=int, default=400)\n    parser.add_argument('--print-freq', type=int, default=40)\n    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')\n    parser.add_argument('--rerank', action='store_true', help=\"evaluation only\")\n    parser.add_argument(\"--log\", type=str, default='baseline',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/re_identification/baseline.sh",
    "content": "#!/usr/bin/env bash\n# Market1501 -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \\\n--finetune --seed 0 --log logs/baseline/Market2Duke\n\n# Duke -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a reid_resnet50 \\\n--finetune --seed 0 --log logs/baseline/Duke2Market\n\n# Market1501 -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a reid_resnet50 \\\n--finetune --seed 0 --log logs/baseline/Market2MSMT\n\n# MSMT -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a reid_resnet50 \\\n--finetune --seed 0 --log logs/baseline/MSMT2Market\n\n# Duke -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a reid_resnet50 \\\n--finetune --seed 0 --log logs/baseline/Duke2MSMT\n\n# MSMT -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a reid_resnet50 \\\n--finetune --seed 0 --log logs/baseline/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_generalization/re_identification/ibn.sh",
    "content": "#!/usr/bin/env bash\n# Market1501 -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \\\n--finetune --seed 0 --log logs/ibn/Market2Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \\\n--finetune --seed 0 --log logs/ibn/Market2Duke\n\n# Duke -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \\\n--finetune --seed 0 --log logs/ibn/Duke2Market\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \\\n--finetune --seed 0 --log logs/ibn/Duke2Market\n\n# Market1501 -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_a \\\n--finetune --seed 0 --log logs/ibn/Market2MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_b \\\n--finetune --seed 0 --log logs/ibn/Market2MSMT\n\n# MSMT -> Market1501\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_a \\\n--finetune --seed 0 --log logs/ibn/MSMT2Market\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_b \\\n--finetune --seed 0 --log logs/ibn/MSMT2Market\n\n# Duke -> MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \\\n--finetune --seed 0 --log logs/ibn/Duke2MSMT\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \\\n--finetune --seed 0 --log logs/ibn/Duke2MSMT\n\n# MSMT -> Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \\\n--finetune --seed 0 --log logs/ibn/MSMT2Duke\nCUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \\\n--finetune --seed 0 --log logs/ibn/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_generalization/re_identification/mixstyle.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os.path as osp\n\nimport numpy as np\nimport torch\nfrom torch.nn import DataParallel\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import Adam\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.normalization.mixstyle.sampler import RandomDomainMultiInstanceSampler\nimport tllib.normalization.mixstyle.resnet as models\nfrom tllib.vision.models.reid.identifier import ReIdentifier\nfrom tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss\nimport tllib.vision.datasets.reid as datasets\nfrom tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset\nfrom tllib.vision.models.reid.resnet import ReidResNet\nfrom tllib.utils.scheduler import WarmupMultiStepLR\nfrom tllib.utils.metric.reid import validate, visualize_ranked_results\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,\n                                                random_horizontal_flip=True,\n                                                random_color_jitter=False,\n                                                random_gray_scale=False)\n    val_transform = utils.get_val_transform(args.height, args.width)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    working_dir = osp.dirname(osp.abspath(__file__))\n    root = osp.join(working_dir, args.root)\n\n    # source dataset\n    source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower()))\n    sampler = RandomDomainMultiInstanceSampler(source_dataset.train, batch_size=args.batch_size, n_domains_per_batch=2,\n                                               num_instances=args.num_instances)\n    train_loader = DataLoader(\n        convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),\n        batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),\n                                   root=source_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # target dataset\n    target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower()))\n    test_loader = DataLoader(\n        convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),\n                                   root=target_dataset.images_dir,\n                                   transform=val_transform),\n        batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)\n\n    # create model\n    num_classes = source_dataset.num_train_pids\n    backbone = models.__dict__[args.arch](mix_layers=args.mix_layers, mix_p=args.mix_p, mix_alpha=args.mix_alpha,\n                                          resnet_class=ReidResNet, pretrained=True)\n    model = ReIdentifier(backbone, num_classes, finetune=args.finetune).to(device)\n    model = DataParallel(model)\n\n    # define optimizer and learning rate scheduler\n    optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,\n                     weight_decay=args.weight_decay)\n    lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1,\n                                     warmup_steps=args.warmup_steps)\n\n    # resume from the best checkpoint\n    if args.phase != 'train':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        model.load_state_dict(checkpoint)\n\n    # analysis the model\n    if args.phase == 'analysis':\n        # plot t-SNE\n        utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,\n                             filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)\n        # visualize ranked results\n        visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                 visualize_dir=logger.visualize_directory, width=args.width, height=args.height,\n                                 rerank=args.rerank)\n        return\n\n    if args.phase == 'test':\n        print(\"Test on source domain:\")\n        validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        print(\"Test on target domain:\")\n        validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,\n                 rerank=args.rerank)\n        return\n\n    # define loss function\n    criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)\n    criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)\n\n    # start training\n    best_val_mAP = 0.\n    best_test_mAP = 0.\n    for epoch in range(args.epochs):\n        # print learning rate\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(train_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args)\n\n        # update learning rate\n        lr_scheduler.step()\n\n        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):\n\n            # evaluate on validation set\n            print(\"Validation on source domain...\")\n            _, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device,\n                                  cmc_flag=True)\n\n            # remember best mAP and save checkpoint\n            torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))\n            if val_mAP > best_val_mAP:\n                shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n            best_val_mAP = max(val_mAP, best_val_mAP)\n\n            # evaluate on test set\n            print(\"Test on target domain...\")\n            _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                                   cmc_flag=True, rerank=args.rerank)\n            best_test_mAP = max(test_mAP, best_test_mAP)\n\n    # evaluate on test set\n    model.load_state_dict(torch.load(logger.get_checkpoint_path('best')))\n    print(\"Test on target domain:\")\n    _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,\n                           cmc_flag=True, rerank=args.rerank)\n    print(\"test mAP on target = {}\".format(test_mAP))\n    print(\"oracle mAP on target = {}\".format(best_test_mAP))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth,\n          criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses_ce = AverageMeter('CeLoss', ':3.2f')\n    losses_triplet = AverageMeter('TripletLoss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n\n    for i in range(args.iters_per_epoch):\n        x, _, labels, _ = next(train_iter)\n        x = x.to(device)\n        labels = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, f = model(x)\n\n        # cross entropy loss\n        loss_ce = criterion_ce(y, labels)\n        # triplet loss\n        loss_triplet = criterion_triplet(f, f, labels)\n        loss = loss_ce + loss_triplet * args.trade_off\n\n        cls_acc = accuracy(y, labels)[0]\n        losses_ce.update(loss_ce.item(), x.size(0))\n        losses_triplet.update(loss_triplet.item(), x.size(0))\n        losses.update(loss.item(), x.size(0))\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    architecture_names = sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    )\n    dataset_names = sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n    parser = argparse.ArgumentParser(description=\"MixStyle for Domain Generalizable ReID\")\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-s', '--source', type=str, help='source domain')\n    parser.add_argument('-t', '--target', type=str, help='target domain')\n    parser.add_argument('--train-resizing', type=str, default='default')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=architecture_names,\n                        help='backbone architecture: ' +\n                             ' | '.join(architecture_names) +\n                             ' (default: resnet50)')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--rate', type=float, default=0.2)\n    parser.add_argument('--mix-layers', nargs='+', help='layers to apply MixStyle')\n    parser.add_argument('--mix-p', default=0.5, type=float, help='probability to apply MixStyle')\n    parser.add_argument('--mix-alpha', default=0.1, type=float, help='parameter alpha for beta distribution')\n    # training parameters\n    parser.add_argument('--trade-off', type=float, default=1,\n                        help='trade-off hyper parameter between cross entropy loss and triplet loss')\n    parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')\n    parser.add_argument('-j', '--workers', type=int, default=4)\n    parser.add_argument('-b', '--batch-size', type=int, default=16)\n    parser.add_argument('--height', type=int, default=256, help=\"input height\")\n    parser.add_argument('--width', type=int, default=128, help=\"input width\")\n    parser.add_argument('--num-instances', type=int, default=4,\n                        help=\"each minibatch consist of \"\n                             \"(batch_size // num_instances) identities, and \"\n                             \"each identity has num_instances instances, \"\n                             \"default: 4\")\n    parser.add_argument('--lr', type=float, default=0.00035,\n                        help=\"learning rate of new parameters, for pretrained \")\n    parser.add_argument('--weight-decay', type=float, default=5e-4)\n    parser.add_argument('--epochs', type=int, default=80)\n    parser.add_argument('--warmup-steps', type=int, default=10, help='number of warm-up steps')\n    parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70],\n                        help='milestones for the learning rate decay')\n    parser.add_argument('--eval-step', type=int, default=40)\n    parser.add_argument('--iters-per-epoch', type=int, default=400)\n    parser.add_argument('--print-freq', type=int, default=40)\n    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')\n    parser.add_argument('--rerank', action='store_true', help=\"evaluation only\")\n    parser.add_argument(\"--log\", type=str, default='mixstyle',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test', 'analysis'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/domain_generalization/re_identification/mixstyle.sh",
    "content": "#!/usr/bin/env bash\n# Market1501 -> Duke\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t DukeMTMC -a resnet50 \\\n--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2Duke\n\n# Duke -> Market1501\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t Market1501 -a resnet50 \\\n--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2Market\n\n# Market1501 -> MSMT\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t MSMT17 -a resnet50 \\\n--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2MSMT\n\n# MSMT -> Market1501\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t Market1501 -a resnet50 \\\n--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Market\n\n# Duke -> MSMT\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t MSMT17 -a resnet50 \\\n--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2MSMT\n\n# MSMT -> Duke\nCUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t DukeMTMC -a resnet50 \\\n--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Duke\n"
  },
  {
    "path": "examples/domain_generalization/re_identification/requirements.txt",
    "content": "timm\nopencv-python"
  },
  {
    "path": "examples/domain_generalization/re_identification/utils.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport sys\nimport timm\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms as T\n\nsys.path.append('../../..')\nfrom tllib.utils.metric.reid import extract_reid_feature\nfrom tllib.utils.analysis import tsne\nimport tllib.vision.models.reid as models\nimport tllib.normalization.ibn as ibn_models\n\n\ndef get_model_names():\n    return sorted(name for name in models.__dict__ if\n                  name.islower() and not name.startswith(\"__\") and callable(models.__dict__[name])) + \\\n           sorted(name for name in ibn_models.__dict__ if\n                  name.islower() and not name.startswith(\"__\") and callable(ibn_models.__dict__[name])) + \\\n           timm.list_models()\n\n\ndef get_model(model_name):\n    if model_name in models.__dict__:\n        # load models from tllib.vision.models\n        backbone = models.__dict__[model_name](pretrained=True)\n    elif model_name in ibn_models.__dict__:\n        # load models (with ibn) from tllib.normalization.ibn\n        backbone = ibn_models.__dict__[model_name](pretrained=True)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=True)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n    return backbone\n\n\ndef get_train_transform(height, width, resizing='default', random_horizontal_flip=True, random_color_jitter=False,\n                        random_gray_scale=False):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to (height, width), zero-pad it by 10 on each size, the take a random crop of\n            (height, width)\n        - res: resize the image to(height, width)\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            T.Resize((height, width), interpolation=3),\n            T.Pad(10),\n            T.RandomCrop((height, width))\n        ])\n    elif resizing == 'res':\n        transform = T.Resize((height, width), interpolation=3)\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if random_color_jitter:\n        transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3))\n    if random_gray_scale:\n        transforms.append(T.RandomGrayscale())\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n    return T.Compose(transforms)\n\n\ndef get_val_transform(height, width):\n    return T.Compose([\n        T.Resize((height, width), interpolation=3),\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n\n\ndef visualize_tsne(source_loader, target_loader, model, filename, device, n_data_points_per_domain=3000):\n    \"\"\"Visualize features from different domains using t-SNE. As we can have very large number of samples in each\n    domain, only `n_data_points_per_domain` number of samples are randomly selected in each domain.\n    \"\"\"\n    source_feature_dict = extract_reid_feature(source_loader, model, device, normalize=True)\n    source_feature = torch.stack(list(source_feature_dict.values())).cpu()\n    source_feature = source_feature[torch.randperm(len(source_feature))]\n    source_feature = source_feature[:n_data_points_per_domain]\n\n    target_feature_dict = extract_reid_feature(target_loader, model, device, normalize=True)\n    target_feature = torch.stack(list(target_feature_dict.values())).cpu()\n    target_feature = target_feature[torch.randperm(len(target_feature))]\n    target_feature = target_feature[:n_data_points_per_domain]\n\n    tsne.visualize(source_feature, target_feature, filename, source_color='cornflowerblue', target_color='darkorange')\n    print('T-SNE process is done, figure is saved to {}'.format(filename))\n"
  },
  {
    "path": "examples/model_selection/README.md",
    "content": "# Model Selection\n\n## Installation\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).\nYou need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\n- [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)\n- [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/)\n- [CIFAR10](http://www.cs.utoronto.ca/~kriz/cifar.html)\n- [CIFAR100](http://www.cs.utoronto.ca/~kriz/cifar.html)\n- [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html)\n- [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/)\n- [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)\n- [SUN397](https://vision.princeton.edu/projects/2010/SUN/)\n\n## Supported Methods\n\nSupported methods include:\n\n- [An Information-theoretic Approach to Transferability in Task Transfer Learning (H-Score, ICIP 2019)](http://yangli-feasibility.com/home/media/icip-19.pdf)\n\n- [LEEP: A New Measure to Evaluate Transferability of Learned Representations (LEEP, ICML 2020)](http://proceedings.mlr.press/v119/nguyen20b/nguyen20b.pdf)\n\n- [Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models for Transfer Learning (LogME, ICML 2021)](https://arxiv.org/pdf/2102.11005.pdf)\n\n- [Negative Conditional Entropy in `Transferability and Hardness of Supervised Classification Tasks (NCE, ICCV 2019)](https://arxiv.org/pdf/1908.08142v1.pdf)\n    \n## Experiment and Results\n\n### Model Ranking on image classification tasks\n\nThe shell files give the scripts to ranking pre-trained models on a given dataset. For example, if you want to use LogME to calculate the transfer performance of ResNet50(ImageNet pre-trained) on Aircraft, use the following script\n\n```shell script\n# Using LogME to ranking pre-trained ResNet50 on Aircraft\n# Assume you have put the datasets under the path `data/cub200`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features\n```\n\nWe use LEEP, NCE HScore and LogME to compute scores by applying 10 pre-trained models to different datasets. The correlation([Weighted kendall Tau](https://vigna.di.unimi.it/ftp/papers/WeightedTau.pdf)/Pearson Correlation) between scores and fine-tuned accuracies\nare presented.\n\n#### Model Ranking Benchmark on Aircraft\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE    |\n|--------------|---------------|--------|--------|-------|--------|\n| GoogleNet    |          82.7 | 28.37  | -4.310 | 0.934 | -4.248 |\n| Inception V3 |          88.8 | 43.89  | -4.202 | 0.953 | -4.170 |\n| ResNet50     |          86.6 | 46.23  | -4.215 | 0.946 | -4.201 |\n| ResNet101    |          85.6 | 46.13  | -4.230 | 0.948 | -4.222 |\n| ResNet152    |          85.3 | 46.25  | -4.230 | 0.950 | -4.229 |\n| DenseNet121  |          85.4 | 31.53  | -4.228 | 0.938 | -4.215 |\n| DenseNet169  |          84.5 | 41.81  | -4.245 | 0.943 | -4.270 |\n| Densenet201  |          84.6 | 46.01  | -4.206 | 0.942 | -4.189 |\n| MobileNet V2 |          82.8 | 34.43  | -4.198 | 0.941 | -4.208 |\n| MNasNet      |          72.8 | 35.28  | -4.192 | 0.948 | -4.195 |\n| Pearson Corr |             - |  0.688 |  0.127 | 0.582 | 0.173 |\n| Weighted Tau |             - |  0.664 | -0.264 | 0.595 |  0.002 |\n\n#### Model Ranking Benchmark on Caltech101\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE    |\n|--------------|---------------|--------|--------|-------|--------|\n| GoogleNet    |          91.7 | 75.88  | -1.462 | 1.228 | -0.665 |\n| Inception V3 |          94.3 | 93.73  | -1.119 | 1.387 | -0.560 |\n| ResNet50     |          91.8 | 91.65  | -1.020 | 1.262 | -0.616 |\n| ResNet101    |          93.1 | 92.54  | -0.899 | 1.305 | -0.603 |\n| ResNet152    |          93.2 | 92.91  | -0.875 | 1.324 | -0.605 |\n| DenseNet121  |          91.9 | 75.02  | -0.979 | 1.172 | -0.609 |\n| DenseNet169  |          92.5 | 86.37  | -0.864 | 1.212 | -0.580 |\n| Densenet201  |          93.4 | 89.90  | -0.914 | 1.228 | -0.590 |\n| MobileNet V2 |          89.1 | 75.82  | -1.115 | 1.150 | -0.693 |\n| MNasNet      |          91.5 | 77.00  | -1.043 | 1.178 | -0.690 |\n| Pearson Corr |             - |  0.748 |  0.324 | 0.794 |  0.843 |\n| Weighted Tau |             - |  0.721 |  0.127 | 0.697 |  0.810 |\n\n#### Model Ranking Benchmark on CIFAR10\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE    |\n|--------------|---------------|--------|--------|-------|--------|\n| GoogleNet    |         96.2  | 5.911  | -1.385 | 0.293 | -1.139 |\n| Inception V3 |         97.5  | 6.363  | -1.259 | 0.349 | -1.060 |\n| ResNet50     |         96.8  | 6.567  | -1.010 | 0.388 | -1.007 |\n| ResNet101    |         97.7  | 6.901  | -0.829 | 0.463 | -0.838 |\n| ResNet152    |         97.9  | 6.945  | -0.838 | 0.469 | -0.851 |\n| DenseNet121  |         97.2  | 6.210  | -1.035 | 0.302 | -1.006 |\n| DenseNet169  |         97.4  | 6.547  | -0.934 | 0.343 | -0.946 |\n| Densenet201  |         97.4  | 6.706  | -0.888 | 0.369 | -0.866 |\n| MobileNet V2 |         95.7  | 5.928  | -1.100 | 0.291 | -1.089 |\n| MNasNet      |         96.8  | 6.018  | -1.066 | 0.304 | -1.086 |\n| Pearson Corr |             - |  0.839 |  0.604 | 0.733 |  0.786 |\n| Weighted Tau |             - |  0.800 |  0.638 | 0.785 |  0.714 |\n\n#### Model Ranking Benchmark on CIFAR100\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE    |\n|--------------|---------------|--------|--------|-------|--------|\n| GoogleNet    |         83.2  | 29.33  | -3.234 | 1.037 | -2.751 |\n| Inception V3 |         86.6  | 36.47  | -2.995 | 1.070 | -2.615 |\n| ResNet50     |         84.5  | 40.20  | -2.612 | 1.099 | -2.516 |\n| ResNet101    |         87.0  | 43.80  | -2.365 | 1.130 | -2.285 |\n| ResNet152    |         87.6  | 44.19  | -2.410 | 1.133 | -2.369 |\n| DenseNet121  |         84.8  | 32.13  | -2.665 | 1.029 | -2.504 |\n| DenseNet169  |         85.0  | 37.51  | -2.494 | 1.051 | -2.418 |\n| Densenet201  |         86.0  | 39.75  | -2.470 | 1.061 | -2.305 |\n| MobileNet V2 |         80.8  | 30.36  | -2.800 | 1.039 | -2.653 |\n| MNasNet      |         83.9  | 32.05  | -2.732 | 1.051 | -2.643 |\n| Pearson Corr | -             | 0.815  | 0.513  | 0.698 | 0.705  |\n| Weighted Tau | -             | 0.775  | 0.659  | 0.790 | 0.654  |\n\n#### Model Ranking Benchmark on DTD\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE   |\n|--------------|---------------|--------|--------|-------|-------|\n| GoogleNet    |          73.6 | 34.61  | -2.333 | 0.682 | 0.682 |\n| Inception V3 |          77.2 | 57.17  | -2.135 | 0.691 | 0.691 |\n| ResNet50     |          75.2 | 78.26  | -1.985 | 0.695 | 0.695 |\n| ResNet101    |          76.2 | 117.23 | -1.974 | 0.689 | 0.689 |\n| ResNet152    |          75.4 | 32.30  | -1.924 | 0.698 | 0.698 |\n| DenseNet121  |          74.9 | 35.23  | -2.001 | 0.670 | 0.670 |\n| DenseNet169  |          74.8 | 43.36  | -1.817 | 0.686 | 0.686 |\n| Densenet201  |          74.5 | 45.96  | -1.926 | 0.689 | 0.689 |\n| MobileNet V2 |          72.9 | 37.99  | -2.098 | 0.664 | 0.664 |\n| MNasNet      |          72.8 | 38.03  | -2.033 | 0.679 | 0.679 |\n| Pearson Corr | -             | 0.532  | 0.217  | 0.617 | 0.471 |\n| Weighted Tau | -             | 0.416  | -0.004 | 0.550 | 0.083 |\n\n#### Model Ranking Benchmark on OxfordIIITPets\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE    |\n|--------------|---------------|--------|--------|-------|--------|\n| GoogleNet    |         91.9  | 28.02  | -1.064 | 0.854 | -0.815 |\n| Inception V3 |         93.5  | 33.29  | -0.888 | 1.119 | -0.711 |\n| ResNet50     |         92.5  | 32.55  | -0.805 | 0.952 | -0.721 |\n| ResNet101    |         94.0  | 32.76  | -0.769 | 0.985 | -0.717 |\n| ResNet152    |         94.5  | 32.86  | -0.732 | 1.009 | -0.679 |\n| DenseNet121  |         92.9  | 27.09  | -0.837 | 0.797 | -0.753 |\n| DenseNet169  |         93.1  | 30.09  | -0.779 | 0.829 | -0.699 |\n| Densenet201  |         92.8  | 31.25  | -0.810 | 0.860 | -0.716 |\n| MobileNet V2 |         90.5  | 27.83  | -0.902 | 0.765 | -0.822 |\n| MNasNet      |         89.4  | 27.95  | -0.854 | 0.785 | -0.812 |\n| Pearson Corr | -             | 0.427  | -0.127 | 0.589 | 0.501  |\n| Weighted Tau | -             | 0.425  | -0.143 | 0.502 | 0.119  |\n\n#### Model Ranking Benchmark on StanfordCars\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE    |\n|--------------|---------------|--------|--------|-------|--------|\n| GoogleNet    |         91.0  | 41.47  | -4.612 | 1.246 | -4.312 |\n| Inception V3 |         92.3  | 73.68  | -4.268 | 1.259 | -4.110 |\n| ResNet50     |         91.7  | 72.94  | -4.366 | 1.253 | -4.221 |\n| ResNet101    |         91.7  | 73.98  | -4.281 | 1.255 | -4.218 |\n| ResNet152    |         92.0  | 76.17  | -4.215 | 1.260 | -4.142 |\n| DenseNet121  |         91.5  | 45.82  | -4.437 | 1.249 | -4.271 |\n| DenseNet169  |         91.5  | 63.40  | -4.286 | 1.252 | -4.175 |\n| Densenet201  |         91.0  | 70.50  | -4.319 | 1.251 | -4.151 |\n| MobileNet V2 |         91.0  | 51.12  | -4.463 | 1.250 | -4.306 |\n| MNasNet      |         88.5  | 51.91  | -4.423 | 1.254 | -4.338 |\n| Pearson Corr | -             | 0.503  | 0.433  | 0.274 | 0.695  |\n| Weighted Tau | -             | 0.638  | 0.703  | 0.654 | 0.750  |\n\n#### Model Ranking Benchmark on SUN397\n\n| Model        | Finetuned Acc | HScore | LEEP   | LogME | NCE    |\n|--------------|---------------|--------|--------|-------|--------|\n| GoogleNet    |         62.0  | 71.35  | -3.744 | 1.621 | -3.055 |\n| Inception V3 |         65.7  | 114.21 | -3.372 | 1.648 | -2.844 |\n| ResNet50     |         64.7  | 110.39 | -3.198 | 1.638 | -2.894 |\n| ResNet101    |         64.8  | 113.63 | -3.103 | 1.642 | -2.837 |\n| ResNet152    |         66.0  | 116.51 | -3.056 | 1.646 | -2.822 |\n| DenseNet121  |         62.3  | 72.16  | -3.311 | 1.614 | -2.945 |\n| DenseNet169  |         63.0  | 95.80  | -3.165 | 1.623 | -2.903 |\n| Densenet201  |         64.7  | 103.09 | -3.205 | 1.624 | -2.896 |\n| MobileNet V2 |         60.5  | 75.90  | -3.338 | 1.617 | -2.968 |\n| MNasNet      |         60.7  | 80.91  | -3.234 | 1.625 | -2.933 |\n| Pearson Corr | -             | 0.913  | 0.428  | 0.824 | 0.782  |\n| Weighted Tau | -             | 0.918  | 0.581  | 0.748 | 0.873  |\n\n## Citation\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{bao_information-theoretic_2019,\n\ttitle = {An Information-Theoretic Approach to Transferability in Task Transfer Learning},\n\tbooktitle = {ICIP},\n\tauthor = {Bao, Yajie and Li, Yang and Huang, Shao-Lun and Zhang, Lin and Zheng, Lizhong and Zamir, Amir and Guibas, Leonidas},\n\tyear = {2019}\n}\n\n@inproceedings{nguyen_leep:_2020,\n\ttitle = {LEEP: A New Measure to Evaluate Transferability of Learned Representations},\n\tbooktitle = {ICML},\n\tauthor = {Nguyen, Cuong and Hassner, Tal and Seeger, Matthias and Archambeau, Cedric},\n\tyear = {2020}\n}\n\n@inproceedings{you_logme:_2021,\n\ttitle = {LogME: Practical Assessment of Pre-trained Models for Transfer Learning},\n\tbooktitle = {ICML},\n\tauthor = {You, Kaichao and Liu, Yong and Wang, Jianmin and Long, Mingsheng},\n\tyear = {2021}\n}\n\n@inproceedings{tran_transferability_2019,\n\ttitle = {Transferability and hardness of supervised classification tasks},\n\tbooktitle = {ICCV},\n\tauthor = {Tran, Anh T. and Nguyen, Cuong V. and Hassner, Tal},\n\tyear = {2019}\n}\n\n```"
  },
  {
    "path": "examples/model_selection/hscore.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\n\nimport os\nimport sys\nimport argparse\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nsys.path.append('../..')\nfrom tllib.ranking import h_score\n\nsys.path.append('.')\nimport utils\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = utils.Logger(args.data, args.arch, 'results_hscore')\n    print(args)\n    print(f'Calc Transferabilities of {args.arch} on {args.data}')\n\n    try:\n        features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))\n        predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))\n        targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))\n        print('Loaded extracted features')\n    except:\n        print('Conducting feature extraction')\n        data_transform = utils.get_transform(resizing=args.resizing)\n        print(\"data_transform: \", data_transform)\n        model = utils.get_model(args.arch, args.pretrained).to(device)\n        score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,\n                                                       args.num_samples_per_classes)\n        score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,\n                                  pin_memory=True)\n        print(f'Using {len(score_dataset)} samples for ranking')\n        features, predictions, targets = utils.forwarding_dataset(score_loader, model,\n                                                                  layer=eval(f'model.{args.layer}'), device=device)\n        if args.save_features:\n            np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)\n            np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)\n            np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)\n\n    print('Conducting transferability calculation')\n    result = h_score(features, targets)\n\n    logger.write(\n        f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\\n')\n    print(f'Results saved in {logger.get_result_dir()}')\n    logger.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Ranking pre-trained models with HScore')\n\n    # dataset\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N', help='mini-batch size (default: 48)')\n    parser.add_argument('--resizing', default='res.', type=str)\n\n    # model\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='model to be ranked: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('-l', '--layer', default='fc',\n                        help='before which layer features are extracted')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument(\"--save_features\", action='store_true',\n                        help=\"whether to save extracted features\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/model_selection/hscore.sh",
    "content": "#!/usr/bin/env bash\n\n# Ranking Pre-trained Model\n# ======================================================================================================================\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# FGVCAircraft\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Caltech101\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Oxford-IIIT\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features\n"
  },
  {
    "path": "examples/model_selection/leep.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\n\nimport os\nimport sys\nimport argparse\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nsys.path.append('../..')\nfrom tllib.ranking import log_expected_empirical_prediction as leep\n\nsys.path.append('.')\nimport utils\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = utils.Logger(args.data, args.arch, 'results_leep')\n    print(args)\n    print(f'Calc Transferabilities of {args.arch} on {args.data}')\n\n    try:\n        features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))\n        predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))\n        targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))\n        print('Loaded extracted features')\n    except:\n        print('Conducting feature extraction')\n        data_transform = utils.get_transform(resizing=args.resizing)\n        print(\"data_transform: \", data_transform)\n        model = utils.get_model(args.arch, args.pretrained).to(device)\n        score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,\n                                                       args.num_samples_per_classes)\n        score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,\n                                  pin_memory=True)\n        print(f'Using {len(score_dataset)} samples for ranking')\n        features, predictions, targets = utils.forwarding_dataset(score_loader, model,\n                                                                  layer=eval(f'model.{args.layer}'), device=device)\n        if args.save_features:\n            np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)\n            np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)\n            np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)\n\n    print('Conducting transferability calculation')\n    result = leep(predictions, targets)\n\n    logger.write(\n        f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\\n')\n    logger.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(\n        description='Ranking pre-trained models with LEEP (Log Expected Empirical Prediction)')\n\n    # dataset\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N', help='mini-batch size (default: 48)')\n    parser.add_argument('--resizing', default='res.', type=str)\n\n    # model\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='model to be ranked: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('-l', '--layer', default='fc',\n                        help='before which layer features are extracted')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument(\"--save_features\", action='store_true',\n                        help=\"whether to save extracted features\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/model_selection/leep.sh",
    "content": "#!/usr/bin/env bash\n\n# Ranking Pre-trained Model\n# ======================================================================================================================\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# FGVCAircraft\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Caltech101\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Oxford-IIIT\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features\n"
  },
  {
    "path": "examples/model_selection/logme.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\n\nimport os\nimport sys\nimport argparse\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nsys.path.append('../..')\nfrom tllib.ranking import log_maximum_evidence as logme\n\nsys.path.append('.')\nimport utils\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = utils.Logger(args.data, args.arch, 'results_logme')\n    print(args)\n    print(f'Calc Transferabilities of {args.arch} on {args.data}')\n\n    try:\n        features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))\n        predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))\n        targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))\n        print('Loaded extracted features')\n    except:\n        print('Conducting feature extraction')\n        data_transform = utils.get_transform(resizing=args.resizing)\n        print(\"data_transform: \", data_transform)\n        model = utils.get_model(args.arch, args.pretrained).to(device)\n        score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,\n                                                       args.num_samples_per_classes)\n        score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,\n                                  pin_memory=True)\n        print(f'Using {len(score_dataset)} samples for ranking')\n        features, predictions, targets = utils.forwarding_dataset(score_loader, model,\n                                                                  layer=eval(f'model.{args.layer}'), device=device)\n        if args.save_features:\n            np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)\n            np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)\n            np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)\n\n    print('Conducting transferability calculation')\n    result = logme(features, targets)\n\n    logger.write(\n        f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\\n')\n    logger.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Ranking pre-trained models with LogME (Log Maximum Evidence)')\n\n    # dataset\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N', help='mini-batch size (default: 48)')\n    parser.add_argument('--resizing', default='res.', type=str)\n\n    # model\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='model to be ranked: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('-l', '--layer', default='fc',\n                        help='before which layer features are extracted')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument(\"--save_features\", action='store_true',\n                        help=\"whether to save extracted features\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/model_selection/logme.sh",
    "content": "#!/usr/bin/env bash\n\n# Ranking Pre-trained Model\n# ======================================================================================================================\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# FGVCAircraft\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Caltech101\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Oxford-IIIT\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features\n"
  },
  {
    "path": "examples/model_selection/nce.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\n\nimport os\nimport sys\nimport argparse\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nsys.path.append('../..')\nfrom tllib.ranking import negative_conditional_entropy as nce\n\nsys.path.append('.')\nimport utils\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args):\n    logger = utils.Logger(args.data, args.arch, 'results_nce')\n    print(args)\n    print(f'Calc Transferabilities of {args.arch} on {args.data}')\n\n    try:\n        features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))\n        predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))\n        targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))\n        print('Loaded extracted features')\n    except:\n        print('Conducting feature extraction')\n        data_transform = utils.get_transform(resizing=args.resizing)\n        print(\"data_transform: \", data_transform)\n        model = utils.get_model(args.arch, args.pretrained).to(device)\n        score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,\n                                                       args.num_samples_per_classes)\n        score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,\n                                  pin_memory=True)\n        print(f'Using {len(score_dataset)} samples for ranking')\n        features, predictions, targets = utils.forwarding_dataset(score_loader, model,\n                                                                  layer=eval(f'model.{args.layer}'), device=device)\n        if args.save_features:\n            np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)\n            np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)\n            np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)\n\n    print('Conducting transferability calculation')\n    result = nce(np.argmax(predictions, axis=1), targets)\n\n    logger.write(\n        f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\\n')\n    logger.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Ranking pre-trained models with NCE (Negative Conditional Entropy)')\n\n    # dataset\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N', help='mini-batch size (default: 48)')\n    parser.add_argument('--resizing', default='res.', type=str)\n\n    # model\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='model to be ranked: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('-l', '--layer', default='fc',\n                        help='before which layer features are extracted')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument(\"--save_features\", action='store_true',\n                        help=\"whether to save extracted features\")\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/model_selection/nce.sh",
    "content": "#!/usr/bin/env bash\n\n# Ranking Pre-trained Model\n# ======================================================================================================================\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# FGVCAircraft\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Caltech101\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# Oxford-IIIT\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features\nCUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features\n"
  },
  {
    "path": "examples/model_selection/requirements.txt",
    "content": "timm\nnumba\n"
  },
  {
    "path": "examples/model_selection/utils.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\nimport random\nimport sys, os\n\nimport torch\nimport timm\nfrom torch.utils.data import Subset\nimport torchvision.transforms as T\nimport torch.nn.functional as F\nimport torchvision.models as models\n\nsys.path.append('../../..')\nimport tllib.vision.datasets as datasets\n\n\nclass Logger(object):\n    \"\"\"Writes stream output to external text file.\n\n    Args:\n        filename (str): the file to write stream output\n        stream: the stream to read from. Default: sys.stdout\n    \"\"\"\n\n    def __init__(self, data_name, model_name, metric_name, stream=sys.stdout):\n        self.terminal = stream\n        self.save_dir = os.path.join(data_name, model_name)  # save intermediate features/outputs\n        self.result_dir = os.path.join(data_name, f'{metric_name}.txt')  # save ranking results\n        os.makedirs(self.save_dir, exist_ok=True)\n        self.log = open(self.result_dir, 'a')\n\n    def write(self, message):\n        self.terminal.write(message)\n        self.log.write(message)\n        self.flush()\n\n    def get_save_dir(self):\n        return self.save_dir\n\n    def get_result_dir(self):\n        return self.result_dir\n\n    def flush(self):\n        self.terminal.flush()\n        self.log.flush()\n\n    def close(self):\n        self.terminal.close()\n        self.log.close()\n\n\ndef get_model_names():\n    return sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    ) + timm.list_models()\n\n\ndef forwarding_dataset(score_loader, model, layer, device):\n    \"\"\"\n    A forward forcasting on full dataset\n\n    :params score_loader: the dataloader for scoring transferability\n    :params model: the model for scoring transferability\n    :params layer: before which layer features are extracted, for registering hooks\n    \n    returns\n        features: extracted features of model\n        prediction: probability outputs of model\n        targets: ground-truth labels of dataset\n    \"\"\"\n    features = []\n    outputs = []\n    targets = []\n\n    def hook_fn_forward(module, input, output):\n        features.append(input[0].detach().cpu())\n        outputs.append(output.detach().cpu())\n\n    forward_hook = layer.register_forward_hook(hook_fn_forward)\n\n    model.eval()\n    with torch.no_grad():\n        for _, (data, target) in enumerate(score_loader):\n            targets.append(target)\n            data = data.to(device)\n            _ = model(data)\n\n    forward_hook.remove()\n\n    features = torch.cat([x for x in features]).numpy()\n    outputs = torch.cat([x for x in outputs])\n    predictions = F.softmax(outputs, dim=-1).numpy()\n    targets = torch.cat([x for x in targets]).numpy()\n\n    return features, predictions, targets\n\n\ndef get_model(model_name, pretrained=True, pretrained_checkpoint=None):\n    if model_name in get_model_names():\n        # load models from common.vision.models\n        backbone = models.__dict__[model_name](pretrained=pretrained)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=pretrained)\n    if pretrained_checkpoint:\n        print(\"=> loading pre-trained model from '{}'\".format(pretrained_checkpoint))\n        pretrained_dict = torch.load(pretrained_checkpoint)\n        backbone.load_state_dict(pretrained_dict, strict=False)\n    return backbone\n\n\ndef get_dataset(dataset_name, root, transform, sample_rate=100, num_samples_per_classes=None, split='train'):\n    \"\"\"\n    When sample_rate < 100,  e.g. sample_rate = 50, use 50% data to train the model.\n    Otherwise,\n        if num_samples_per_classes is not None, e.g. 5, then sample 5 images for each class, and use them to train the model;\n        otherwise, keep all the data.\n    \"\"\"\n    dataset = datasets.__dict__[dataset_name]\n    if sample_rate < 100:\n        score_dataset = dataset(root=root, split=split, sample_rate=sample_rate, download=True, transform=transform)\n        num_classes = len(score_dataset.classes)\n    else:\n        score_dataset = dataset(root=root, split=split, download=True, transform=transform)\n        num_classes = len(score_dataset.classes)\n        if num_samples_per_classes is not None:\n            samples = list(range(len(score_dataset)))\n            random.shuffle(samples)\n            samples_len = min(num_samples_per_classes * num_classes, len(score_dataset))\n            print(\"Origin dataset:\", len(score_dataset), \"Sampled dataset:\", samples_len, \"Ratio:\",\n                  float(samples_len) / len(score_dataset))\n            dataset = Subset(score_dataset, samples[:samples_len])\n    return score_dataset, num_classes\n\n\ndef get_transform(resizing='res.'):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take the center crop of size 224;\n        – res.: resize the image to 224\n        – res.|crop: resize the image such that the smaller side is of size 256 and\n            then take a central crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            T.Resize(256),\n            T.CenterCrop(224),\n        ])\n    elif resizing == 'res.':\n        transform = T.Resize((224, 224))\n    elif resizing == 'res.299':\n        transform = T.Resize((299, 299))\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize((256, 256)),\n            T.CenterCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    return T.Compose([\n        transform,\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/README.md",
    "content": "# Semi-Supervised Learning for Image Classification\n\n## Installation\n\nIt’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You\nalso need to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [FOOD-101](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/)\n- [CIFAR10](http://www.cs.utoronto.ca/~kriz/cifar.html)\n- [CIFAR100](http://www.cs.utoronto.ca/~kriz/cifar.html)\n- [CUB200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)\n- [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)\n- [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)\n- [SUN397](https://vision.princeton.edu/projects/2010/SUN/)\n- [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html)\n- [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/)\n- [OxfordFlowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)\n- [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/)\n\n## Supported Methods\n\nSupported methods include:\n\n- [Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (Pseudo Label, ICML 2013)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf)\n- [Temporal Ensembling for Semi-Supervised Learning (Pi Model, ICLR 2017)](https://arxiv.org/abs/1610.02242)\n- [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (Mean Teacher, NIPS 2017)](https://arxiv.org/abs/1703.01780)\n- [Self-Training With Noisy Student Improves ImageNet Classification (Noisy Student, CVPR 2020)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_Self-Training_With_Noisy_Student_Improves_ImageNet_Classification_CVPR_2020_paper.pdf)\n- [Unsupervised Data Augmentation for Consistency Training (UDA, NIPS 2020)](https://arxiv.org/pdf/1904.12848v4.pdf)\n- [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (FixMatch, NIPS 2020)](https://arxiv.org/abs/2001.07685)\n- [Self-Tuning for Data-Efficient Deep Learning (Self-Tuning, ICML 2021)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf)\n- [FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling (FlexMatch, NIPS 2021)](https://arxiv.org/abs/2110.08263)\n- [Debiased Learning From Naturally Imbalanced Pseudo-Labels (DebiasMatch, CVPR 2022)](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Debiased_Learning_From_Naturally_Imbalanced_Pseudo-Labels_CVPR_2022_paper.pdf)\n- [Debiased Self-Training for Semi-Supervised Learning (DST)](https://arxiv.org/abs/2202.07136)\n\n## Usage\n\n### Semi-supervised learning with supervised pre-trained model\n\nThe shell files give the script to train with supervised pre-trained model with specified hyper-parameters. For example,\nif you want to train UDA on CIFAR100, use the following script\n\n```shell script\n# Semi-supervised learning on CIFAR100 (ResNet50, 400labels).\n# Assume you have put the datasets under the path `data/cifar100`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cifar100_4_labels_per_class\n```\n\nFollowing common practice in semi-supervised learning, we select a class-balanced subset as the labeled dataset and\ntreat other samples as unlabeled data. In the above command, `num-samples-per-class` specifies how many labeled samples\nfor each class. Note that the labeled subset is **deterministic with the same random seed**. Hence, if you want to\ncompare different algorithms with the same labeled subset, you can simply pass in the same random seed.\n\n### Semi-supervised learning with unsupervised pre-trained model\n\nTake MoCo as an example.\n\n1. Download MoCo pretrained checkpoints from https://github.com/facebookresearch/moco\n2. Convert the format of the MoCo checkpoints to the standard format of pytorch\n\n```shell\nmkdir checkpoints\npython convert_moco_to_pretrained.py checkpoints/moco_v2_800ep_pretrain.pth.tar checkpoints/moco_v2_800ep_backbone.pth checkpoints/moco_v2_800ep_fc.pth\n```\n\n3. Start training\n\n```shell\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_4_labels_per_class\n```\n\n## Experiment and Results\n\n**Notations**\n\n- ``Avg`` is the accuracy reported by `TLlib`.\n- ``ERM`` refers to the model trained with only labeled data.\n- ``Oracle`` refers to the model trained using all data as labeled data.\n\nBelow are the results of implemented methods. Other than _Oracle_, we randomly sample 4 labels per category.\n\n### ImageNet Supervised Pre-training (ResNet-50)\n\n| Methods      | Food101 | CIFAR10 | CIFAR100 | CUB200 | Aircraft | Cars | SUN397 | DTD  | Pets | Flowers | Caltech | Avg  |\n|--------------|---------|---------|----------|--------|----------|------|--------|------|------|---------|---------|------|\n| ERM          | 33.6    | 59.4    | 47.9     | 48.6   | 29.0     | 37.1 | 40.9   | 50.5 | 82.2 | 87.6    | 82.2    | 54.5 |\n| Pseudo Label | 36.9    | 62.8    | 52.5     | 54.9   | 30.4     | 40.4 | 41.7   | 54.1 | 89.6 | 93.5    | 85.1    | 58.4 |\n| Pi Model     | 34.2    | 66.9    | 48.5     | 47.9   | 26.7     | 37.4 | 40.9   | 51.9 | 83.5 | 92.0    | 82.2    | 55.6 |\n| Mean Teacher | 40.4    | 78.1    | 58.5     | 52.8   | 32.0     | 45.6 | 40.2   | 53.8 | 86.8 | 92.8    | 83.7    | 60.4 |\n| UDA          | 41.9    | 73.0    | 59.8     | 55.4   | 33.5     | 42.7 | 42.1   | 49.7 | 88.0 | 93.4    | 85.3    | 60.4 |\n| FixMatch     | 36.2    | 74.5    | 58.0     | 52.6   | 27.1     | 44.8 | 40.8   | 50.2 | 87.8 | 93.6    | 83.2    | 59.0 |\n| Self Tuning  | 41.4    | 70.9    | 57.2     | 60.5   | 37.0     | 59.8 | 43.5   | 51.7 | 88.4 | 93.5    | 89.1    | 63.0 |\n| FlexMatch    | 48.1    | 94.2    | 69.2     | 65.1   | 38.0     | 55.3 | 50.2   | 55.6 | 91.5 | 94.6    | 89.4    | 68.3 |\n| DebiasMatch  | 57.1    | 92.4    | 69.0     | 66.2   | 41.5     | 65.4 | 48.3   | 54.2 | 90.2 | 95.4    | 89.3    | 69.9 |\n| DST          | 58.1    | 93.5    | 67.8     | 68.6   | 44.9     | 68.6 | 47.0   | 56.3 | 91.5 | 95.1    | 90.3    | 71.1 |\n| Oracle       | 85.5    | 97.5    | 86.3     | 81.1   | 85.1     | 91.1 | 64.1   | 68.8 | 93.2 | 98.1    | 92.6    | 85.8 |\n\n### ImageNet Unsupervised Pre-training (ResNet-50, MoCo v2)\n\n| Methods      | Food101 | CIFAR10 | CIFAR100 | CUB200 | Aircraft | Cars | SUN397 | DTD  | Pets | Flowers | Caltech | Avg  |\n|--------------|---------|---------|----------|--------|----------|------|--------|------|------|---------|---------|------|\n| ERM          | 33.5    | 63.0    | 50.8     | 39.4   | 28.1     | 40.3 | 40.7   | 53.7 | 65.4 | 87.5    | 82.8    | 53.2 |\n| Pseudo Label | 33.6    | 71.9    | 53.8     | 42.7   | 30.9     | 51.2 | 41.2   | 55.2 | 69.3 | 94.2    | 86.2    | 57.3 |\n| Pi Model     | 32.7    | 77.9    | 50.9     | 33.6   | 27.2     | 34.4 | 41.1   | 54.9 | 66.7 | 91.4    | 84.1    | 54.1 |\n| Mean Teacher | 36.8    | 79.0    | 56.7     | 43.0   | 33.0     | 53.9 | 39.5   | 54.5 | 67.8 | 92.7    | 83.3    | 58.2 |\n| UDA          | 39.5    | 91.3    | 60.0     | 41.9   | 36.2     | 39.7 | 41.7   | 51.5 | 71.0 | 93.7    | 86.5    | 59.4 |\n| FixMatch     | 44.3    | 86.1    | 58.0     | 42.7   | 38.0     | 55.4 | 42.4   | 53.1 | 67.9 | 95.2    | 83.4    | 60.6 |\n| Self Tuning  | 34.0    | 63.6    | 51.7     | 43.3   | 32.2     | 50.2 | 40.7   | 52.7 | 68.2 | 91.8    | 87.7    | 56.0 |\n| FlexMatch    | 50.2    | 96.6    | 69.2     | 49.4   | 41.3     | 62.5 | 47.2   | 54.5 | 72.4 | 94.8    | 89.4    | 66.1 |\n| DebiasMatch  | 54.2    | 95.5    | 68.1     | 49.1   | 40.9     | 73.0 | 47.6   | 54.4 | 76.6 | 95.5    | 88.7    | 67.6 |\n| DST          | 57.1    | 95.0    | 68.2     | 53.6   | 47.7     | 72.0 | 46.8   | 56.0 | 76.3 | 95.6    | 90.1    | 68.9 |\n| Oracle       | 87.0    | 98.2    | 87.9     | 80.6   | 88.7     | 92.7 | 63.9   | 73.8 | 90.6 | 97.8    | 93.1    | 86.8 |\n\n## TODO\n\n1. support multi-gpu training\n2. add training from scratch code and results\n\n## Citation\n\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{pseudo_label,\n    title={Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks},\n    author={Lee, Dong-Hyun and others},\n    booktitle={ICML},\n    year={2013}\n}\n\n@inproceedings{pi_model,\n    title={Temporal ensembling for semi-supervised learning},\n    author={Laine, Samuli and Aila, Timo},\n    booktitle={ICLR},\n    year={2017}\n}\n\n@inproceedings{mean_teacher,\n    title={Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results},\n    author={Tarvainen, Antti and Valpola, Harri},\n    booktitle={NIPS},\n    year={2017}\n}\n\n@inproceedings{noisy_student,\n    title={Self-training with noisy student improves imagenet classification},\n    author={Xie, Qizhe and Luong, Minh-Thang and Hovy, Eduard and Le, Quoc V},\n    booktitle={CVPR},\n    year={2020}\n}\n\n@inproceedings{UDA,\n    title={Unsupervised data augmentation for consistency training},\n    author={Xie, Qizhe and Dai, Zihang and Hovy, Eduard and Luong, Thang and Le, Quoc},\n    booktitle={NIPS},\n    year={2020}\n}\n\n@inproceedings{FixMatch,\n    title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},\n    author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},\n    booktitle={NIPS},\n    year={2020}\n}\n\n@inproceedings{SelfTuning,\n    title={Self-tuning for data-efficient deep learning},\n    author={Wang, Ximei and Gao, Jinghan and Long, Mingsheng and Wang, Jianmin},\n    booktitle={ICML},\n    year={2021}\n}\n\n@inproceedings{FlexMatch,\n    title={Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling},\n    author={Zhang, Bowen and Wang, Yidong and Hou, Wenxin and Wu, Hao and Wang, Jindong and Okumura, Manabu and Shinozaki, Takahiro},\n    booktitle={NeurIPS},\n    year={2021}\n}\n\n@inproceedings{DebiasMatch,\n    title={Debiased Learning from Naturally Imbalanced Pseudo-Labels},\n    author={Wang, Xudong and Wu, Zhirong and Lian, Long and Yu, Stella X},\n    booktitle={CVPR},\n    year={2022}\n}\n\n@article{DST,\n    title={Debiased Self-Training for Semi-Supervised Learning},\n    author={Chen, Baixu and Jiang, Junguang and Wang, Ximei and Wang, Jianmin and Long, Mingsheng},\n    journal={arXiv preprint arXiv:2202.07136},\n    year={2022}\n}\n```\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/convert_moco_to_pretrained.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport sys\nimport torch\n\nif __name__ == \"__main__\":\n    input = sys.argv[1]\n\n    obj = torch.load(input, map_location=\"cpu\")\n    obj = obj[\"state_dict\"]\n\n    newmodel = {}\n    fc = {}\n    for k, v in obj.items():\n        if not k.startswith(\"module.encoder_q.\"):\n            continue\n        old_k = k\n        k = k.replace(\"module.encoder_q.\", \"\")\n        if k.startswith(\"fc\"):\n            print(k)\n            fc[k] = v\n        else:\n            newmodel[k] = v\n\n    with open(sys.argv[2], \"wb\") as f:\n        torch.save(newmodel, f)\n\n    with open(sys.argv[3], \"wb\") as f:\n        torch.save(fc, f)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/debiasmatch.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # initialize q_hat\n    q_hat = (torch.ones(num_classes) / num_classes).to(device)\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, q_hat, epoch, args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,\n          lr_scheduler: LambdaLR, q_hat, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    self_training_losses = AverageMeter('Self Training Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')\n    pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,\n         pseudo_label_ratios],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        (x_u, x_u_strong), labels_u = next(unlabeled_train_iter)\n        x_u = x_u.to(device)\n        x_u_strong = x_u_strong.to(device)\n        labels_u = labels_u.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        # cross entropy loss\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # self training loss\n        with torch.no_grad():\n            y_u = model(x_u)\n        y_u_strong = model(x_u_strong)\n\n        # update q_hat\n        q = torch.softmax(y_u, dim=1).mean(dim=0)\n        q_hat = args.momentum * q_hat + (1 - args.momentum) * q\n\n        self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong + args.tau * torch.log(q_hat),\n                                                                          y_u - args.tau * torch.log(q_hat))\n        self_training_loss = args.trade_off_self_training * self_training_loss\n        self_training_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + self_training_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        self_training_losses.update(self_training_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # ratio of pseudo labels\n        n_pseudo_labels = mask.sum()\n        ratio = n_pseudo_labels / batch_size\n        pseudo_label_ratios.update(ratio.item() * 100, batch_size)\n\n        # accuracy of pseudo labels\n        if n_pseudo_labels > 0:\n            pseudo_labels = pseudo_labels * mask - (1 - mask)\n            n_correct = (pseudo_labels == labels_u).float().sum()\n            pseudo_label_acc = n_correct / n_pseudo_labels * 100\n            pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='DebiasMatch for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--momentum', default=0.999, type=float,\n                        help='momentum coefficient for updating q_hat (default: 0.999)')\n    parser.add_argument('--tau', default=1, type=float,\n                        help='debiased strength (default: 1)')\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-self-training', default=1, type=float,\n                        help='the trade-off hyper-parameter of self training loss')\n    parser.add_argument('--threshold', default=0.95, type=float,\n                        help='confidence threshold')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=90, type=int, metavar='N',\n                        help='number of total epochs to run (default: 90)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='debiasmatch', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/debiasmatch.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.8 --tau 1 --seed 0 --log logs/debiasmatch/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --threshold 0.9 --tau 1 --seed 0 --log logs/debiasmatch/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --tau 1 --seed 0 --log logs/debiasmatch/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 7 --tau 3 --seed 0 --log logs/debiasmatch/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/dst.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss\nfrom tllib.self_training.dst import ImageClassifier, WorstCaseEstimationLoss\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, width=args.width,\n                                 pool_layer=pool_layer, finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    self_training_losses = AverageMeter('Self Training Loss', ':3.2f')\n    wce_losses = AverageMeter('Worst Case Estimation Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')\n    pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, self_training_losses, wce_losses, cls_accs, pseudo_label_accs,\n         pseudo_label_ratios],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)\n    worst_case_estimation_criterion = WorstCaseEstimationLoss(args.eta_prime).to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        (x_u, x_u_strong), labels_u = next(unlabeled_train_iter)\n        x_u = x_u.to(device)\n        x_u_strong = x_u_strong.to(device)\n        labels_u = labels_u.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n\n        # ==============================================================================================================\n        # cross entropy loss (strong augment)\n        # ==============================================================================================================\n        y_l_strong, _, _ = model(x_l_strong)\n        cls_loss_strong = args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss_strong.backward()\n\n        x = torch.cat((x_l, x_u), dim=0)\n        outputs, outputs_adv, _ = model(x)\n        y_l, y_u = outputs.chunk(2, dim=0)\n        y_l_adv, y_u_adv = outputs_adv.chunk(2, dim=0)\n\n        # ==============================================================================================================\n        # cross entropy loss (weak augment)\n        # ==============================================================================================================\n        cls_loss_weak = F.cross_entropy(y_l, labels_l)\n\n        # ==============================================================================================================\n        # worst case estimation loss\n        # ==============================================================================================================\n        wce_loss = args.eta * worst_case_estimation_criterion(y_l, y_l_adv, y_u, y_u_adv)\n        (cls_loss_weak + wce_loss).backward()\n\n        # ==============================================================================================================\n        # self training loss\n        # ==============================================================================================================\n        _, _, y_u_strong = model(x_u_strong)\n        self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong, y_u)\n        self_training_loss = args.trade_off_self_training * self_training_loss\n        self_training_loss.backward()\n\n        # measure accuracy and record loss\n        cls_loss = cls_loss_strong + cls_loss_weak\n        cls_losses.update(cls_loss.item(), batch_size)\n        loss = cls_loss + self_training_loss + wce_loss\n        losses.update(loss.item(), batch_size)\n        wce_losses.update(wce_loss.item(), batch_size)\n        self_training_losses.update(self_training_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # ratio of pseudo labels\n        n_pseudo_labels = mask.sum()\n        ratio = n_pseudo_labels / batch_size\n        pseudo_label_ratios.update(ratio.item() * 100, batch_size)\n\n        # accuracy of pseudo labels\n        if n_pseudo_labels > 0:\n            pseudo_labels = pseudo_labels * mask - (1 - mask)\n            n_correct = (pseudo_labels == labels_u).float().sum()\n            pseudo_label_acc = n_correct / n_pseudo_labels * 100\n            pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n        model.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Debiased Self-Training for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--width', default=2048, type=int,\n                        help='width of the pseudo head and the worst-case estimation head')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-self-training', default=1, type=float,\n                        help='the trade-off hyper-parameter of self training loss')\n    parser.add_argument('--eta', default=1, type=float,\n                        help='the trade-off hyper-parameter of adversarial loss')\n    parser.add_argument('--eta-prime', default=2, type=float,\n                        help=\"the trade-off hyper-parameter between adversarial loss on labeled data \"\n                             \"and that on unlabeled data\")\n    parser.add_argument('--threshold', default=0.7, type=float,\n                        help='confidence threshold')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0002, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=90, type=int, metavar='N',\n                        help='number of total epochs to run (default: 90)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='dst', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/dst.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python dst.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.8 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.95 --trade-off-self-training 0.3 --eta-prime 2 \\\n  --seed 0 --log logs/dst/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python dst.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \\\n  --seed 0 --log logs/dst/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python dst.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python dst.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.95 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python dst.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.9 --trade-off-self-training 0.3 --eta-prime 2 \\\n  --seed 0 --log logs/dst/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python dst.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.9 --trade-off-self-training 0.3 --eta-prime 1 \\\n  --seed 0 --log logs/dst/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python dst.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 4 \\\n  --seed 0 --log logs/dst/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python dst.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \\\n  --seed 0 --log logs/dst_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \\\n  --seed 0 --log logs/dst_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \\\n  --seed 0 --log logs/dst_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python dst.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \\\n  --seed 0 --log logs/dst_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python dst.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \\\n  --seed 0 --log logs/dst_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python dst.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 0.3 --eta-prime 2 \\\n  --seed 0 --log logs/dst_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python dst.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 0.1 --eta-prime 3 \\\n  --seed 0 --log logs/dst_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python dst.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 0.1 --eta-prime 1 \\\n  --seed 0 --log logs/dst_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python dst.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 1 --eta-prime 1 \\\n  --seed 0 --log logs/dst_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python dst.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 0.1 --eta-prime 1 \\\n  --seed 0 --log logs/dst_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/erm.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader, ConcatDataset\n\nimport utils\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    train_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('train_transform: ', train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, train_transform,\n                          val_transform,\n                          seed=args.seed)\n    if args.oracle:\n        num_classes = labeled_train_dataset.num_classes\n        labeled_train_dataset = ConcatDataset([labeled_train_dataset, unlabeled_train_dataset])\n        labeled_train_dataset.num_classes = num_classes\n\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        utils.empirical_risk_minimization(labeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args, device)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Baseline for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    parser.add_argument('--oracle', action='store_true', default=False,\n                        help='use all data as labeled data (oracle)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run (default: 20)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='baseline', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/erm.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/erm/food101_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/erm/food101_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --oracle -a resnet50 \\\n  --lr 0.01 --finetune --epochs 80 --seed 0 --log logs/erm/food101_oracle\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/cifar10_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/cifar10_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --oracle -a resnet50 \\\n  --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/cifar10_oracle\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/erm/cifar100_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/erm/cifar100_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --oracle -a resnet50 \\\n  --lr 0.01 --finetune --epochs 80 --seed 0 --log logs/erm/cifar100_oracle\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/erm/cub200_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/erm/cub200_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --oracle -a resnet50 \\\n  --lr 0.003 --finetune --epochs 80 --seed 0 --log logs/erm/cub200_oracle\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/aircraft_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/aircraft_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --oracle -a resnet50 \\\n  --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/aircraft_oracle\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/car_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/car_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --oracle -a resnet50 \\\n  --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/car_oracle\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --seed 0 --log logs/erm/sun_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.001 --finetune --seed 0 --log logs/erm/sun_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --oracle -a resnet50 \\\n  --lr 0.001 --finetune --epochs 80 --seed 0 --log logs/erm/sun_oracle\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/dtd_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/dtd_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --oracle -a resnet50 \\\n  --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/dtd_oracle\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --seed 0 --log logs/erm/pets_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.001 --finetune --seed 0 --log logs/erm/pets_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --oracle -a resnet50 \\\n  --lr 0.001 --finetune --epochs 80 --seed 0 --log logs/erm/pets_oracle\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/flowers_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/erm/flowers_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --oracle -a resnet50 \\\n  --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/flowers_oracle\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/erm/caltech_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 10 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/erm/caltech_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --oracle -a resnet50 \\\n  --lr 0.003 --finetune --epochs 80 --seed 0 --log logs/erm/caltech_oracle\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/food101_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/food101_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/food101_oracle\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar10_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar10_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cifar10_oracle\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cifar100_oracle\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cub200_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cub200_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cub200_oracle\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/aircraft_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 10 -a resnet50 \\\n--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/aircraft_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/aircraft_oracle\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/car_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/car_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/car_oracle\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/sun_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/sun_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/sun_oracle\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/dtd_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/dtd_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/dtd_oracle\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/pets_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/pets_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/pets_oracle\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/flowers_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/flowers_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/flowers_oracle\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/caltech_4_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 10 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/caltech_10_labels_per_class\nCUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --oracle -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/caltech_oracle\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/fixmatch.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    self_training_losses = AverageMeter('Self Training Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')\n    pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,\n         pseudo_label_ratios],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        (x_u, x_u_strong), labels_u = next(unlabeled_train_iter)\n        x_u = x_u.to(device)\n        x_u_strong = x_u_strong.to(device)\n        labels_u = labels_u.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        # cross entropy loss\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # self training loss\n        with torch.no_grad():\n            y_u = model(x_u)\n        y_u_strong = model(x_u_strong)\n        self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong, y_u)\n        self_training_loss = args.trade_off_self_training * self_training_loss\n        self_training_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + self_training_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        self_training_losses.update(self_training_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # ratio of pseudo labels\n        n_pseudo_labels = mask.sum()\n        ratio = n_pseudo_labels / batch_size\n        pseudo_label_ratios.update(ratio.item() * 100, batch_size)\n\n        # accuracy of pseudo labels\n        if n_pseudo_labels > 0:\n            pseudo_labels = pseudo_labels * mask - (1 - mask)\n            n_correct = (pseudo_labels == labels_u).float().sum()\n            pseudo_label_acc = n_correct / n_pseudo_labels * 100\n            pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='FixMatch for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-self-training', default=1, type=float,\n                        help='the trade-off hyper-parameter of self training loss')\n    parser.add_argument('--threshold', default=0.95, type=float,\n                        help='confidence threshold')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run (default: 60)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='fixmatch', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/fixmatch.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/fixmatch/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/fixmatch/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.9 --seed 0 --log logs/fixmatch/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/fixmatch_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/fixmatch_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/fixmatch_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/fixmatch_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/fixmatch_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python fixmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/fixmatch_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/flexmatch.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.flexmatch import DynamicThresholdingModule\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    unlabeled_train_dataset = utils.convert_dataset(unlabeled_train_dataset)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    args.num_classes = num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # thresholding module with convex mapping function x / (2 - x)\n    thresholding_module = DynamicThresholdingModule(args.threshold, args.warmup, lambda x: x / (2 - x), num_classes,\n                                                    len(unlabeled_train_dataset), device=device)\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, thresholding_module, classifier, optimizer, lr_scheduler, epoch,\n              args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator,\n          thresholding_module: DynamicThresholdingModule, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,\n          args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    self_training_losses = AverageMeter('Self Training Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')\n    pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,\n         pseudo_label_ratios],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        idx_u, ((x_u, x_u_strong), labels_u) = next(unlabeled_train_iter)\n        idx_u = idx_u.to(device)\n        x_u = x_u.to(device)\n        x_u_strong = x_u_strong.to(device)\n        labels_u = labels_u.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        # cross entropy loss\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # self training loss\n        with torch.no_grad():\n            y_u = model(x_u)\n        y_u_strong = model(x_u_strong)\n\n        confidence, pseudo_labels = torch.softmax(y_u, dim=1).max(dim=1)\n        dynamic_threshold = thresholding_module.get_threshold(pseudo_labels)\n        mask = (confidence > dynamic_threshold).float()\n        # mask used for updating learning status\n        selected_mask = (confidence > args.threshold).long()\n        thresholding_module.update(idx_u, selected_mask, pseudo_labels)\n\n        self_training_loss = args.trade_off_self_training * (\n                F.cross_entropy(y_u_strong, pseudo_labels, reduction='none') * mask).mean()\n        self_training_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + self_training_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        self_training_losses.update(self_training_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # ratio of pseudo labels\n        n_pseudo_labels = mask.sum()\n        ratio = n_pseudo_labels / batch_size\n        pseudo_label_ratios.update(ratio.item() * 100, batch_size)\n\n        # accuracy of pseudo labels\n        if n_pseudo_labels > 0:\n            pseudo_labels = pseudo_labels * mask - (1 - mask)\n            n_correct = (pseudo_labels == labels_u).float().sum()\n            pseudo_label_acc = n_correct / n_pseudo_labels * 100\n            pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='FlexMatch for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--warmup', default=False, type=bool)\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-self-training', default=1, type=float,\n                        help='the trade-off hyper-parameter of self training loss')\n    parser.add_argument('--threshold', default=0.95, type=float,\n                        help='confidence threshold')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=90, type=int, metavar='N',\n                        help='number of total epochs to run (default: 90)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='flexmatch', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/flexmatch.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.9 --seed 0 --log logs/flexmatch/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.9 --seed 0 --log logs/flexmatch/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/flexmatch_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/flexmatch_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python flexmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/mean_teacher.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.pi_model import sigmoid_warm_up, L2ConsistencyLoss\nfrom tllib.self_training.mean_teacher import update_bn, EMATeacher\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = MultipleApply([weak_augment, weak_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    teacher = EMATeacher(classifier, alpha=args.alpha)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, classifier, teacher, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, teacher,\n          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    con_losses = AverageMeter('Con Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, con_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    consistency_criterion = L2ConsistencyLoss(reduction='sum').to(device)\n    # switch to train mode\n    model.train()\n    teacher.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        (x_u, x_u_teacher), _ = next(unlabeled_train_iter)\n        x_u = x_u.to(device)\n        x_u_teacher = x_u_teacher.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        # cross entropy loss\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # consistency loss\n        with torch.no_grad():\n            y_u_teacher = teacher(x_u_teacher)\n            p_u_teacher = torch.softmax(y_u_teacher, dim=1)\n        y_u = model(x_u)\n        p_u = torch.softmax(y_u, dim=1)\n        con_loss = args.trade_off_con * sigmoid_warm_up(epoch, args.warm_up_epochs) * \\\n                   consistency_criterion(p_u, p_u_teacher)\n        con_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + con_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        con_losses.update(con_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # update teacher\n        global_step = epoch * args.iters_per_epoch + i + 1\n        teacher.set_alpha(min(args.alpha, 1 - 1 / global_step))\n        teacher.update()\n        update_bn(model, teacher.teacher)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Mean Teacher for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    parser.add_argument('--alpha', default=0.999, type=float,\n                        help='ema decay factor')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-con', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of consistency loss')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0001, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=40, type=int, metavar='N',\n                        help='number of total epochs to run (default: 40)')\n    parser.add_argument('--warm-up-epochs', default=10, type=int,\n                        help='number of epochs to warm up (default: 10)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='mean_teacher', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/mean_teacher.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.05 --finetune --seed 0 --log logs/mean_teacher/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.05 --finetune --seed 0 --log logs/mean_teacher/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/mean_teacher/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/mean_teacher/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/mean_teacher/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/mean_teacher/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/noisy_student.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport copy\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.vision.models.reid.loss import CrossEntropyLoss\nfrom tllib.modules.classifier import Classifier\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\nclass ImageClassifier(Classifier):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs):\n        bottleneck = nn.Sequential(\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        bottleneck[0].weight.data.normal_(0, 0.005)\n        bottleneck[0].bias.data.fill_(0.1)\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n        self.dropout = nn.Dropout(0.5)\n        self.as_teacher_model = False\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\"\"\"\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        if not self.as_teacher_model:\n            f = self.dropout(f)\n        predictions = self.head(f)\n        return predictions\n\n\ndef calc_teacher_output(classifier_teacher: ImageClassifier, weak_augmented_unlabeled_dataset):\n    \"\"\"Compute outputs of the teacher network. Here, we use weak data augmentation and do not introduce an additional\n    dropout layer according to the Noisy Student paper `Self-Training With Noisy Student Improves ImageNet\n    Classification <https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_Self-Training_With_Noisy_Student_Improves\n    _ImageNet_Classification_CVPR_2020_paper.pdf>`_.\n    \"\"\"\n\n    data_loader = DataLoader(weak_augmented_unlabeled_dataset, batch_size=args.batch_size, shuffle=False,\n                             num_workers=args.workers, drop_last=False)\n    batch_time = AverageMeter('Time', ':6.3f')\n    progress = ProgressMeter(\n        len(data_loader),\n        [batch_time],\n        prefix='Computing teacher output: ')\n\n    teacher_output = []\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, _) in enumerate(data_loader):\n            images = images.to(device)\n            output = classifier_teacher(images)\n            teacher_output.append(output)\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n    teacher_output = torch.cat(teacher_output, dim=0)\n    return teacher_output\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('weak_augment (input transform for teacher model): ', weak_augment)\n    print('strong_augment (input transform for student model): ', strong_augment)\n    print('val_transform:', val_transform)\n\n    labeled_train_dataset, weak_augmented_unlabeled_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=weak_augment,\n                          seed=args.seed)\n    _, strong_augmented_unlabeled_dataset, _ = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=strong_augment,\n                          seed=args.seed)\n\n    strong_augmented_unlabeled_dataset = utils.convert_dataset(strong_augmented_unlabeled_dataset)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(weak_augmented_unlabeled_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(strong_augmented_unlabeled_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                 finetune=args.finetune).to(device)\n    print(classifier)\n\n    if args.pretrained_teacher:\n        # load teacher model\n        classifier_teacher = copy.deepcopy(classifier)\n        checkpoint = torch.load(args.pretrained_teacher)\n        classifier_teacher.load_state_dict(checkpoint)\n        classifier_teacher.eval()\n        classifier_teacher.as_teacher_model = True\n\n        print('compute outputs of the teacher network')\n        teacher_output = calc_teacher_output(classifier_teacher, weak_augmented_unlabeled_dataset)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        if args.pretrained_teacher:\n            train(labeled_train_iter, unlabeled_train_iter, classifier, teacher_output, optimizer, lr_scheduler,\n                  epoch, args)\n        else:\n            utils.empirical_risk_minimization(labeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args,\n                                              device)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, teacher_output,\n          optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    self_training_losses = AverageMeter('Self Training Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    self_training_criterion = CrossEntropyLoss().to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        idx_u, (x_u_strong, _) = next(unlabeled_train_iter)\n        idx_u = idx_u.to(device)\n        x_u_strong = x_u_strong.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        # cross entropy loss\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # self training loss\n        y_u = teacher_output[idx_u]\n        y_u_strong = model(x_u_strong)\n        self_training_loss = args.trade_off_self_training * self_training_criterion(y_u_strong / args.T, y_u / args.T)\n        self_training_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + self_training_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        self_training_losses.update(self_training_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Noisy Student for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    parser.add_argument('--pretrained-teacher', default=None, type=str,\n                        help='pretrained checkpoint of the teacher model')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-self-training', default=1, type=float,\n                        help='the trade-off hyper-parameter of self training loss')\n    parser.add_argument('--T', default=2, type=float,\n                        help='temperature')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=40, type=int, metavar='N',\n                        help='number of total epochs to run (default: 40)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='noisy_student', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/noisy_student.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --epochs 20 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_0\n\nfor round in 0 1 2; do\n  CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n    --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n    --pretrained-teacher logs/noisy_student/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \\\n    --lr 0.01 --finetune --epochs 40 --T 0.5 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_$((round + 1))\ndone\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --epochs 20 --seed 0 \\\n  --log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_0\n\nfor round in 0 1 2; do\n  CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n    --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n    --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n    --pretrained-teacher logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \\\n    --lr 0.001 --finetune --lr-scheduler cos --epochs 40 --T 1 --seed 0 \\\n    --log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$((round + 1))\ndone\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/pi_model.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.pi_model import sigmoid_warm_up, L2ConsistencyLoss\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = MultipleApply([weak_augment, weak_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    con_losses = AverageMeter('Con Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, con_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    consistency_criterion = L2ConsistencyLoss().to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        (x_u1, x_u2), _ = next(unlabeled_train_iter)\n        x_u1 = x_u1.to(device)\n        x_u2 = x_u2.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        # cross entropy loss\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # consistency loss\n        y_u1 = model(x_u1)\n        y_u2 = model(x_u2)\n        p_u1 = torch.softmax(y_u1, dim=1)\n        p_u2 = torch.softmax(y_u2, dim=1)\n        con_loss = args.trade_off_con * sigmoid_warm_up(epoch, args.warm_up_epochs) * consistency_criterion(p_u1, p_u2)\n        con_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + con_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        con_losses.update(con_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Pi Model for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-con', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of consistency loss')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=40, type=int, metavar='N',\n                        help='number of total epochs to run (default: 40)')\n    parser.add_argument('--warm-up-epochs', default=10, type=int,\n                        help='number of epochs to warm up (default: 10)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='pi_model', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/pi_model.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/pi_model/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/pi_model/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/pi_model/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/pi_model/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/pi_model/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/pi_model/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/pi_model/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/pi_model/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --seed 0 --log logs/pi_model/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --seed 0 --log logs/pi_model/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/pi_model/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python pi_model.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n   --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/pseudo_label.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = weak_augment\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    self_training_losses = AverageMeter('Self Training Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n    pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        x_u, labels_u = next(unlabeled_train_iter)\n        x_u = x_u.to(device)\n        labels_u = labels_u.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        # cross entropy loss\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # self training loss\n        y_u = model(x_u)\n        self_training_loss, mask, pseudo_labels = self_training_criterion(y_u, y_u)\n        self_training_loss = args.trade_off_self_training * self_training_loss\n        self_training_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + self_training_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        self_training_losses.update(self_training_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # accuracy of pseudo labels\n        n_pseudo_labels = mask.sum()\n        if n_pseudo_labels > 0:\n            pseudo_labels = pseudo_labels * mask - (1 - mask)\n            n_correct = (pseudo_labels == labels_u).float().sum()\n            pseudo_label_acc = n_correct / n_pseudo_labels * 100\n            pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Pseudo Label for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-self-training', default=1, type=float,\n                        help='the trade-off hyper-parameter of self training loss')\n    parser.add_argument('--threshold', default=0.95, type=float,\n                        help='confidence threshold (default: 0.95)')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=40, type=int, metavar='N',\n                        help='number of total epochs to run (default: 40)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='pseudo_label', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/pseudo_label.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/pseudo_label/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/pseudo_label/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/pseudo_label_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/requirements.txt",
    "content": "timm"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/self_tuning.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.self_tuning import Classifier, SelfTuning\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    train_transform = MultipleApply([strong_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('train_transform: ', train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, train_transform,\n                          val_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    num_classes = labeled_train_dataset.num_classes\n\n    backbone_q = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier_q = Classifier(backbone_q, num_classes, projection_dim=args.projection_dim,\n                              bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                              finetune=args.finetune).to(device)\n    print(classifier_q)\n\n    backbone_k = utils.get_model(args.arch)\n    classifier_k = Classifier(backbone_k, num_classes, projection_dim=args.projection_dim,\n                              bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(device)\n\n    selftuning = SelfTuning(classifier_q, classifier_k, num_classes, K=args.K, m=args.m, T=args.T).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier_q.get_parameters(args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier_q.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier_q, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, selftuning, optimizer, epoch, args)\n\n        # update lr\n        lr_scheduler.step()\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier_q, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier_q.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, selftuning: SelfTuning,\n          optimizer: SGD, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    pgc_losses_labeled = AverageMeter('Pgc Loss (Labeled Data)', ':3.2f')\n    pgc_losses_unlabeled = AverageMeter('Pgc Loss (Unlabeled Data)', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, pgc_losses_labeled, pgc_losses_unlabeled, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # define loss functions\n    criterion_kl = nn.KLDivLoss(reduction='batchmean').to(device)\n\n    # switch to train mode\n    selftuning.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (l_q, l_k), labels_l = next(labeled_train_iter)\n        (u_q, u_k), _ = next(unlabeled_train_iter)\n\n        l_q, l_k = l_q.to(device), l_k.to(device)\n        u_q, u_k = u_q.to(device), u_k.to(device)\n        labels_l = labels_l.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        pgc_logits_labeled, pgc_labels_labeled, y_l = selftuning(l_q, l_k, labels_l)\n        # cross entropy loss\n        cls_loss = F.cross_entropy(y_l, labels_l)\n\n        # pgc loss on labeled samples\n        pgc_loss_labeled = criterion_kl(pgc_logits_labeled, pgc_labels_labeled)\n        (cls_loss + pgc_loss_labeled).backward()\n\n        # pgc loss on unlabeled samples\n        _, y_pred = selftuning.encoder_q(u_q)\n        _, pseudo_labels = torch.max(y_pred, dim=1)\n        pgc_logits_unlabeled, pgc_labels_unlabeled, _ = selftuning(u_q, u_k, pseudo_labels)\n        pgc_loss_unlabeled = criterion_kl(pgc_logits_unlabeled, pgc_labels_unlabeled)\n        pgc_loss_unlabeled.backward()\n\n        # compute gradient and do SGD step\n        optimizer.step()\n\n        # measure accuracy and record loss\n        cls_losses.update(cls_loss.item(), batch_size)\n        pgc_losses_labeled.update(pgc_loss_labeled.item(), batch_size)\n        pgc_losses_unlabeled.update(pgc_loss_unlabeled.item(), batch_size)\n        loss = cls_loss + pgc_loss_labeled + pgc_loss_unlabeled\n        losses.update(loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Self Tuning for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--projection-dim', default=1024, type=int,\n                        help='dimension of projection head')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--T', default=0.07, type=float,\n                        help=\"temperature (default: 0.07)\")\n    parser.add_argument('--K', default=32, type=int,\n                        help=\"queue size (default: 32)\")\n    parser.add_argument('--m', default=0.999, type=float,\n                        help=\"momentum coefficient (default: 0.999)\")\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-gamma', default=0.1, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--milestones', default=[12, 24, 36, 48], type=int, nargs='+',\n                        help='epochs to decay learning rate')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run (default: 60)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='self_tuning', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/self_tuning.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/self_tuning/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/self_tuning/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.01 --finetune --seed 0 --log logs/self_tuning/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.003 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.01 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python self_tuning.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/uda.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.self_training.uda import StrongWeakConsistencyLoss\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                             norm_mean=args.norm_mean, norm_std=args.norm_std)\n    strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,\n                                               auto_augment=args.auto_augment,\n                                               norm_mean=args.norm_mean, norm_std=args.norm_std)\n    labeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])\n    val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)\n    print('labeled_train_transform: ', labeled_train_transform)\n    print('unlabeled_train_transform: ', unlabeled_train_transform)\n    print('val_transform:', val_transform)\n    labeled_train_dataset, unlabeled_train_dataset, val_dataset = \\\n        utils.get_dataset(args.data,\n                          args.num_samples_per_class,\n                          args.root, labeled_train_transform,\n                          val_transform,\n                          unlabeled_train_transform=unlabeled_train_transform,\n                          seed=args.seed)\n    print(\"labeled_dataset_size: \", len(labeled_train_dataset))\n    print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))\n    print(\"val_dataset_size: \", len(val_dataset))\n\n    labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                      num_workers=args.workers, drop_last=True)\n    unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,\n                                        num_workers=args.workers, drop_last=True)\n    labeled_train_iter = ForeverDataIterator(labeled_train_loader)\n    unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)\n    num_classes = labeled_train_dataset.num_classes\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,\n                                       finetune=args.finetune).to(device)\n    print(classifier)\n\n    # define optimizer and lr scheduler\n    if args.lr_scheduler == 'exp':\n        optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)\n        lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))\n    else:\n        optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,\n                        nesterov=True)\n        lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    best_avg = 0.0\n    for epoch in range(args.epochs):\n        # print lr\n        print(lr_scheduler.get_lr())\n\n        # train for one epoch\n        train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)\n\n        # evaluate on validation set\n        acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n        best_avg = max(avg, best_avg)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    print('best_avg = {:3.1f}'.format(best_avg))\n    logger.close()\n\n\ndef train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,\n          lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    con_losses = AverageMeter('Con Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, con_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    consistency_criterion = StrongWeakConsistencyLoss(args.threshold, args.T).to(device)\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        (x_u, x_u_strong), _ = next(unlabeled_train_iter)\n        x_u = x_u.to(device)\n        x_u_strong = x_u_strong.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # clear grad\n        optimizer.zero_grad()\n\n        # compute output\n        # cross entropy loss\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n        cls_loss.backward()\n\n        # consistency loss\n        with torch.no_grad():\n            y_u = model(x_u)\n        y_u_strong = model(x_u_strong)\n        con_loss = args.trade_off_con * consistency_criterion(y_u_strong, y_u)\n        con_loss.backward()\n\n        # measure accuracy and record loss\n        loss = cls_loss + con_loss\n        losses.update(loss.item(), batch_size)\n        cls_losses.update(cls_loss.item(), batch_size)\n        con_losses.update(con_loss.item(), batch_size)\n\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # compute gradient and do SGD step\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='UDA for Semi Supervised Learning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA',\n                        help='dataset: ' + ' | '.join(utils.get_dataset_names()))\n    parser.add_argument('--num-samples-per-class', default=4, type=int,\n                        help='number of labeled samples per class')\n    parser.add_argument('--train-resizing', default='default', type=str)\n    parser.add_argument('--val-resizing', default='default', type=str)\n    parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',\n                        help='normalization mean')\n    parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',\n                        help='normalization std')\n    parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,\n                        help='AutoAugment policy (default: rand-m10-n2-mstd2)')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),\n                        help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')\n    parser.add_argument('--bottleneck-dim', default=1024, type=int,\n                        help='dimension of bottleneck')\n    parser.add_argument('--no-pool', action='store_true', default=False,\n                        help='no pool layer after the feature extractor')\n    parser.add_argument('--pretrained-backbone', default=None, type=str,\n                        help=\"pretrained checkpoint of the backbone \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--finetune', action='store_true', default=False,\n                        help='whether to use 10x smaller lr for backbone')\n    # training parameters\n    parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,\n                        help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')\n    parser.add_argument('--trade-off-con', default=1, type=float,\n                        help='the trade-off hyper-parameter of consistency loss')\n    parser.add_argument('--threshold', default=0.7, type=float,\n                        help='confidence threshold')\n    parser.add_argument('--T', default=0.85, type=float,\n                        help='temperature')\n    parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',\n                        help='mini-batch size (default: 32)')\n    parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',\n                        help='initial learning rate')\n    parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],\n                        help='learning rate decay strategy')\n    parser.add_argument('--lr-gamma', default=0.0004, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay', default=0.75, type=float,\n                        help='parameter for lr scheduler')\n    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',\n                        help='weight decay (default:5e-4)')\n    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',\n                        help='number of data loading workers (default: 4)')\n    parser.add_argument('--epochs', default=60, type=int, metavar='N',\n                        help='number of total epochs to run (default: 60)')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='number of iterations per epoch (default: 500)')\n    parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',\n                        help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training ')\n    parser.add_argument(\"--log\", default='uda', type=str,\n                        help=\"where to save logs, checkpoints and debugging images\")\n    parser.add_argument(\"--phase\", default='train', type=str, choices=['train', 'test'],\n                        help=\"when phase is 'test', only test the model\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/uda.sh",
    "content": "#!/usr/bin/env bash\n\n# ImageNet Supervised Pretrain (ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python uda.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/uda/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python uda.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python uda.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python uda.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/uda/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python uda.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python uda.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python uda.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/caltech_4_labels_per_class\n\n# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)\n# ======================================================================================================================\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python uda.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/food101_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 10\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cifar10_4_labels_per_class\n\n# ======================================================================================================================\n# CIFAR 100\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \\\n  --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cifar100_4_labels_per_class\n\n# ======================================================================================================================\n# CUB 200\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cub200_4_labels_per_class\n\n# ======================================================================================================================\n# Aircraft\nCUDA_VISIBLE_DEVICES=0 python uda.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/aircraft_4_labels_per_class\n\n# ======================================================================================================================\n# StanfordCars\nCUDA_VISIBLE_DEVICES=0 python uda.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/car_4_labels_per_class\n\n# ======================================================================================================================\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python uda.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/sun_4_labels_per_class\n\n# ======================================================================================================================\n# DTD\nCUDA_VISIBLE_DEVICES=0 python uda.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/dtd_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Pets\nCUDA_VISIBLE_DEVICES=0 python uda.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/pets_4_labels_per_class\n\n# ======================================================================================================================\n# Oxford Flowers\nCUDA_VISIBLE_DEVICES=0 python uda.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/uda_moco_pretrain/flowers_4_labels_per_class\n\n# ======================================================================================================================\n# Caltech 101\nCUDA_VISIBLE_DEVICES=0 python uda.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \\\n  --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \\\n  --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/caltech_4_labels_per_class\n"
  },
  {
    "path": "examples/semi_supervised_learning/image_classification/utils.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport math\nimport sys\nimport time\nfrom PIL import Image\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom torch.utils.data.dataset import Subset, ConcatDataset\nimport torchvision.transforms as T\nimport timm\nfrom timm.data.auto_augment import auto_augment_transform, rand_augment_transform\n\nsys.path.append('../../..')\nfrom tllib.modules.classifier import Classifier\nimport tllib.vision.datasets as datasets\nimport tllib.vision.models as models\nfrom tllib.vision.transforms import ResizeImage\nfrom tllib.utils.metric import accuracy, ConfusionMatrix\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\ndef get_model_names():\n    return sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    ) + timm.list_models()\n\n\ndef get_model(model_name, pretrained=True, pretrained_checkpoint=None):\n    if model_name in models.__dict__:\n        # load models from common.vision.models\n        backbone = models.__dict__[model_name](pretrained=pretrained)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=pretrained)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n    if pretrained_checkpoint:\n        print(\"=> loading pre-trained model from '{}'\".format(pretrained_checkpoint))\n        pretrained_dict = torch.load(pretrained_checkpoint)\n        backbone.load_state_dict(pretrained_dict, strict=False)\n    return backbone\n\n\ndef get_dataset_names():\n    return sorted(\n        name for name in datasets.__dict__\n        if not name.startswith(\"__\") and callable(datasets.__dict__[name])\n    )\n\n\ndef get_dataset(dataset_name, num_samples_per_class, root, labeled_train_transform, val_transform,\n                unlabeled_train_transform=None, seed=0):\n    if unlabeled_train_transform is None:\n        unlabeled_train_transform = labeled_train_transform\n\n    if dataset_name == 'OxfordFlowers102':\n        dataset = datasets.__dict__[dataset_name]\n        base_dataset = dataset(root=root, split='train', transform=labeled_train_transform, download=True)\n        # create labeled and unlabeled splits\n        labeled_idxes, unlabeled_idxes = x_u_split(num_samples_per_class, base_dataset.num_classes,\n                                                   base_dataset.targets, seed=seed)\n        # labeled subset\n        labeled_train_dataset = Subset(base_dataset, labeled_idxes)\n        labeled_train_dataset.num_classes = base_dataset.num_classes\n        # unlabeled subset\n        base_dataset = dataset(root=root, split='train', transform=unlabeled_train_transform, download=True)\n        unlabeled_train_dataset = ConcatDataset([\n            Subset(base_dataset, unlabeled_idxes),\n            dataset(root=root, split='validation', download=True, transform=unlabeled_train_transform)\n        ])\n        val_dataset = dataset(root=root, split='test', download=True, transform=val_transform)\n    else:\n        dataset = datasets.__dict__[dataset_name]\n        base_dataset = dataset(root=root, split='train', transform=labeled_train_transform, download=True)\n        # create labeled and unlabeled splits\n        labeled_idxes, unlabeled_idxes = x_u_split(num_samples_per_class, base_dataset.num_classes,\n                                                   base_dataset.targets, seed=seed)\n        # labeled subset\n        labeled_train_dataset = Subset(base_dataset, labeled_idxes)\n        labeled_train_dataset.num_classes = base_dataset.num_classes\n        # unlabeled subset\n        base_dataset = dataset(root=root, split='train', transform=unlabeled_train_transform, download=True)\n        unlabeled_train_dataset = Subset(base_dataset, unlabeled_idxes)\n        val_dataset = dataset(root=root, split='test', download=True, transform=val_transform)\n    return labeled_train_dataset, unlabeled_train_dataset, val_dataset\n\n\ndef x_u_split(num_samples_per_class, num_classes, labels, seed):\n    \"\"\"\n    Construct labeled and unlabeled subsets, where the labeled subset is class balanced. Note that the resulting\n    subsets are **deterministic** with the same random seed.\n    \"\"\"\n    labels = np.array(labels)\n    assert num_samples_per_class * num_classes <= len(labels)\n    random_state = np.random.RandomState(seed)\n\n    # labeled subset\n    labeled_idxes = []\n    for i in range(num_classes):\n        ith_class_idxes = np.where(labels == i)[0]\n        ith_class_idxes = random_state.choice(ith_class_idxes, num_samples_per_class, False)\n        labeled_idxes.extend(ith_class_idxes)\n\n    # unlabeled subset\n    unlabeled_idxes = [i for i in range(len(labels)) if i not in labeled_idxes]\n    return labeled_idxes, unlabeled_idxes\n\n\ndef get_train_transform(resizing='default', random_horizontal_flip=True, auto_augment=None,\n                        norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):\n    if resizing == 'default':\n        transform = T.RandomResizedCrop(224, scale=(0.2, 1.))\n    elif resizing == 'cifar':\n        transform = T.Compose([\n            T.RandomCrop(size=32, padding=4, padding_mode='reflect'),\n            ResizeImage(224)\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if auto_augment:\n        aa_params = dict(\n            translate_const=int(224 * 0.45),\n            img_mean=tuple([min(255, round(255 * x)) for x in norm_mean]),\n            interpolation=Image.BILINEAR\n        )\n        if auto_augment.startswith('rand'):\n            transforms.append(rand_augment_transform(auto_augment, aa_params))\n        else:\n            transforms.append(auto_augment_transform(auto_augment, aa_params))\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=norm_mean, std=norm_std)\n    ])\n    return T.Compose(transforms)\n\n\ndef get_val_transform(resizing='default', norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):\n    if resizing == 'default':\n        transform = T.Compose([\n            ResizeImage(256),\n            T.CenterCrop(224),\n        ])\n    elif resizing == 'cifar':\n        transform = ResizeImage(224)\n    else:\n        raise NotImplementedError(resizing)\n    return T.Compose([\n        transform,\n        T.ToTensor(),\n        T.Normalize(mean=norm_mean, std=norm_std)\n    ])\n\n\ndef convert_dataset(dataset):\n    \"\"\"\n    Converts a dataset which returns (img, label) pairs into one that returns (index, img, label) triplets.\n    \"\"\"\n\n    class DatasetWrapper:\n\n        def __init__(self):\n            self.dataset = dataset\n\n        def __getitem__(self, index):\n            return index, self.dataset[index]\n\n        def __len__(self):\n            return len(self.dataset)\n\n    return DatasetWrapper()\n\n\nclass ImageClassifier(Classifier):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs):\n        bottleneck = nn.Sequential(\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(0.5)\n        )\n        bottleneck[0].weight.data.normal_(0, 0.005)\n        bottleneck[0].bias.data.fill_(0.1)\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n\n    def forward(self, x: torch.Tensor):\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        predictions = self.head(f)\n        return predictions\n\n\ndef get_cosine_scheduler_with_warmup(optimizer, T_max, num_cycles=7. / 16., num_warmup_steps=0,\n                                     last_epoch=-1):\n    \"\"\"\n    Cosine learning rate scheduler from `FixMatch: Simplifying Semi-Supervised Learning with\n    Consistency and Confidence (NIPS 2020) <https://arxiv.org/abs/2001.07685>`_.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        T_max (int): Maximum number of iterations.\n        num_cycles (float): A scalar that controls the shape of cosine function. Default: 7/16.\n        num_warmup_steps (int): Number of iterations to warm up. Default: 0.\n        last_epoch (int): The index of last epoch. Default: -1.\n\n    \"\"\"\n\n    def _lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            _lr = float(current_step) / float(max(1, num_warmup_steps))\n        else:\n            num_cos_steps = float(current_step - num_warmup_steps)\n            num_cos_steps = num_cos_steps / float(max(1, T_max - num_warmup_steps))\n            _lr = max(0.0, math.cos(math.pi * num_cycles * num_cos_steps))\n        return _lr\n\n    return LambdaLR(optimizer, _lr_lambda, last_epoch)\n\n\ndef validate(val_loader, model, args, device, num_classes):\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    top1 = AverageMeter('Acc@1', ':6.2f')\n    top5 = AverageMeter('Acc@5', ':6.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, top1, top5],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n    confmat = ConfusionMatrix(num_classes)\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            loss = F.cross_entropy(output, target)\n\n            # measure accuracy and record loss\n            acc1, acc5 = accuracy(output, target, topk=(1, 5))\n            confmat.update(target, output.argmax(1))\n            losses.update(loss.item(), images.size(0))\n            top1.update(acc1.item(), images.size(0))\n            top5.update(acc5.item(), images.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n\n        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'\n              .format(top1=top1, top5=top5))\n        acc_global, acc_per_class, iu = confmat.compute()\n        mean_cls_acc = acc_per_class.mean().item() * 100\n        print(' * Mean Cls {:.3f}'.format(mean_cls_acc))\n\n    return top1.avg, mean_cls_acc\n\n\ndef empirical_risk_minimization(labeled_train_iter, model, optimizer, lr_scheduler, epoch, args, device):\n    batch_time = AverageMeter('Time', ':2.2f')\n    data_time = AverageMeter('Data', ':2.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    batch_size = args.batch_size\n    for i in range(args.iters_per_epoch):\n        (x_l, x_l_strong), labels_l = next(labeled_train_iter)\n        x_l = x_l.to(device)\n        x_l_strong = x_l_strong.to(device)\n        labels_l = labels_l.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_l = model(x_l)\n        y_l_strong = model(x_l_strong)\n        # cross entropy loss on both weak augmented and strong augmented samples\n        loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)\n\n        # measure accuracy and record loss\n        losses.update(loss.item(), batch_size)\n        cls_acc = accuracy(y_l, labels_l)[0]\n        cls_accs.update(cls_acc.item(), batch_size)\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n        lr_scheduler.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/README.md",
    "content": "# Task Adaptation for Image Classification\n\n## Installation\n\nExample scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You\nneed to install timm to use PyTorch-Image-Models.\n\n```\npip install timm\n```\n\n## Dataset\n\nFollowing datasets can be downloaded automatically:\n\n- [CUB200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)\n- [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)\n- [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)\n- [StanfordDogs](http://vision.stanford.edu/aditya86/ImageNetDogs/)\n- [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/)\n- [OxfordFlowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)\n- [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html)\n- [PatchCamelyon](https://patchcamelyon.grand-challenge.org/)\n- [EuroSAT](https://github.com/phelber/eurosat)\n\nYou need to prepare following datasets manually if you want to use them:\n\n- [Retinopathy](https://www.kaggle.com/c/diabetic-retinopathy-detection/data)\n- [Resisc45](http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html)\n\nand prepare them following [Documentation for Retinopathy](/common/vision/datasets/retinopathy.py)\nand [Resisc45](/common/vision/datasets/resisc45.py).\n\n## Supported Methods\n\nSupported methods include:\n\n- [Explicit inductive bias for transfer learning with convolutional networks\n  (L2-SP, ICML 2018)](https://arxiv.org/abs/1802.01483)\n- [Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning (BSS, NIPS 2019)](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf)\n- [DEep Learning Transfer using Fea- ture Map with Attention for convolutional networks (DELTA, ICLR 2019)](https://openreview.net/pdf?id=rkgbwsAcYm)\n- [Co-Tuning for Transfer Learning (Co-Tuning, NIPS 2020)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf)\n- [Stochastic Normalization (StochNorm, NIPS 2020)](https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf)\n- [Learning Without Forgetting (LWF, ECCV 2016)](https://arxiv.org/abs/1606.09282)\n- [Bi-tuning of Pre-trained Representations (Bi-Tuning)](https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29)\n\n## Experiment and Results\n\nWe follow the common practice in the community as described\nin [Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning (BSS, NIPS 2019)](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf)\n.\n\nTraining iterations and data augmentations are kept the same for different task-adaptation methods for a fair\ncomparison.\n\nHyper-parameters of each method are selected by the performance on target validation data.\n\n### Fine-tune the supervised pre-trained model\n\nThe shell files give the script to reproduce the supervised pretrained benchmarks with specified hyper-parameters. For\nexample, if you want to use vanilla fine-tune on CUB200, use the following script\n\n```shell script\n# Fine-tune ResNet50 on CUB200.\n# Assume you have put the datasets under the path `data/cub200`, \n# or you are glad to download the datasets automatically from the Internet to this path\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/erm/cub200_100\n```\n\n#### Vision Benchmark on ResNet-50\n\n|                 | Food101 | CIFAR10 | CIFAR100 | SUN397 | Standford Cars | FGVC Aircraft  | DTD   | Oxford-IIIT Pets | Caltech-101    | Oxford 102 Flowers | average |\n|-----------------|---------|---------|----------|--------|----------------|----------------|-------|------------------|----------------|--------------------|---------|\n| Accuracy metric | top1    | top1    | top1     | top1   | top1           | mean per-class | top-1 | mean per-class   | mean per-class | mean per-class     |         |\n| Baseline        | 85.1    | 96.6    | 84.1     | 63.7   | 87.8           | 80.1           | 70.8  | 93.2             | 91.1           | 93.0               | 84.6    |\n| LWF             | 83.9    | 96.5    | 83.6     | 64.1   | 87.4           | 82.2           | 72.2  | 94.0             | 89.8           | 92.9               | 84.7    |\n| DELTA           | 83.8    | 95.9    | 83.7     | 64.5   | 88.1           | 82.3           | 72.2  | 94.2             | 90.1           | 93.1               | 84.8    |\n| BSS             | 85.0    | 96.6    | 84.2     | 63.5   | 88.4           | 81.8           | 70.2  | 93.3             | 91.6           | 92.7               | 84.7    |\n| StochNorm       | 85.0    | 96.8    | 83.9     | 63.0   | 87.7           | 81.5           | 71.3  | 93.6             | 90.5           | 92.9               | 84.6    |\n| Bi-Tuning       | 85.7    | 97.1    | 84.3     | 64.2   | 90.3           | 84.8           | 70.6  | 93.5             | 91.5           | 94.5               | 85.7    |\n\n#### CUB-200-2011 on ResNet-50 (Supervised Pre-trained)\n\n| CUB200    | 15%  | 30%  | 50%  | 100% | Avg  |\n|-----------|------|------|------|------|------|\n| ERM  | 51.2 | 64.6 | 74.6 | 81.8 | 68.1 |\n| lwf       | 56.7 | 66.8 | 73.4 | 81.5 | 69.6 |\n| BSS       | 53.4 | 66.7 | 76.0 | 82.0 | 69.5 |\n| delta     | 54.8 | 67.3 | 76.3 | 82.3 | 70.2 |\n| StochNorm | 54.8 | 66.8 | 75.8 | 82.2 | 69.9 |\n| Co-tuning | 57.6 | 70.1 | 77.3 | 82.5 | 71.9 |\n| bi-tuning | 55.8 | 69.3 | 77.2 | 83.1 | 71.4 |\n\n#### Stanford Cars on ResNet-50 (Supervised Pre-trained)\n\n| Standford Cars | 15%  | 30%  | 50%  | 100% | Avg  |\n|----------------|------|------|------|------|------|\n| ERM       | 41.1 | 65.9 | 78.4 | 87.8 | 68.3 |\n| lwf            | 44.9 | 67.0 | 77.6 | 87.5 | 69.3 |\n| BSS            | 43.3 | 67.6 | 79.6 | 88.0 | 69.6 |\n| delta          | 45.0 | 68.4 | 79.6 | 88.4 | 70.4 |\n| StochNorm      | 44.4 | 68.1 | 79.3 | 87.9 | 69.9 |\n| Co-tuning      | 49.0 | 70.6 | 81.9 | 89.1 | 72.7 |\n| bi-tuning      | 48.3 | 72.8 | 83.3 | 90.2 | 73.7 |\n\n#### FGVC Aircraft on ResNet-50 (Supervised Pre-trained)\n\n| FGVC Aircraft | 15%  | 30%  | 50%  | 100% | Avg  |\n|---------------|------|------|------|------|------|\n| ERM      | 41.6 | 57.8 | 68.7 | 80.2 | 62.1 |\n| lwf           | 44.1 | 60.6 | 68.7 | 82.4 | 64.0 |\n| BSS           | 43.6 | 59.5 | 69.6 | 81.2 | 63.5 |\n| delta         | 44.4 | 61.9 | 71.4 | 82.7 | 65.1 |\n| StochNorm     | 44.3 | 60.6 | 70.1 | 81.5 | 64.1 |\n| Co-tuning     | 45.9 | 61.2 | 71.3 | 82.2 | 65.2 |\n| bi-tuning     | 47.2 | 64.3 | 73.7 | 84.3 | 67.4 |\n\n### Fine-tune the unsupervised pre-trained model\n\nTake MoCo as an example.\n\n1. Download MoCo pretrained checkpoints from https://github.com/facebookresearch/moco\n2. Convert the format of the MoCo checkpoints to the standard format of pytorch\n\n```shell\nmkdir checkpoints\npython convert_moco_to_pretrained.py checkpoints/moco_v1_200ep_pretrain.pth.tar checkpoints/moco_v1_200ep_backbone.pth checkpoints/moco_v1_200ep_fc.pth\n```\n\n3. Start training\n\n```shell\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n```\n\n#### CUB-200-2011 on ResNet-50 (MoCo Pre-trained)\n\n| CUB200    | 15%  | 30%  | 50%  | 100% | Avg  |\n|-----------|------|------|------|------|------|\n| ERM  | 28.0 | 48.2 | 62.7 | 75.6 | 53.6 |\n| lwf       | 28.8 | 50.1 | 62.8 | 76.2 | 54.5 |\n| BSS       | 30.9 | 50.3 | 63.7 | 75.8 | 55.2 |\n| delta     | 27.9 | 51.4 | 65.9 | 74.6 | 55.0 |\n| StochNorm | 20.8 | 44.9 | 60.1 | 72.8 | 49.7 |\n| Co-tuning | 29.1 | 50.1 | 63.8 | 75.9 | 54.7 |\n| bi-tuning | 32.4 | 51.8 | 65.7 | 76.1 | 56.5 |\n\n#### Stanford Cars on ResNet-50 (MoCo Pre-trained)\n\n| Standford Cars | 15%  | 30%  | 50%  | 100% | Avg  |\n|----------------|------|------|------|------|------|\n| ERM       | 42.5 | 71.2 | 83.0 | 90.1 | 71.7 |\n| lwf            | 44.2 | 71.7 | 82.9 | 90.5 | 72.3 |\n| BSS            | 45.0 | 71.5 | 83.8 | 90.1 | 72.6 |\n| delta          | 45.9 | 72.9 | 82.5 | 88.9 | 72.6 |\n| StochNorm      | 40.3 | 66.2 | 78.0 | 86.2 | 67.7 |\n| Co-tuning      | 44.2 | 72.6 | 83.3 | 90.3 | 72.6 |\n| bi-tuning      | 45.6 | 72.8 | 83.2 | 90.8 | 73.1 |\n\n#### FGVC Aircraft on ResNet-50 (MoCo Pre-trained)\n\n| FGVC Aircraft | 15%  | 30%  | 50%  | 100% | Avg  |\n|---------------|------|------|------|------|------|\n| ERM      | 45.8 | 67.6 | 78.8 | 88.0 | 70.1 |\n| lwf           | 48.5 | 68.5 | 78.0 | 87.9 | 70.7 |\n| BSS           | 47.7 | 69.1 | 79.2 | 88.0 | 71.0 |\n| delta         | -    | -    | -    | -    | -    |\n| StochNorm     | 45.4 | 68.8 | 76.7 | 86.1 | 69.3 |\n| Co-tuning     | 48.2 | 68.5 | 78.7 | 87.3 | 70.7 |\n| bi-tuning     | 46.4 | 69.6 | 79.4 | 87.9 | 70.8 |\n\n## Citation\n\nIf you use these methods in your research, please consider citing.\n\n```\n@inproceedings{LWF,\n    author    = {Zhizhong Li and\n                Derek Hoiem},\n    title     = {Learning without Forgetting},\n    booktitle={ECCV},\n    year      = {2016},\n}\n\n@inproceedings{L2SP,\n    title={Explicit inductive bias for transfer learning with convolutional networks},\n    author={Xuhong, LI and Grandvalet, Yves and Davoine, Franck},\n    booktitle={ICML},\n    year={2018},\n}\n\n@inproceedings{BSS,\n    title={Catastrophic forgetting meets negative transfer: Batch spectral shrinkage for safe transfer learning},\n    author={Chen, Xinyang and Wang, Sinan and Fu, Bo and Long, Mingsheng and Wang, Jianmin},\n    booktitle={NeurIPS},\n    year={2019}\n}\n\n@inproceedings{DELTA,\n    title={Delta: Deep learning transfer using feature map with attention for convolutional networks},\n    author={Li, Xingjian and Xiong, Haoyi and Wang, Hanchao and Rao, Yuxuan and Liu, Liping and Chen, Zeyu and Huan, Jun},\n    booktitle={ICLR},\n    year={2019}\n}\n\n@inproceedings{StocNorm,\n    title={Stochastic Normalization},\n    author={Kou, Zhi and You, Kaichao and Long, Mingsheng and Wang, Jianmin},\n    booktitle={NeurIPS},\n    year={2020}\n}\n\n@inproceedings{CoTuning,\n    title={Co-Tuning for Transfer Learning},\n    author={You, Kaichao and Kou, Zhi and Long, Mingsheng and Wang, Jianmin},\n    booktitle={NeurIPS},\n    year={2020}\n}\n\n@article{BiTuning,\n    title={Bi-tuning of Pre-trained Representations},\n    author={Zhong, Jincheng and Wang, Ximei and Kou, Zhi and Wang, Jianmin and Long, Mingsheng},\n    journal={arXiv preprint arXiv:2011.06182},\n    year={2020}\n}\n```\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/bi_tuning.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\n\nimport utils\nfrom tllib.vision.transforms import MultipleApply\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.regularization.bi_tuning import Classifier, BiTuning\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_augmentation = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    train_transform = MultipleApply([train_augmentation, train_augmentation])\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,\n                                                                val_transform, args.sample_rate,\n                                                                args.num_samples_per_classes)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,\n                              num_workers=args.workers, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"training dataset size: {} test dataset size: {}\".format(len(train_dataset), len(val_dataset)))\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone_q = utils.get_model(args.arch, args.pretrained)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier_q = Classifier(backbone_q, num_classes, pool_layer=pool_layer, projection_dim=args.projection_dim,\n                              finetune=args.finetune)\n    if args.pretrained_fc:\n        print(\"=> loading pre-trained fc from '{}'\".format(args.pretrained_fc))\n        pretrained_fc_dict = torch.load(args.pretrained_fc)\n        classifier_q.projector.load_state_dict(pretrained_fc_dict, strict=False)\n    classifier_q = classifier_q.to(device)\n    backbone_k = utils.get_model(args.arch)\n    classifier_k = Classifier(backbone_k, num_classes, pool_layer=pool_layer).to(device)\n\n    bituning = BiTuning(classifier_q, classifier_k, num_classes, K=args.K, m=args.m, T=args.T)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier_q.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier_q.load_state_dict(checkpoint)\n        acc1 = utils.validate(val_loader, classifier_q, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, bituning, optimizer, epoch, args)\n        lr_scheduler.step()\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier_q, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier_q.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, bituning: BiTuning, optimizer: SGD, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    cls_losses = AverageMeter('Cls Loss', ':3.2f')\n    contrastive_losses = AverageMeter('Contrastive Loss', ':3.2f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_losses, contrastive_losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    classifier_criterion = torch.nn.CrossEntropyLoss().to(device)\n    contrastive_criterion = torch.nn.KLDivLoss(reduction='batchmean').to(device)\n\n    # switch to train mode\n    bituning.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels = next(train_iter)\n        img_q, img_k = x[0], x[1]\n\n        img_q = img_q.to(device)\n        img_k = img_k.to(device)\n        labels = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, logits_z, logits_y, bituning_labels = bituning(img_q, img_k, labels)\n        cls_loss = classifier_criterion(y, labels)\n        contrastive_loss_z = contrastive_criterion(logits_z, bituning_labels)\n        contrastive_loss_y = contrastive_criterion(logits_y, bituning_labels)\n        contrastive_loss = (contrastive_loss_z + contrastive_loss_y)\n        loss = cls_loss + contrastive_loss * args.trade_off\n\n        # measure accuracy and record loss\n        losses.update(loss.item(), x[0].size(0))\n        cls_losses.update(cls_loss.item(), x[0].size(0))\n        contrastive_losses.update(contrastive_loss.item(), x[0].size(0))\n\n        cls_acc = accuracy(y, labels)[0]\n        cls_accs.update(cls_acc.item(), x[0].size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Bi-tuning for Finetuning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')\n    parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor. Used in models such as ViT.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--pretrained-fc', default=None,\n                        help=\"pretrained checkpoint of the fc. \"\n                             \"(default: None)\")\n    parser.add_argument('--T', default=0.07, type=float, help=\"temperature. (default: 0.07)\")\n    parser.add_argument('--K', type=int, default=40, help=\"queue size. (default: 40)\")\n    parser.add_argument('--m', type=float, default=0.999, help=\"momentum coefficient. (default: 0.999)\")\n    parser.add_argument('--projection-dim', type=int, default=128,\n                        help=\"dimension of the projection head. (default: 128)\")\n    parser.add_argument('--trade-off', type=float, default=1.0, help=\"trade-off parameters. (default: 1.0)\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 48)')\n    parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam'])\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='bi_tuning',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/bi_tuning.sh",
    "content": "#!/usr/bin/env bash\n# Supervised Pretraining\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --finetune --seed 0 --log logs/bi_tuning/cub200_100\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 50 --finetune --seed 0 --log logs/bi_tuning/cub200_50\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 30 --finetune --seed 0 --log logs/bi_tuning/cub200_30\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 15 --finetune --seed 0 --log logs/bi_tuning/cub200_15\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 100 --finetune --seed 0 --log logs/bi_tuning/car_100\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 50 --finetune --seed 0 --log logs/bi_tuning/car_50\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 30 --finetune --seed 0 --log logs/bi_tuning/car_30\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 15 --finetune --seed 0 --log logs/bi_tuning/car_15\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/bi_tuning/aircraft_100\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/bi_tuning/aircraft_50\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/bi_tuning/aircraft_30\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/bi_tuning/aircraft_15\n\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/bi_tuning/cifar10/1e-2 --lr 1e-2\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/bi_tuning/cifar100/1e-2 --lr 1e-2\n# Flowers\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/bi_tuning/oxford_flowers102/1e-2 --lr 1e-2\n# Pets\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/bi_tuning/oxford_pet/1e-2 --lr 1e-2\n# DTD\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/dtd -d DTD --seed 0 --finetune --log logs/bi_tuning/dtd/1e-2 --lr 1e-2\n# caltech101\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/bi_tuning/caltech101/lr_1e-3 --lr 1e-3\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/bi_tuning/sun397/lr_1e-2 --lr 1e-2\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/food-101 -d Food101 --seed 0 --finetune --log logs/bi_tuning/food-101/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/bi_tuning/stanford_cars/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/bi_tuning/aircraft/lr_1e-2 --lr 1e-2\n\n# MoCo (Unsupervised Pretraining)\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bi_tuning/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/bss.py",
    "content": "\"\"\"\n@author: Yifei Ji, Junguang Jiang\n@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.regularization.bss import BatchSpectralShrinkage\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,\n                                                                val_transform, args.sample_rate,\n                                                                args.num_samples_per_classes)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,\n                              num_workers=args.workers, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"training dataset size: {} test dataset size: {}\".format(len(train_dataset), len(val_dataset)))\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, args.pretrained)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)\n    bss_module = BatchSpectralShrinkage(k=args.k)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1 = utils.validate(val_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, bss_module, optimizer, epoch, args)\n        lr_scheduler.step()\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model: Classifier, bss_module, optimizer: SGD,\n          epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels = next(train_iter)\n\n        x = x.to(device)\n        label = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, f = model(x)\n        cls_loss = F.cross_entropy(y, label)\n        bss_loss = bss_module(f)\n        loss = cls_loss + args.trade_off * bss_loss\n\n        cls_acc = accuracy(y, label)[0]\n\n        losses.update(loss.item(), x.size(0))\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='BSS for Finetuning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')\n    parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor. Used in models such as ViT.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('-k', '--k', default=1, type=int,\n                        metavar='N',\n                        help='hyper-parameter for BSS loss')\n    parser.add_argument('--trade-off', default=0.001, type=float,\n                        metavar='P', help='trade-off weight of BSS loss')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 48)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='bss',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/bss.sh",
    "content": "#!/usr/bin/env bash\n# Supervised Pretraining\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/bss/cub200_100\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/bss/cub200_50\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/bss/cub200_30\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/bss/cub200_15\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/bss/car_100\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/bss/car_50\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/bss/car_30\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/bss/car_15\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/bss/aircraft_100\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/bss/aircraft_50\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/bss/aircraft_30\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/bss/aircraft_15\n\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/bss/cifar10/1e-2 --lr 1e-2\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/bss/cifar100/1e-2 --lr 1e-2\n# Flowers\nCUDA_VISIBLE_DEVICES=0 python bss.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/bss/oxford_flowers102/1e-2 --lr 1e-2\n# Pets\nCUDA_VISIBLE_DEVICES=0 python bss.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/bss/oxford_pet/1e-2 --lr 1e-2\n# DTD\nCUDA_VISIBLE_DEVICES=0 python bss.py data/dtd -d DTD --seed 0 --finetune --log logs/bss/dtd/1e-2 --lr 1e-2\n# caltech101\nCUDA_VISIBLE_DEVICES=0 python bss.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/bss/caltech101/lr_1e-3 --lr 1e-3\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python bss.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/bss/sun397/lr_1e-2 --lr 1e-2\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python bss.py data/food-101 -d Food101 --seed 0 --finetune --log logs/bss/food-101/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/bss/stanford_cars/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/bss/aircraft/lr_1e-2 --lr 1e-2\n\n# MoCo (Unsupervised Pretraining)\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_bss/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/co_tuning.py",
    "content": "\"\"\"\n@author: Yifei Ji, Junguang Jiang\n@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\nimport os\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\nfrom torch.utils.data import Subset\n\nimport utils\nfrom tllib.regularization.co_tuning import CoTuningLoss, Relationship, Classifier\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.data import ForeverDataIterator\nimport tllib.vision.datasets as datasets\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef get_dataset(dataset_name, root, train_transform, val_transform, sample_rate=100, num_samples_per_classes=None):\n    dataset = datasets.__dict__[dataset_name]\n    if sample_rate < 100:\n        train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True,\n                                transform=train_transform)\n        determin_train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True,\n                                         transform=val_transform)\n        test_dataset = dataset(root=root, split='test', sample_rate=100, download=True, transform=val_transform)\n        num_classes = train_dataset.num_classes\n    else:\n        train_dataset = dataset(root=root, split='train', transform=train_transform)\n        determin_train_dataset = dataset(root=root, split='train', transform=val_transform)\n        test_dataset = dataset(root=root, split='test', transform=val_transform)\n        num_classes = train_dataset.num_classes\n        if num_samples_per_classes is not None:\n            samples = list(range(len(train_dataset)))\n            random.shuffle(samples)\n            samples_len = min(num_samples_per_classes * num_classes, len(train_dataset))\n            train_dataset = Subset(train_dataset, samples[:samples_len])\n            determin_train_dataset = Subset(determin_train_dataset, samples[:samples_len])\n    return train_dataset, determin_train_dataset, test_dataset, num_classes\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, determin_train_dataset, val_dataset, num_classes = get_dataset(args.data, args.root, train_transform,\n                                                                                  val_transform, args.sample_rate,\n                                                                                  args.num_samples_per_classes)\n    print(\"training dataset size: {} test dataset size: {}\".format(len(train_dataset), len(val_dataset)))\n\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,\n                              num_workers=args.workers, drop_last=True)\n    determin_train_loader = DataLoader(determin_train_dataset, batch_size=args.batch_size,\n                                       shuffle=False, num_workers=args.workers, drop_last=False)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, args.pretrained)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, head_source=backbone.copy_head(), pool_layer=pool_layer,\n                            finetune=args.finetune).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1 = utils.validate(val_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # build relationship between source classes and target classes\n    source_classifier = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.head_source)\n    relationship = Relationship(determin_train_loader, source_classifier, device,\n                                os.path.join(logger.root, args.relationship))\n    co_tuning_loss = CoTuningLoss()\n\n    # start training\n    best_acc1 = 0.0\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_iter, classifier, optimizer, epoch, relationship, co_tuning_loss, args)\n        lr_scheduler.step()\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,\n          epoch: int, relationship, co_tuning_loss, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, label_t = next(train_iter)\n\n        x = x.to(device)\n        label_s = torch.from_numpy(relationship[label_t]).cuda().float()\n        label_t = label_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, y_t = model(x)\n        tgt_loss = F.cross_entropy(y_t, label_t)\n        src_loss = co_tuning_loss(y_s, label_s)\n        loss = tgt_loss + args.trade_off * src_loss\n\n        # measure accuracy and record loss\n        losses.update(loss.item(), x.size(0))\n        cls_acc = accuracy(y_t, label_t)[0]\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Co-Tuning for Finetuning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')\n    parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor. Used in models such as ViT.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--trade-off', default=2.3, type=float,\n                        metavar='P', help='the trade-off hyper-parameter for co-tuning loss')\n    parser.add_argument(\"--relationship\", type=str, default='relationship.npy',\n                        help=\"Where to save relationship file.\")\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 48)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='cotuning',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/co_tuning.sh",
    "content": "#!/usr/bin/env bash\n# Supervised Pretraining\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/co_tuning/cub200_100\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/co_tuning/cub200_50\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/co_tuning/cub200_30\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/co_tuning/cub200_15\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/co_tuning/car_100\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/co_tuning/car_50\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/co_tuning/car_30\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/co_tuning/car_15\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/co_tuning/aircraft_100\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/co_tuning/aircraft_50\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/co_tuning/aircraft_30\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/co_tuning/aircraft_15\n\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/co_tuning/cifar10/1e-2 --lr 1e-2\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/co_tuning/cifar100/1e-2 --lr 1e-2\n# Flowers\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/co_tuning/oxford_flowers102/1e-2 --lr 1e-2\n# Pets\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/co_tuning/oxford_pet/1e-2 --lr 1e-2\n# DTD\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/dtd -d DTD --seed 0 --finetune --log logs/co_tuning/dtd/1e-2 --lr 1e-2\n# caltech101\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/co_tuning/caltech101/lr_1e-3 --lr 1e-3\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/co_tuning/sun397/lr_1e-2 --lr 1e-2\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/food-101 -d Food101 --seed 0 --finetune --log logs/co_tuning/food-101/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/co_tuning/stanford_cars/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/co_tuning/aircraft/lr_1e-2 --lr 1e-2\n\n\n# MoCo (Unsupervised Pretraining)\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_co_tuning/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/convert_moco_to_pretrained.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport sys\nimport torch\n\nif __name__ == \"__main__\":\n    input = sys.argv[1]\n\n    obj = torch.load(input, map_location=\"cpu\")\n    obj = obj[\"state_dict\"]\n\n    newmodel = {}\n    fc = {}\n    for k, v in obj.items():\n        if not k.startswith(\"module.encoder_q.\"):\n            continue\n        old_k = k\n        k = k.replace(\"module.encoder_q.\", \"\")\n        if k.startswith(\"fc\"):\n            print(k)\n            fc[k] = v\n        else:\n            newmodel[k] = v\n\n    with open(sys.argv[2], \"wb\") as f:\n        torch.save(newmodel, f)\n\n    with open(sys.argv[3], \"wb\") as f:\n        torch.save(fc, f)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/delta.py",
    "content": "\"\"\"\n@author: Yifei Ji, Junguang Jiang\n@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com\n\"\"\"\nimport math\nimport os\nimport random\nimport time\nimport warnings\nimport sys\nimport argparse\nimport shutil\n\nimport numpy as np\nfrom tqdm import tqdm\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.regularization.delta import *\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,\n                                                                val_transform, args.sample_rate,\n                                                                args.num_samples_per_classes)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,\n                              num_workers=args.workers, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"training dataset size: {} test dataset size: {}\".format(len(train_dataset), len(val_dataset)))\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, args.pretrained)\n    backbone_source = utils.get_model(args.arch, args.pretrained)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)\n    source_classifier = Classifier(backbone_source, num_classes=backbone_source.fc.out_features,\n                                   head=backbone_source.copy_head(), pool_layer=pool_layer).to(device)\n    for param in source_classifier.parameters():\n        param.requires_grad = False\n    source_classifier.eval()\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1 = utils.validate(val_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # create intermediate layer getter\n    if args.arch == 'resnet50':\n        return_layers = ['backbone.layer1.2.conv3', 'backbone.layer2.3.conv3', 'backbone.layer3.5.conv3',\n                         'backbone.layer4.2.conv3']\n    elif args.arch == 'resnet101':\n        return_layers = ['backbone.layer1.2.conv3', 'backbone.layer2.3.conv3', 'backbone.layer3.5.conv3',\n                         'backbone.layer4.2.conv3']\n    else:\n        raise NotImplementedError(args.arch)\n    source_getter = IntermediateLayerGetter(source_classifier, return_layers=return_layers)\n    target_getter = IntermediateLayerGetter(classifier, return_layers=return_layers)\n\n    # get regularization\n    if args.regularization_type == 'l2_sp':\n        backbone_regularization = SPRegularization(source_classifier.backbone, classifier.backbone)\n    elif args.regularization_type == 'feature_map':\n        backbone_regularization = BehavioralRegularization()\n    elif args.regularization_type == 'attention_feature_map':\n        attention_file = os.path.join(logger.root, args.attention_file)\n        if not os.path.exists(attention_file):\n            attention = calculate_channel_attention(train_dataset, return_layers, num_classes, args)\n            torch.save(attention, attention_file)\n        else:\n            print(\"Loading channel attention from\", attention_file)\n            attention = torch.load(attention_file)\n            attention = [a.to(device) for a in attention]\n        backbone_regularization = AttentionBehavioralRegularization(attention)\n    else:\n        raise NotImplementedError(args.regularization_type)\n\n    head_regularization = L2Regularization(nn.ModuleList([classifier.head, classifier.bottleneck]))\n\n    # start training\n    best_acc1 = 0.0\n\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, backbone_regularization, head_regularization, target_getter, source_getter,\n              optimizer, epoch, args)\n        lr_scheduler.step()\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    logger.close()\n\n\ndef calculate_channel_attention(dataset, return_layers, num_classes, args):\n    backbone = utils.get_model(args.arch)\n    classifier = Classifier(backbone, num_classes).to(device)\n    optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    data_loader = DataLoader(dataset, batch_size=args.attention_batch_size, shuffle=True,\n                             num_workers=args.workers, drop_last=False)\n    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=math.exp(\n        math.log(0.1) / args.attention_lr_decay_epochs))\n    criterion = nn.CrossEntropyLoss()\n\n    channel_weights = []\n    for layer_id, name in enumerate(return_layers):\n        layer = get_attribute(classifier, name)\n        layer_channel_weight = [0] * layer.out_channels\n        channel_weights.append(layer_channel_weight)\n\n    # train the classifier\n    classifier.train()\n    classifier.backbone.requires_grad = False\n    print(\"Pretrain a classifier to calculate channel attention.\")\n    for epoch in range(args.attention_epochs):\n        losses = AverageMeter('Loss', ':3.2f')\n        cls_accs = AverageMeter('Cls Acc', ':3.1f')\n        progress = ProgressMeter(\n            len(data_loader),\n            [losses, cls_accs],\n            prefix=\"Epoch: [{}]\".format(epoch))\n\n        for i, data in enumerate(data_loader):\n            inputs, labels = data\n            inputs = inputs.to(device)\n            labels = labels.to(device)\n            outputs, _ = classifier(inputs)\n            loss = criterion(outputs, labels)\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            cls_acc = accuracy(outputs, labels)[0]\n\n            losses.update(loss.item(), inputs.size(0))\n            cls_accs.update(cls_acc.item(), inputs.size(0))\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n        lr_scheduler.step()\n\n    # calculate the channel attention\n    print('Calculating channel attention.')\n    classifier.eval()\n    if args.attention_iteration_limit > 0:\n        total_iteration = min(len(data_loader), args.attention_iteration_limit)\n    else:\n        total_iteration = len(args.data_loader)\n\n    progress = ProgressMeter(\n        total_iteration,\n        [],\n        prefix=\"Iteration: \")\n\n    for i, data in enumerate(data_loader):\n        if i >= total_iteration:\n            break\n        inputs, labels = data\n        inputs = inputs.to(device)\n        labels = labels.to(device)\n        outputs = classifier(inputs)\n        loss_0 = criterion(outputs, labels)\n        progress.display(i)\n        for layer_id, name in enumerate(tqdm(return_layers)):\n            layer = get_attribute(classifier, name)\n            for j in range(layer.out_channels):\n                tmp = classifier.state_dict()[name + '.weight'][j,].clone()\n                classifier.state_dict()[name + '.weight'][j,] = 0.0\n                outputs = classifier(inputs)\n                loss_1 = criterion(outputs, labels)\n                difference = loss_1 - loss_0\n                difference = difference.detach().cpu().numpy().item()\n                history_value = channel_weights[layer_id][j]\n                channel_weights[layer_id][j] = 1.0 * (i * history_value + difference) / (i + 1)\n                classifier.state_dict()[name + '.weight'][j,] = tmp\n\n    channel_attention = []\n    for weight in channel_weights:\n        weight = np.array(weight)\n        weight = (weight - np.mean(weight)) / np.std(weight)\n        weight = torch.from_numpy(weight).float().to(device)\n        channel_attention.append(F.softmax(weight / 5).detach())\n    return channel_attention\n\n\ndef train(train_iter: ForeverDataIterator, model: Classifier, backbone_regularization: nn.Module,\n          head_regularization: nn.Module,\n          target_getter: IntermediateLayerGetter,\n          source_getter: IntermediateLayerGetter,\n          optimizer: SGD, epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f')\n    losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, losses_reg_head, losses_reg_backbone, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels = next(train_iter)\n        x = x.to(device)\n        label = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        intermediate_output_s, output_s = source_getter(x)\n        intermediate_output_t, output_t = target_getter(x)\n        y, f = output_t\n\n        # measure accuracy and record loss\n        cls_acc = accuracy(y, label)[0]\n        cls_loss = F.cross_entropy(y, label)\n        if args.regularization_type == 'feature_map':\n            loss_reg_backbone = backbone_regularization(intermediate_output_s, intermediate_output_t)\n        elif args.regularization_type == 'attention_feature_map':\n            loss_reg_backbone = backbone_regularization(intermediate_output_s, intermediate_output_t)\n        else:\n            loss_reg_backbone = backbone_regularization()\n        loss_reg_head = head_regularization()\n        loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head\n\n        losses_reg_backbone.update(loss_reg_backbone.item() * args.trade_off_backbone, x.size(0))\n        losses_reg_head.update(loss_reg_head.item() * args.trade_off_head, x.size(0))\n        losses.update(loss.item(), x.size(0))\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Delta for Finetuning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')\n    parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor. Used in models such as ViT.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    parser.add_argument('--regularization-type', choices=['l2_sp', 'feature_map', 'attention_feature_map'],\n                        default='attention_feature_map')\n    parser.add_argument('--trade-off-backbone', default=0.01, type=float,\n                        help='trade-off for backbone regularization')\n    parser.add_argument('--trade-off-head', default=0.01, type=float,\n                        help='trade-off for head regularization')\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 48)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0., type=float,\n                        metavar='W', help='weight decay (default: 0.)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='delta',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\"\n                             \"When phase is 'analysis', only analysis the model.\")\n\n    # parameters for calculating channel attention\n    parser.add_argument(\"--attention-file\", type=str, default='channel_attention.pt',\n                        help=\"Where to save and load channel attention file.\")\n    parser.add_argument('--attention-batch-size', default=32, type=int,\n                        metavar='N',\n                        help='mini-batch size for calculating channel attention (default: 32)')\n    parser.add_argument('--attention-epochs', default=10, type=int, metavar='N',\n                        help='number of epochs to train for training before calculating channel weight')\n    parser.add_argument('--attention-lr-decay-epochs', default=6, type=int, metavar='N',\n                        help='epochs to decay lr for training before calculating channel weight')\n    parser.add_argument('--attention-iteration-limit', default=10, type=int, metavar='N',\n                        help='iteration limits for calculating channel attention, -1 means no limits')\n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/delta.sh",
    "content": "#!/usr/bin/env bash\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/delta/cub200_100\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/delta/cub200_50\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/delta/cub200_30\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/delta/cub200_15\n\n# Stanford Cars\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/delta/car_100\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/delta/car_50\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/delta/car_30\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/delta/car_15\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/delta/aircraft_100\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/delta/aircraft_50\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/delta/aircraft_30\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/delta/aircraft_15\n\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/delta/cifar10/1e-2 --lr 1e-2\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/delta/cifar100/1e-2 --lr 1e-2\n# Flowers\nCUDA_VISIBLE_DEVICES=0 python delta.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/delta/oxford_flowers102/1e-2 --lr 1e-2\n# Pets\nCUDA_VISIBLE_DEVICES=0 python delta.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/delta/oxford_pet/1e-2 --lr 1e-2\n# DTD\nCUDA_VISIBLE_DEVICES=0 python delta.py data/dtd -d DTD --seed 0 --finetune --log logs/delta/dtd/1e-2 --lr 1e-2\n# caltech101\nCUDA_VISIBLE_DEVICES=0 python delta.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/delta/caltech101/lr_1e-3 --lr 1e-3\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python delta.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/delta/sun397/lr_1e-2 --lr 1e-2\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python delta.py data/food-101 -d Food101 --seed 0 --finetune --log logs/delta/food-101/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/delta/stanford_cars/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/delta/aircraft/lr_1e-2 --lr 1e-2\n\n# MoCo (Unsupervised Pretraining)\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cub200_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cub200_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cub200_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cub200_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cars_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cars_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cars_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/cars_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/aircraft_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/aircraft_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/aircraft_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\nCUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_delta/aircraft_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/erm.py",
    "content": "\"\"\"\n@author: Yifei Ji, Junguang Jiang\n@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,\n                                                                val_transform, args.sample_rate,\n                                                                args.num_samples_per_classes)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,\n                              num_workers=args.workers, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"training dataset size: {} test dataset size: {}\".format(len(train_dataset), len(val_dataset)))\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, args.pretrained)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1 = utils.validate(val_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    for epoch in range(args.epochs):\n        logger.set_epoch(epoch)\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, epoch, args)\n        lr_scheduler.step()\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,\n          epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels = next(train_iter)\n\n        x = x.to(device)\n        label = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, f = model(x)\n        cls_loss = F.cross_entropy(y, label)\n        loss = cls_loss\n\n        # measure accuracy and record loss\n        losses.update(loss.item(), x.size(0))\n        cls_acc = accuracy(y, label)[0]\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Baseline for Finetuning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')\n    parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor. Used in models such as ViT.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 48)')\n    parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam'])\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='baseline',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/erm.sh",
    "content": "#!/usr/bin/env bash\n# Supervised Pretraining\n# CUB-200-2011\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/erm/cub200_100\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/erm/cub200_50\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/erm/cub200_30\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/erm/cub200_15\n\n# Standford Cars\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/erm/car_100\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/erm/car_50\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/erm/car_30\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/erm/car_15\n\n# Aircrafts\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/erm/aircraft_100\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/erm/aircraft_50\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/erm/aircraft_30\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/erm/aircraft_15\n\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/erm/cifar10/1e-2 --lr 1e-2\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/erm/cifar100/1e-2 --lr 1e-2\n# Flowers\nCUDA_VISIBLE_DEVICES=0 python erm.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/erm/oxford_flowers102/1e-2 --lr 1e-2\n# Pets\nCUDA_VISIBLE_DEVICES=0 python erm.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/erm/oxford_pet/1e-2 --lr 1e-2\n# DTD\nCUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --seed 0 --finetune --log logs/erm/dtd/1e-2 --lr 1e-2\n# caltech101\nCUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/erm/caltech101/lr_1e-3 --lr 1e-3\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/erm/sun397/lr_1e-2 --lr 1e-2\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python erm.py data/food-101 -d Food101 --seed 0 --finetune --log logs/erm/food-101/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/erm/stanford_cars/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/erm/aircraft/lr_1e-2 --lr 1e-2\n\n# MoCo (Unsupervised Pretraining)\n#CUB-200-2011\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Standford Cars\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Aircrafts\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_erm/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/lwf.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader, TensorDataset\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.regularization.lwf import collect_pretrain_labels, Classifier\nfrom tllib.regularization.knowledge_distillation import KnowledgeDistillationLoss\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.logger import CompleteLogger\nfrom tllib.utils.data import ForeverDataIterator, CombineDataset\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,\n                                                                val_transform, args.sample_rate,\n                                                                args.num_samples_per_classes)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False,\n                              num_workers=args.workers, drop_last=False)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"training dataset size: {} test dataset size: {}\".format(len(train_dataset), len(val_dataset)))\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, args.pretrained)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, head_source=backbone.copy_head(), pool_layer=pool_layer,\n                            finetune=args.finetune).to(device)\n    kd = KnowledgeDistillationLoss(args.T)\n\n    source_classifier = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.head_source)\n    pretrain_labels = collect_pretrain_labels(train_loader, source_classifier, device)\n    train_dataset = CombineDataset([train_dataset, TensorDataset(pretrain_labels)])\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,\n                              shuffle=True, num_workers=args.workers, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1 = utils.validate(val_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    for epoch in range(args.epochs):\n        # train for one epoch\n        train(train_iter, classifier, kd, optimizer, epoch, args)\n        lr_scheduler.step()\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model: Classifier, kd, optimizer: SGD,\n          epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    losses_kd = AverageMeter('Loss (KD)', ':5.4f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, losses_kd, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, label_t, label_s = next(train_iter)\n\n        x = x.to(device)\n        label_s = label_s.to(device)\n        label_t = label_t.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y_s, y_t = model(x)\n        tgt_loss = F.cross_entropy(y_t, label_t)\n        src_loss = kd(y_s, label_s)\n        loss = tgt_loss + args.trade_off * src_loss\n\n        # measure accuracy and record loss\n        losses.update(tgt_loss.item(), x.size(0))\n        losses_kd.update(src_loss.item(), x.size(0))\n        cls_acc = accuracy(y_t, label_t)[0]\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='LWF (Learning without Forgetting) for Finetuning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')\n    parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor. Used in models such as ViT.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--trade-off', default=4, type=float,\n                        metavar='P', help='weight of pretrained loss')\n    parser.add_argument(\"-T\", type=float, default=3,\n                        help=\"temperature for knowledge distillation\")\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 48)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='lwf',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/lwf.sh",
    "content": "#!/usr/bin/env bash\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/lwf/cub200_100 --lr 0.01\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/lwf/cub200_50 --lr 0.001\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/lwf/cub200_30 --lr 0.001\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/lwf/cub200_15 --lr 0.001\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/lwf/car_100 --lr 0.01\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/lwf/car_50 --lr 0.01\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/lwf/car_30 --lr 0.01\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/lwf/car_15 --lr 0.01\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/lwf/aircraft_100 --lr 0.001\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/lwf/aircraft_50 --lr 0.001\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/lwf/aircraft_30 --lr 0.001\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/lwf/aircraft_15 --lr 0.001\n\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/lwf/cifar10/1e-2 --lr 1e-2\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/lwf/cifar100/1e-2 --lr 1e-2\n# Flowers\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/lwf/oxford_flowers102/1e-2 --lr 1e-2\n# Pets\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/lwf/oxford_pet/1e-2 --lr 1e-2\n# DTD\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/dtd -d DTD --seed 0 --finetune --log logs/lwf/dtd/1e-2 --lr 1e-2\n# caltech101\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/lwf/caltech101/lr_1e-3 --lr 1e-3\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/lwf/sun397/lr_1e-2 --lr 1e-2\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/food-101 -d Food101 --seed 0 --finetune --log logs/lwf/food-101/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/lwf/stanford_cars/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/lwf/aircraft/lr_1e-2 --lr 1e-2\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/requirements.txt",
    "content": "timm"
  },
  {
    "path": "examples/task_adaptation/image_classification/stochnorm.py",
    "content": "\"\"\"\n@author: Yifei Ji, Junguang Jiang\n@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com\n\"\"\"\nimport random\nimport time\nimport warnings\nimport argparse\nimport shutil\n\nimport torch\nimport torch.nn as nn\nimport torch.backends.cudnn as cudnn\nfrom torch.optim import SGD\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nimport utils\nfrom tllib.normalization.stochnorm import convert_model\nfrom tllib.modules.classifier import Classifier\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.utils.data import ForeverDataIterator\nfrom tllib.utils.logger import CompleteLogger\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef main(args: argparse.Namespace):\n    logger = CompleteLogger(args.log, args.phase)\n    print(args)\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n        warnings.warn('You have chosen to seed training. '\n                      'This will turn on the CUDNN deterministic setting, '\n                      'which can slow down your training considerably! '\n                      'You may see unexpected behavior when restarting '\n                      'from checkpoints.')\n\n    cudnn.benchmark = True\n\n    # Data loading code\n    train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)\n    val_transform = utils.get_val_transform(args.val_resizing)\n    print(\"train_transform: \", train_transform)\n    print(\"val_transform: \", val_transform)\n\n    train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,\n                                                                val_transform, args.sample_rate,\n                                                                args.num_samples_per_classes)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,\n                              num_workers=args.workers, drop_last=True)\n    train_iter = ForeverDataIterator(train_loader)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)\n    print(\"training dataset size: {} test dataset size: {}\".format(len(train_dataset), len(val_dataset)))\n\n    # create model\n    print(\"=> using pre-trained model '{}'\".format(args.arch))\n    backbone = utils.get_model(args.arch, args.pretrained)\n    pool_layer = nn.Identity() if args.no_pool else None\n    classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)\n    classifier = convert_model(classifier, p=args.prob)\n\n    # define optimizer and lr scheduler\n    optimizer = SGD(classifier.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd,\n                    nesterov=True)\n    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)\n\n    # resume from the best checkpoint\n    if args.phase == 'test':\n        checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')\n        classifier.load_state_dict(checkpoint)\n        acc1 = utils.validate(val_loader, classifier, args, device)\n        print(acc1)\n        return\n\n    # start training\n    best_acc1 = 0.0\n    for epoch in range(args.epochs):\n        print(lr_scheduler.get_lr())\n        # train for one epoch\n        train(train_iter, classifier, optimizer, epoch, args)\n        lr_scheduler.step()\n\n        # evaluate on validation set\n        acc1 = utils.validate(val_loader, classifier, args, device)\n\n        # remember best acc@1 and save checkpoint\n        torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))\n        if acc1 > best_acc1:\n            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))\n        best_acc1 = max(acc1, best_acc1)\n\n    print(\"best_acc1 = {:3.1f}\".format(best_acc1))\n    logger.close()\n\n\ndef train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,\n          epoch: int, args: argparse.Namespace):\n    batch_time = AverageMeter('Time', ':4.2f')\n    data_time = AverageMeter('Data', ':3.1f')\n    losses = AverageMeter('Loss', ':3.2f')\n    cls_accs = AverageMeter('Cls Acc', ':3.1f')\n\n    progress = ProgressMeter(\n        args.iters_per_epoch,\n        [batch_time, data_time, losses, cls_accs],\n        prefix=\"Epoch: [{}]\".format(epoch))\n\n    # switch to train mode\n    model.train()\n\n    end = time.time()\n    for i in range(args.iters_per_epoch):\n        x, labels = next(train_iter)\n\n        x = x.to(device)\n        label = labels.to(device)\n\n        # measure data loading time\n        data_time.update(time.time() - end)\n\n        # compute output\n        y, f = model(x)\n\n        cls_loss = F.cross_entropy(y, label)\n        loss = cls_loss\n\n        cls_acc = accuracy(y, label)[0]\n\n        losses.update(loss.item(), x.size(0))\n        cls_accs.update(cls_acc.item(), x.size(0))\n\n        # compute gradient and do SGD step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n        if i % args.print_freq == 0:\n            progress.display(i)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='StochNorm for Finetuning')\n    # dataset parameters\n    parser.add_argument('root', metavar='DIR',\n                        help='root path of dataset')\n    parser.add_argument('-d', '--data', metavar='DATA')\n    parser.add_argument('-sr', '--sample-rate', default=100, type=int,\n                        metavar='N',\n                        help='sample rate of training dataset (default: 100)')\n    parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,\n                        help='number of samples per classes.')\n    parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')\n    parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')\n    parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')\n    parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')\n    # model parameters\n    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',\n                        choices=utils.get_model_names(),\n                        help='backbone architecture: ' +\n                             ' | '.join(utils.get_model_names()) +\n                             ' (default: resnet50)')\n    parser.add_argument('--no-pool', action='store_true',\n                        help='no pool layer after the feature extractor. Used in models such as ViT.')\n    parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')\n    parser.add_argument('--prob', '--probability', default=0.5, type=float,\n                        metavar='P', help='Probability for StochNorm layers')\n    parser.add_argument('--pretrained', default=None,\n                        help=\"pretrained checkpoint of the backbone. \"\n                             \"(default: None, use the ImageNet supervised pretrained backbone)\")\n    # training parameters\n    parser.add_argument('-b', '--batch-size', default=48, type=int,\n                        metavar='N',\n                        help='mini-batch size (default: 48)')\n    parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,\n                        metavar='LR', help='initial learning rate', dest='lr')\n    parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')\n    parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')\n    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',\n                        help='momentum')\n    parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,\n                        metavar='W', help='weight decay (default: 5e-4)')\n    parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',\n                        help='number of data loading workers (default: 2)')\n    parser.add_argument('--epochs', default=20, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,\n                        help='Number of iterations per epoch')\n    parser.add_argument('-p', '--print-freq', default=100, type=int,\n                        metavar='N', help='print frequency (default: 100)')\n    parser.add_argument('--seed', default=None, type=int,\n                        help='seed for initializing training. ')\n    parser.add_argument(\"--log\", type=str, default='stochnorm',\n                        help=\"Where to save logs, checkpoints and debugging images.\")\n    parser.add_argument(\"--phase\", type=str, default='train', choices=['train', 'test'],\n                        help=\"When phase is 'test', only test the model.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/stochnorm.sh",
    "content": "#!/usr/bin/env bash\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/stochnorm/cub200_100\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/stochnorm/cub200_50\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/stochnorm/cub200_30\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/stochnorm/cub200_15\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/stochnorm/car_100\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/stochnorm/car_50\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/stochnorm/car_30\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/stochnorm/car_15\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/stochnorm/aircraft_100\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/stochnorm/aircraft_50\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/stochnorm/aircraft_30\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/stochnorm/aircraft_15\n\n# CIFAR10\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/stochnorm/cifar10/1e-2 --lr 1e-2\n# CIFAR100\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/stochnorm/cifar100/1e-2 --lr 1e-2\n# Flowers\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/stochnorm/oxford_flowers102/1e-2 --lr 1e-2\n# Pets\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/stochnorm/oxford_pet/1e-2 --lr 1e-2\n# DTD\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/dtd -d DTD --seed 0 --finetune --log logs/stochnorm/dtd/1e-2 --lr 1e-2\n# caltech101\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/stochnorm/caltech101/lr_1e-3 --lr 1e-3\n# SUN397\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/stochnorm/sun397/lr_1e-2 --lr 1e-2\n# Food 101\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/food-101 -d Food101 --seed 0 --finetune --log logs/stochnorm/food-101/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/stochnorm/stanford_cars/lr_1e-2 --lr 1e-2\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/stochnorm/aircraft/lr_1e-2 --lr 1e-2\n\n# MoCo (Unsupervised Pretraining)\n# CUB-200-2011\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Standford Cars\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n\n# Aircrafts\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth\nCUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \\\n  --log logs/moco_pretrain_stochnorm/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth\n"
  },
  {
    "path": "examples/task_adaptation/image_classification/utils.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport time\nfrom PIL import Image\nimport timm\nimport numpy as np\nimport random\nimport sys\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import Subset\nimport torchvision.transforms as T\nfrom torch.optim import SGD, Adam\n\nsys.path.append('../../..')\nimport tllib.vision.datasets as datasets\nimport tllib.vision.models as models\nfrom tllib.utils.metric import accuracy\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\nfrom tllib.vision.transforms import Denormalize\n\n\ndef get_model_names():\n    return sorted(\n        name for name in models.__dict__\n        if name.islower() and not name.startswith(\"__\")\n        and callable(models.__dict__[name])\n    ) + timm.list_models()\n\n\ndef get_model(model_name, pretrained_checkpoint=None):\n    if model_name in models.__dict__:\n        # load models from tllib.vision.models\n        backbone = models.__dict__[model_name](pretrained=True)\n    else:\n        # load models from pytorch-image-models\n        backbone = timm.create_model(model_name, pretrained=True)\n        try:\n            backbone.out_features = backbone.get_classifier().in_features\n            backbone.reset_classifier(0, '')\n            backbone.copy_head = backbone.get_classifier\n        except:\n            backbone.out_features = backbone.head.in_features\n            backbone.head = nn.Identity()\n            backbone.copy_head = lambda x: x.head\n    if pretrained_checkpoint:\n        print(\"=> loading pre-trained model from '{}'\".format(pretrained_checkpoint))\n        pretrained_dict = torch.load(pretrained_checkpoint)\n        backbone.load_state_dict(pretrained_dict, strict=False)\n    return backbone\n\n\ndef get_dataset(dataset_name, root, train_transform, val_transform, sample_rate=100, num_samples_per_classes=None):\n    \"\"\"\n    When sample_rate < 100,  e.g. sample_rate = 50, use 50% data to train the model.\n    Otherwise,\n        if num_samples_per_classes is not None, e.g. 5, then sample 5 images for each class, and use them to train the model;\n        otherwise, keep all the data.\n    \"\"\"\n    dataset = datasets.__dict__[dataset_name]\n    if sample_rate < 100:\n        train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True, transform=train_transform)\n        test_dataset = dataset(root=root, split='test', sample_rate=100, download=True, transform=val_transform)\n        num_classes = train_dataset.num_classes\n    else:\n        train_dataset = dataset(root=root, split='train', download=True, transform=train_transform)\n        test_dataset = dataset(root=root, split='test', download=True, transform=val_transform)\n        num_classes = train_dataset.num_classes\n        if num_samples_per_classes is not None:\n            samples = list(range(len(train_dataset)))\n            random.shuffle(samples)\n            samples_len = min(num_samples_per_classes * num_classes, len(train_dataset))\n            print(\"Origin dataset:\", len(train_dataset), \"Sampled dataset:\", samples_len, \"Ratio:\", float(samples_len) / len(train_dataset))\n            train_dataset = Subset(train_dataset, samples[:samples_len])\n    return train_dataset, test_dataset, num_classes\n\n\ndef validate(val_loader, model, args, device, visualize=None) -> float:\n    batch_time = AverageMeter('Time', ':6.3f')\n    losses = AverageMeter('Loss', ':.4e')\n    top1 = AverageMeter('Acc@1', ':6.2f')\n    progress = ProgressMeter(\n        len(val_loader),\n        [batch_time, losses, top1],\n        prefix='Test: ')\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images, target) in enumerate(val_loader):\n            images = images.to(device)\n            target = target.to(device)\n\n            # compute output\n            output = model(images)\n            loss = F.cross_entropy(output, target)\n\n            # measure accuracy and record loss\n            acc1, = accuracy(output, target, topk=(1, ))\n            losses.update(loss.item(), images.size(0))\n            top1.update(acc1.item(), images.size(0))\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % args.print_freq == 0:\n                progress.display(i)\n                if visualize is not None:\n                    visualize(images[0], \"val_{}\".format(i))\n\n        print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))\n\n    return top1.avg\n\n\ndef get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False):\n    \"\"\"\n    resizing mode:\n        - default: take a random resized crop of size 224 with scale in [0.2, 1.];\n        - res: resize the image to 224;\n        - res.|crop: resize the image to 256 and take a random crop of size 224;\n        - res.sma|crop: resize the image keeping its aspect ratio such that the\n            smaller side is 256, then take a random crop of size 224;\n        – inc.crop: “inception crop” from (Szegedy et al., 2015);\n        – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.RandomResizedCrop(224, scale=(0.2, 1.))\n    elif resizing == 'res.':\n        transform = T.Resize((224, 224))\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize((256, 256)),\n            T.RandomCrop(224)\n        ])\n    elif resizing == \"res.sma|crop\":\n        transform = T.Compose([\n            T.Resize(256),\n            T.RandomCrop(224)\n        ])\n    elif resizing == 'inc.crop':\n        transform = T.RandomResizedCrop(224)\n    elif resizing == 'cif.crop':\n        transform = T.Compose([\n            T.Resize((224, 224)),\n            T.Pad(28),\n            T.RandomCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    transforms = [transform]\n    if random_horizontal_flip:\n        transforms.append(T.RandomHorizontalFlip())\n    if random_color_jitter:\n        transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))\n    transforms.extend([\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n    return T.Compose(transforms)\n\n\ndef get_val_transform(resizing='default'):\n    \"\"\"\n    resizing mode:\n        - default: resize the image to 256 and take the center crop of size 224;\n        – res.: resize the image to 224\n        – res.|crop: resize the image such that the smaller side is of size 256 and\n            then take a central crop of size 224.\n    \"\"\"\n    if resizing == 'default':\n        transform = T.Compose([\n            T.Resize(256),\n            T.CenterCrop(224),\n        ])\n    elif resizing == 'res.':\n        transform = T.Resize((224, 224))\n    elif resizing == 'res.|crop':\n        transform = T.Compose([\n            T.Resize((256, 256)),\n            T.CenterCrop(224),\n        ])\n    else:\n        raise NotImplementedError(resizing)\n    return T.Compose([\n        transform,\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n    ])\n\n\ndef get_optimizer(optimizer_name, params, lr, wd, momentum):\n    '''\n    Args:\n        optimizer_name:\n            - SGD\n            - Adam\n        params: iterable of parameters to optimize or dicts defining parameter groups\n        lr: learning rate\n        weight_decay: weight decay\n        momentum: momentum factor for SGD\n    '''\n    if optimizer_name == 'SGD':\n        optimizer = SGD(params=params, lr=lr, momentum=momentum, weight_decay=wd, nesterov=True)\n    elif optimizer_name == 'Adam':\n        optimizer = Adam(params=params, lr=lr, weight_decay=wd)\n    else:\n        raise NotImplementedError(optimizer_name)\n    return optimizer\n\n\ndef visualize(image, filename):\n    \"\"\"\n    Args:\n        image (tensor): 3 x H x W\n        filename: filename of the saving image\n    \"\"\"\n    image = image.detach().cpu()\n    image = Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)\n    image = image.numpy().transpose((1, 2, 0)) * 255\n    Image.fromarray(np.uint8(image)).save(filename)\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch>=1.7.0\ntorchvision>=0.5.0\nnumpy\nprettytable\ntqdm\nscikit-learn\nwebcolors\nmatplotlib\nopencv-python\nnumba\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\nimport re\nfrom os import path\n\nhere = path.abspath(path.dirname(__file__))\n\n# Get the version string\nwith open(path.join(here, 'tllib', '__init__.py')) as f:\n    version = re.search(r'__version__ = \\'(.*?)\\'', f.read()).group(1)\n\n# Get all runtime requirements\nREQUIRES = []\nwith open('requirements.txt') as f:\n    for line in f:\n        line, _, _ = line.partition('#')\n        line = line.strip()\n        REQUIRES.append(line)\n\nif __name__ == '__main__':\n    setup(\n        name=\"tllib\", # Replace with your own username\n        version=version,\n        author=\"THUML\",\n        author_email=\"JiangJunguang1123@outlook.com\",\n        keywords=\"domain adaptation, task adaptation, domain generalization, \"\n                 \"transfer learning, deep learning, pytorch\",\n        description=\"A Transfer Learning Library for Domain Adaptation, Task Adaptation, and Domain Generalization\",\n        long_description=open('README.md', encoding='utf8').read(),\n        long_description_content_type=\"text/markdown\",\n        url=\"https://github.com/thuml/Transfer-Learning-Library\",\n        packages=find_packages(exclude=['docs', 'examples']),\n        classifiers=[\n            # How mature is this project? Common values are\n            #   3 - Alpha\n            #   4 - Beta\n            #   5 - Production/Stable\n            'Development Status :: 3 - Alpha',\n            # Indicate who your project is intended for\n            'Intended Audience :: Science/Research',\n            'Topic :: Scientific/Engineering :: Artificial Intelligence',\n            'Topic :: Software Development :: Libraries :: Python Modules',\n            # Pick your license as you wish (should match \"license\" above)\n            'License :: OSI Approved :: MIT License',\n            # Specify the Python versions you support here. In particular, ensure\n            # that you indicate whether you support Python 2, Python 3 or both.\n            'Programming Language :: Python :: 3.6',\n            'Programming Language :: Python :: 3.7',\n            'Programming Language :: Python :: 3.8',\n        ],\n        python_requires='>=3.6',\n        install_requires=REQUIRES,\n        extras_require={\n            'dev': [\n                'Sphinx',\n                'sphinx_rtd_theme',\n            ]\n        },\n    )\n"
  },
  {
    "path": "tllib/__init__.py",
    "content": "from . import alignment\nfrom . import self_training\nfrom . import translation\nfrom . import regularization\nfrom . import utils\nfrom . import vision\nfrom . import modules\nfrom . import ranking\n\n__version__ = '0.4'\n\n__all__ = ['alignment', 'self_training', 'translation', 'regularization', 'utils', 'vision', 'modules', 'ranking']\n"
  },
  {
    "path": "tllib/alignment/__init__.py",
    "content": "from . import cdan\nfrom . import dann\nfrom . import mdd\nfrom . import dan\nfrom . import jan\nfrom . import mcd\nfrom . import osbp\nfrom . import adda\nfrom . import bsp\n"
  },
  {
    "path": "tllib/alignment/adda.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom typing import Optional, List, Dict\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom tllib.modules.classifier import Classifier as ClassifierBase\n\n\nclass DomainAdversarialLoss(nn.Module):\n    r\"\"\"Domain adversarial loss from `Adversarial Discriminative Domain Adaptation (CVPR 2017)\n    <https://arxiv.org/pdf/1702.05464.pdf>`_.\n    Similar to the original `GAN <https://arxiv.org/pdf/1406.2661.pdf>`_ paper, ADDA argues that replacing\n    :math:`\\text{log}(1-p)` with :math:`-\\text{log}(p)` in the adversarial loss provides better gradient qualities. Detailed\n    optimization process can be found `here\n    <https://github.com/thuml/Transfer-Learning-Library/blob/master/examples/domain_adaptation/image_classification/adda.py>`_.\n\n    Inputs:\n        - domain_pred (tensor): predictions of domain discriminator\n        - domain_label (str, optional): whether the data comes from source or target.\n          Must be 'source' or 'target'. Default: 'source'\n\n    Shape:\n        - domain_pred: :math:`(minibatch,)`.\n        - Outputs: scalar.\n\n    \"\"\"\n\n    def __init__(self):\n        super(DomainAdversarialLoss, self).__init__()\n\n    def forward(self, domain_pred, domain_label='source'):\n        assert domain_label in ['source', 'target']\n        if domain_label == 'source':\n            return F.binary_cross_entropy(domain_pred, torch.ones_like(domain_pred).to(domain_pred.device))\n        else:\n            return F.binary_cross_entropy(domain_pred, torch.zeros_like(domain_pred).to(domain_pred.device))\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n\n    def freeze_bn(self):\n        for m in self.modules():\n            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n                m.eval()\n\n    def get_parameters(self, base_lr=1.0, optimize_head=True) -> List[Dict]:\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr}\n        ]\n        if optimize_head:\n            params.append({\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr})\n\n        return params\n"
  },
  {
    "path": "tllib/alignment/advent.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom torch import nn\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\n\nclass Discriminator(nn.Sequential):\n    \"\"\"\n    Domain discriminator model from\n    `ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation (CVPR 2019) <https://arxiv.org/abs/1811.12833>`_\n\n    Distinguish pixel-by-pixel whether the input predictions come from the source domain or the target domain.\n    The source domain label is 1 and the target domain label is 0.\n\n    Args:\n        num_classes (int): num of classes in the predictions\n        ndf (int): dimension of the hidden features\n\n    Shape:\n        - Inputs: :math:`(minibatch, C, H, W)` where :math:`C` is the number of classes\n        - Outputs: :math:`(minibatch, 1, H, W)`\n    \"\"\"\n    def __init__(self, num_classes, ndf=64):\n        super(Discriminator, self).__init__(\n            nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),\n            nn.LeakyReLU(negative_slope=0.2, inplace=True),\n            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),\n            nn.LeakyReLU(negative_slope=0.2, inplace=True),\n            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),\n            nn.LeakyReLU(negative_slope=0.2, inplace=True),\n            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),\n            nn.LeakyReLU(negative_slope=0.2, inplace=True),\n            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1),\n        )\n\n\ndef prob_2_entropy(prob):\n    \"\"\" convert probabilistic prediction maps to weighted self-information maps\n    \"\"\"\n    n, c, h, w = prob.size()\n    return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c)\n\n\ndef bce_loss(y_pred, y_label):\n    y_truth_tensor = torch.FloatTensor(y_pred.size())\n    y_truth_tensor.fill_(y_label)\n    y_truth_tensor = y_truth_tensor.to(y_pred.get_device())\n    return F.binary_cross_entropy_with_logits(y_pred, y_truth_tensor)\n\n\nclass DomainAdversarialEntropyLoss(nn.Module):\n    r\"\"\"The `Domain Adversarial Entropy Loss <https://arxiv.org/abs/1811.12833>`_\n\n    Minimizing entropy with adversarial learning through training a domain discriminator.\n\n    Args:\n        domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts\n          the domains of predictions. Its input shape is :math:`(minibatch, C, H, W)` and output shape is :math:`(minibatch, 1, H, W)`\n\n    Inputs:\n        - logits (tensor): logits output of segmentation model\n        - domain_label (str, optional): whether the data comes from source or target.\n          Choices: ['source', 'target']. Default: 'source'\n\n    Shape:\n        - logits: :math:`(minibatch, C, H, W)` where :math:`C` means the number of classes\n        - Outputs: scalar.\n\n    Examples::\n\n        >>> B, C, H, W = 2, 19, 512, 512\n        >>> discriminator = Discriminator(num_classes=C)\n        >>> dann = DomainAdversarialEntropyLoss(discriminator)\n        >>> # logits output on source domain and target domain\n        >>> y_s, y_t = torch.randn(B, C, H, W), torch.randn(B, C, H, W)\n        >>> loss = 0.5 * (dann(y_s, \"source\") + dann(y_t, \"target\"))\n    \"\"\"\n    def __init__(self, discriminator: nn.Module):\n        super(DomainAdversarialEntropyLoss, self).__init__()\n        self.discriminator = discriminator\n\n    def forward(self, logits, domain_label='source'):\n        \"\"\"\n        \"\"\"\n        assert domain_label in ['source', 'target']\n        probability = F.softmax(logits, dim=1)\n        entropy = prob_2_entropy(probability)\n        domain_prediciton = self.discriminator(entropy)\n        if domain_label == 'source':\n            return bce_loss(domain_prediciton, 1)\n        else:\n            return bce_loss(domain_prediciton, 0)\n\n    def train(self, mode=True):\n        r\"\"\"Sets the discriminator in training mode. In the training mode,\n        all the parameters in discriminator will be set requires_grad=True.\n\n        Args:\n            mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``.\n        \"\"\"\n        self.discriminator.train(mode)\n        for param in self.discriminator.parameters():\n            param.requires_grad = mode\n        return self\n\n    def eval(self):\n        r\"\"\"Sets the module in evaluation mode. In the training mode,\n        all the parameters in discriminator will be set requires_grad=False.\n\n        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.\n        \"\"\"\n        return self.train(False)\n"
  },
  {
    "path": "tllib/alignment/bsp.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nfrom tllib.modules.classifier import Classifier as ClassifierBase\n\n\nclass BatchSpectralPenalizationLoss(nn.Module):\n    r\"\"\"Batch spectral penalization loss from `Transferability vs. Discriminability: Batch\n    Spectral Penalization for Adversarial Domain Adaptation (ICML 2019)\n    <http://ise.thss.tsinghua.edu.cn/~mlong/doc/batch-spectral-penalization-icml19.pdf>`_.\n\n    Given source features :math:`f_s` and target features :math:`f_t` in current mini batch, singular value\n    decomposition is first performed\n\n    .. math::\n        f_s = U_s\\Sigma_sV_s^T\n\n    .. math::\n        f_t = U_t\\Sigma_tV_t^T\n\n    Then batch spectral penalization loss is calculated as\n\n    .. math::\n        loss=\\sum_{i=1}^k(\\sigma_{s,i}^2+\\sigma_{t,i}^2)\n\n    where :math:`\\sigma_{s,i},\\sigma_{t,i}` refer to the :math:`i-th` largest singular value of source features\n    and target features respectively. We empirically set :math:`k=1`.\n\n    Inputs:\n        - f_s (tensor): feature representations on source domain, :math:`f^s`\n        - f_t (tensor): feature representations on target domain, :math:`f^t`\n\n    Shape:\n        - f_s, f_t: :math:`(N, F)` where F means the dimension of input features.\n        - Outputs: scalar.\n\n    \"\"\"\n\n    def __init__(self):\n        super(BatchSpectralPenalizationLoss, self).__init__()\n\n    def forward(self, f_s, f_t):\n        _, s_s, _ = torch.svd(f_s)\n        _, s_t, _ = torch.svd(f_t)\n        loss = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2)\n        return loss\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n"
  },
  {
    "path": "tllib/alignment/cdan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\nfrom tllib.utils.metric import binary_accuracy, accuracy\nfrom tllib.modules.grl import WarmStartGradientReverseLayer\nfrom tllib.modules.entropy import entropy\n\n\n__all__ = ['ConditionalDomainAdversarialLoss', 'ImageClassifier']\n\n\nclass ConditionalDomainAdversarialLoss(nn.Module):\n    r\"\"\"The Conditional Domain Adversarial Loss used in `Conditional Adversarial Domain Adaptation (NIPS 2018) <https://arxiv.org/abs/1705.10667>`_\n\n    Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a\n    conditional manner. Given domain discriminator :math:`D`, feature representation :math:`f` and\n    classifier predictions :math:`g`, the definition of CDAN loss is\n\n    .. math::\n        loss(\\mathcal{D}_s, \\mathcal{D}_t) &= \\mathbb{E}_{x_i^s \\sim \\mathcal{D}_s} \\text{log}[D(T(f_i^s, g_i^s))] \\\\\n        &+ \\mathbb{E}_{x_j^t \\sim \\mathcal{D}_t} \\text{log}[1-D(T(f_j^t, g_j^t))],\\\\\n\n    where :math:`T` is a :class:`MultiLinearMap`  or :class:`RandomizedMultiLinearMap` which convert two tensors to a single tensor.\n\n    Args:\n        domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of\n          features. Its input shape is (N, F) and output shape is (N, 1)\n        entropy_conditioning (bool, optional): If True, use entropy-aware weight to reweight each training example.\n          Default: False\n        randomized (bool, optional): If True, use `randomized multi linear map`. Else, use `multi linear map`.\n          Default: False\n        num_classes (int, optional): Number of classes. Default: -1\n        features_dim (int, optional): Dimension of input features. Default: -1\n        randomized_dim (int, optional): Dimension of features after randomized. Default: 1024\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    .. note::\n        You need to provide `num_classes`, `features_dim` and `randomized_dim` **only when** `randomized`\n        is set True.\n\n    Inputs:\n        - g_s (tensor): unnormalized classifier predictions on source domain, :math:`g^s`\n        - f_s (tensor): feature representations on source domain, :math:`f^s`\n        - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`\n        - f_t (tensor): feature representations on target domain, :math:`f^t`\n\n    Shape:\n        - g_s, g_t: :math:`(minibatch, C)` where C means the number of classes.\n        - f_s, f_t: :math:`(minibatch, F)` where F means the dimension of input features.\n        - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, )`.\n\n    Examples::\n\n        >>> from tllib.modules.domain_discriminator import DomainDiscriminator\n        >>> from tllib.alignment.cdan import ConditionalDomainAdversarialLoss\n        >>> import torch\n        >>> num_classes = 2\n        >>> feature_dim = 1024\n        >>> batch_size = 10\n        >>> discriminator = DomainDiscriminator(in_feature=feature_dim * num_classes, hidden_size=1024)\n        >>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')\n        >>> # features from source domain and target domain\n        >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)\n        >>> # logits output from source domain adn target domain\n        >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)\n        >>> output = loss(g_s, f_s, g_t, f_t)\n    \"\"\"\n\n    def __init__(self, domain_discriminator: nn.Module, entropy_conditioning: Optional[bool] = False,\n                 randomized: Optional[bool] = False, num_classes: Optional[int] = -1,\n                 features_dim: Optional[int] = -1, randomized_dim: Optional[int] = 1024,\n                 reduction: Optional[str] = 'mean', sigmoid=True):\n        super(ConditionalDomainAdversarialLoss, self).__init__()\n        self.domain_discriminator = domain_discriminator\n        self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True)\n        self.entropy_conditioning = entropy_conditioning\n        self.sigmoid = sigmoid\n        self.reduction = reduction\n\n        if randomized:\n            assert num_classes > 0 and features_dim > 0 and randomized_dim > 0\n            self.map = RandomizedMultiLinearMap(features_dim, num_classes, randomized_dim)\n        else:\n            self.map = MultiLinearMap()\n        self.bce = lambda input, target, weight: F.binary_cross_entropy(input, target, weight,\n                                                                        reduction=reduction) if self.entropy_conditioning \\\n            else F.binary_cross_entropy(input, target, reduction=reduction)\n        self.domain_discriminator_accuracy = None\n\n    def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:\n        f = torch.cat((f_s, f_t), dim=0)\n        g = torch.cat((g_s, g_t), dim=0)\n        g = F.softmax(g, dim=1).detach()\n        h = self.grl(self.map(f, g))\n        d = self.domain_discriminator(h)\n\n        weight = 1.0 + torch.exp(-entropy(g))\n        batch_size = f.size(0)\n        weight = weight / torch.sum(weight) * batch_size\n\n        if self.sigmoid:\n            d_label = torch.cat((\n                torch.ones((g_s.size(0), 1)).to(g_s.device),\n                torch.zeros((g_t.size(0), 1)).to(g_t.device),\n            ))\n            self.domain_discriminator_accuracy = binary_accuracy(d, d_label)\n            if self.entropy_conditioning:\n                return F.binary_cross_entropy(d, d_label, weight.view_as(d), reduction=self.reduction)\n            else:\n                return F.binary_cross_entropy(d, d_label, reduction=self.reduction)\n        else:\n            d_label = torch.cat((\n                torch.ones((g_s.size(0), )).to(g_s.device),\n                torch.zeros((g_t.size(0), )).to(g_t.device),\n            )).long()\n            self.domain_discriminator_accuracy = accuracy(d, d_label)\n            if self.entropy_conditioning:\n                raise NotImplementedError(\"entropy_conditioning\")\n            return F.cross_entropy(d, d_label, reduction=self.reduction)\n\n\nclass RandomizedMultiLinearMap(nn.Module):\n    \"\"\"Random multi linear map\n\n    Given two inputs :math:`f` and :math:`g`, the definition is\n\n    .. math::\n        T_{\\odot}(f,g) = \\dfrac{1}{\\sqrt{d}} (R_f f) \\odot (R_g g),\n\n    where :math:`\\odot` is element-wise product, :math:`R_f` and :math:`R_g` are random matrices\n    sampled only once and ﬁxed in training.\n\n    Args:\n        features_dim (int): dimension of input :math:`f`\n        num_classes (int): dimension of input :math:`g`\n        output_dim (int, optional): dimension of output tensor. Default: 1024\n\n    Shape:\n        - f: (minibatch, features_dim)\n        - g: (minibatch, num_classes)\n        - Outputs: (minibatch, output_dim)\n    \"\"\"\n\n    def __init__(self, features_dim: int, num_classes: int, output_dim: Optional[int] = 1024):\n        super(RandomizedMultiLinearMap, self).__init__()\n        self.Rf = torch.randn(features_dim, output_dim)\n        self.Rg = torch.randn(num_classes, output_dim)\n        self.output_dim = output_dim\n\n    def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:\n        f = torch.mm(f, self.Rf.to(f.device))\n        g = torch.mm(g, self.Rg.to(g.device))\n        output = torch.mul(f, g) / np.sqrt(float(self.output_dim))\n        return output\n\n\nclass MultiLinearMap(nn.Module):\n    \"\"\"Multi linear map\n\n    Shape:\n        - f: (minibatch, F)\n        - g: (minibatch, C)\n        - Outputs: (minibatch, F * C)\n    \"\"\"\n\n    def __init__(self):\n        super(MultiLinearMap, self).__init__()\n\n    def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:\n        batch_size = f.size(0)\n        output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1))\n        return output.view(batch_size, -1)\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n"
  },
  {
    "path": "tllib/alignment/coral.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch\nimport torch.nn as nn\n\n\nclass CorrelationAlignmentLoss(nn.Module):\n    r\"\"\"The `Correlation Alignment Loss` in\n    `Deep CORAL: Correlation Alignment for Deep Domain Adaptation (ECCV 2016) <https://arxiv.org/pdf/1607.01719.pdf>`_.\n\n    Given source features :math:`f_S` and target features :math:`f_T`, the covariance matrices are given by\n\n    .. math::\n        C_S = \\frac{1}{n_S-1}(f_S^Tf_S-\\frac{1}{n_S}(\\textbf{1}^Tf_S)^T(\\textbf{1}^Tf_S))\n    .. math::\n        C_T = \\frac{1}{n_T-1}(f_T^Tf_T-\\frac{1}{n_T}(\\textbf{1}^Tf_T)^T(\\textbf{1}^Tf_T))\n\n    where :math:`\\textbf{1}` denotes a column vector with all elements equal to 1, :math:`n_S, n_T` denotes number of\n    source and target samples, respectively. We use :math:`d` to denote feature dimension, use\n    :math:`{\\Vert\\cdot\\Vert}^2_F` to denote the squared matrix `Frobenius norm`. The correlation alignment loss is\n    given by\n\n    .. math::\n        l_{CORAL} = \\frac{1}{4d^2}\\Vert C_S-C_T \\Vert^2_F\n\n    Inputs:\n        - f_s (tensor): feature representations on source domain, :math:`f^s`\n        - f_t (tensor): feature representations on target domain, :math:`f^t`\n\n    Shape:\n        - f_s, f_t: :math:`(N, d)` where d means the dimension of input features, :math:`N=n_S=n_T` is mini-batch size.\n        - Outputs: scalar.\n    \"\"\"\n\n    def __init__(self):\n        super(CorrelationAlignmentLoss, self).__init__()\n\n    def forward(self, f_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:\n        mean_s = f_s.mean(0, keepdim=True)\n        mean_t = f_t.mean(0, keepdim=True)\n        cent_s = f_s - mean_s\n        cent_t = f_t - mean_t\n        cov_s = torch.mm(cent_s.t(), cent_s) / (len(f_s) - 1)\n        cov_t = torch.mm(cent_t.t(), cent_t) / (len(f_t) - 1)\n\n        mean_diff = (mean_s - mean_t).pow(2).mean()\n        cov_diff = (cov_s - cov_t).pow(2).mean()\n\n        return mean_diff + cov_diff\n"
  },
  {
    "path": "tllib/alignment/d_adapt/__init__.py",
    "content": ""
  },
  {
    "path": "tllib/alignment/d_adapt/feedback.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport itertools\nimport numpy as np\nimport copy\nimport logging\nfrom typing import List, Optional, Union\nimport torch\n\nfrom detectron2.config import configurable\nfrom detectron2.structures import BoxMode, Boxes, Instances\nfrom detectron2.data.catalog import DatasetCatalog, MetadataCatalog\nfrom detectron2.data.build import filter_images_with_only_crowd_annotations, filter_images_with_few_keypoints, \\\n    print_instances_class_histogram\nfrom detectron2.data.detection_utils import check_metadata_consistency\nimport detectron2.data.transforms as T\nimport detectron2.data.detection_utils as utils\n\nfrom .proposal import Proposal\n\n\ndef load_feedbacks_into_dataset(dataset_dicts, proposals_list: List[Proposal]):\n    \"\"\"\n    Load precomputed object feedbacks into the dataset.\n\n    Args:\n        dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.\n        proposals_list (list[Proposal]): list of Proposal.\n\n    Returns:\n        list[dict]: the same format as dataset_dicts, but added feedback field.\n    \"\"\"\n    feedbacks = {}\n\n    for record in dataset_dicts:\n        image_id = str(record[\"image_id\"])\n        feedbacks[image_id] = {\n            'pred_boxes': [],\n            'pred_classes': [],\n        }\n\n    for proposals in proposals_list:\n        image_id = str(proposals.image_id)\n        feedbacks[image_id]['pred_boxes'] += proposals.pred_boxes.tolist()\n        feedbacks[image_id]['pred_classes'] += proposals.pred_classes.tolist()\n\n    # Assuming default bbox_mode of precomputed feedbacks are 'XYXY_ABS'\n    bbox_mode = BoxMode.XYXY_ABS\n\n    dataset_dicts_with_feedbacks = []\n    for record in dataset_dicts:\n        # Get the index of the feedback\n        image_id = str(record[\"image_id\"])\n        record[\"feedback_proposal_boxes\"] = feedbacks[image_id][\"pred_boxes\"]\n        record[\"feedback_gt_classes\"] = feedbacks[image_id][\"pred_classes\"]\n        record[\"feedback_gt_boxes\"] = feedbacks[image_id][\"pred_boxes\"]\n        record[\"feedback_bbox_mode\"] = bbox_mode\n        if sum(map(lambda x: x >= 0, feedbacks[image_id][\"pred_classes\"])) > 0:  # remove images without feedbacks\n            dataset_dicts_with_feedbacks.append(record)\n\n    return dataset_dicts_with_feedbacks\n\n\ndef get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, proposals_list=None):\n    \"\"\"\n    Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.\n\n    Args:\n        names (str or list[str]): a dataset name or a list of dataset names\n        filter_empty (bool): whether to filter out images without instance annotations\n        min_keypoints (int): filter out images with fewer keypoints than\n            `min_keypoints`. Set to 0 to do nothing.\n        proposals_list (optional, list[Proposal]): list of Proposal.\n\n\n    Returns:\n        list[dict]: a list of dicts following the standard dataset dict format.\n    \"\"\"\n    if isinstance(names, str):\n        names = [names]\n    assert len(names), names\n    dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]\n    for dataset_name, dicts in zip(names, dataset_dicts):\n        assert len(dicts), \"Dataset '{}' is empty!\".format(dataset_name)\n\n    dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))\n    if proposals_list is not None:\n        # load precomputed feedbacks for each proposals\n        dataset_dicts = load_feedbacks_into_dataset(dataset_dicts, proposals_list)\n\n    has_instances = \"annotations\" in dataset_dicts[0]\n    if filter_empty and has_instances:\n        dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)\n    if min_keypoints > 0 and has_instances:\n        dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)\n\n    if has_instances:\n        try:\n            class_names = MetadataCatalog.get(names[0]).thing_classes\n            check_metadata_consistency(\"thing_classes\", names)\n            print_instances_class_histogram(dataset_dicts, class_names)\n        except AttributeError:  # class names are not available for this dataset\n            pass\n\n    assert len(dataset_dicts), \"No valid data found in {}.\".format(\",\".join(names))\n    return dataset_dicts\n\n\ndef transform_feedbacks(dataset_dict, image_shape, transforms, *, min_box_size=0):\n    \"\"\"\n    Apply transformations to the feedbacks in dataset_dict, if any.\n\n    Args:\n        dataset_dict (dict): a dict read from the dataset, possibly\n            contains fields \"proposal_boxes\", \"proposal_objectness_logits\", \"proposal_bbox_mode\"\n        image_shape (tuple): height, width\n        transforms (TransformList):\n        min_box_size (int): proposals with either side smaller than this\n            threshold are removed\n\n    The input dict is modified in-place, with abovementioned keys removed. A new\n    key \"proposals\" will be added. Its value is an `Instances`\n    object which contains the transformed proposals in its field\n    \"proposal_boxes\" and \"objectness_logits\".\n    \"\"\"\n    if \"feedback_proposal_boxes\" in dataset_dict:\n        # Transform proposal boxes\n        proposal_boxes = transforms.apply_box(\n            BoxMode.convert(\n                dataset_dict.pop(\"feedback_proposal_boxes\"),\n                dataset_dict.get(\"feedback_bbox_mode\"),\n                BoxMode.XYXY_ABS,\n            )\n        )\n        proposal_boxes = Boxes(proposal_boxes)\n        gt_boxes = transforms.apply_box(\n            BoxMode.convert(\n                dataset_dict.pop(\"feedback_gt_boxes\"),\n                dataset_dict.get(\"feedback_bbox_mode\"),\n                BoxMode.XYXY_ABS,\n            )\n        )\n        gt_boxes = Boxes(gt_boxes)\n        gt_classes = torch.as_tensor(\n            dataset_dict.pop(\"feedback_gt_classes\")\n        )\n\n        proposal_boxes.clip(image_shape)\n        gt_boxes.clip(image_shape)\n        keep = proposal_boxes.nonempty(threshold=min_box_size) & (gt_classes >= 0)\n        # keep = boxes.nonempty(threshold=min_box_size)\n        proposal_boxes = proposal_boxes[keep]\n        gt_boxes = gt_boxes[keep]\n        gt_classes = gt_classes[keep]\n\n        feedbacks = Instances(image_shape)\n        feedbacks.proposal_boxes = proposal_boxes\n        feedbacks.gt_boxes = gt_boxes\n        feedbacks.gt_classes = gt_classes\n        dataset_dict[\"feedbacks\"] = feedbacks\n\n\nclass DatasetMapper:\n    \"\"\"\n    A callable which takes a dataset dict in Detectron2 Dataset format,\n    and map it into a format used by the model.\n\n    This is the default callable to be used to map your dataset dict into training data.\n    You may need to follow it to implement your own one for customized logic,\n    such as a different way to read or transform images.\n    See :doc:`/tutorials/data_loading` for details.\n\n    The callable currently does the following:\n\n    1. Read the image from \"file_name\"\n    2. Applies cropping/geometric transforms to the image and annotations\n    3. Prepare data and annotations to Tensor and :class:`Instances`\n    \"\"\"\n\n    @configurable\n    def __init__(\n        self,\n        is_train: bool,\n        *,\n        augmentations: List[Union[T.Augmentation, T.Transform]],\n        image_format: str,\n        use_instance_mask: bool = False,\n        use_keypoint: bool = False,\n        instance_mask_format: str = \"polygon\",\n        keypoint_hflip_indices: Optional[np.ndarray] = None,\n        precomputed_proposal_topk: Optional[int] = None,\n        recompute_boxes: bool = False,\n    ):\n        \"\"\"\n        NOTE: this interface is experimental.\n\n        Args:\n            is_train: whether it's used in training or inference\n            augmentations: a list of augmentations or deterministic transforms to apply\n            image_format: an image format supported by :func:`detection_utils.read_image`.\n            use_instance_mask: whether to process instance segmentation annotations, if available\n            use_keypoint: whether to process keypoint annotations if available\n            instance_mask_format: one of \"polygon\" or \"bitmask\". Process instance segmentation\n                masks into this format.\n            keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`\n            precomputed_proposal_topk: if given, will load pre-computed\n                proposals from dataset_dict and keep the top k proposals for each image.\n            recompute_boxes: whether to overwrite bounding box annotations\n                by computing tight bounding boxes from instance mask annotations.\n        \"\"\"\n        if recompute_boxes:\n            assert use_instance_mask, \"recompute_boxes requires instance masks\"\n        # fmt: off\n        self.is_train               = is_train\n        self.augmentations          = T.AugmentationList(augmentations)\n        self.image_format           = image_format\n        self.use_instance_mask      = use_instance_mask\n        self.instance_mask_format   = instance_mask_format\n        self.use_keypoint           = use_keypoint\n        self.keypoint_hflip_indices = keypoint_hflip_indices\n        self.proposal_topk          = precomputed_proposal_topk\n        self.recompute_boxes        = recompute_boxes\n        # fmt: on\n        logger = logging.getLogger(__name__)\n        mode = \"training\" if is_train else \"inference\"\n        logger.info(f\"[DatasetMapper] Augmentations used in {mode}: {augmentations}\")\n\n    @classmethod\n    def from_config(cls, cfg, is_train: bool = True):\n        augs = utils.build_augmentation(cfg, is_train)\n        if cfg.INPUT.CROP.ENABLED and is_train:\n            augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))\n            recompute_boxes = cfg.MODEL.MASK_ON\n        else:\n            recompute_boxes = False\n\n        ret = {\n            \"is_train\": is_train,\n            \"augmentations\": augs,\n            \"image_format\": cfg.INPUT.FORMAT,\n            \"use_instance_mask\": cfg.MODEL.MASK_ON,\n            \"instance_mask_format\": cfg.INPUT.MASK_FORMAT,\n            \"use_keypoint\": cfg.MODEL.KEYPOINT_ON,\n            \"recompute_boxes\": recompute_boxes,\n        }\n\n        if cfg.MODEL.KEYPOINT_ON:\n            ret[\"keypoint_hflip_indices\"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)\n\n        if cfg.MODEL.LOAD_PROPOSALS:\n            ret[\"precomputed_proposal_topk\"] = (\n                cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN\n                if is_train\n                else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST\n            )\n        return ret\n\n    def __call__(self, dataset_dict):\n        \"\"\"\n        Args:\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\n\n        Returns:\n            dict: a format that builtin models in detectron2 accept\n        \"\"\"\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\n        # USER: Write your own image loading if it's not from a file\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.image_format)\n        utils.check_image_size(dataset_dict, image)\n\n        # USER: Remove if you don't do semantic/panoptic segmentation.\n        if \"sem_seg_file_name\" in dataset_dict:\n            sem_seg_gt = utils.read_image(dataset_dict.pop(\"sem_seg_file_name\"), \"L\").squeeze(2)\n        else:\n            sem_seg_gt = None\n\n        aug_input = T.AugInput(image, sem_seg=sem_seg_gt)\n        transforms = self.augmentations(aug_input)\n        image, sem_seg_gt = aug_input.image, aug_input.sem_seg\n\n        image_shape = image.shape[:2]  # h, w\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\n        # Therefore it's important to use torch.Tensor.\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\n        if sem_seg_gt is not None:\n            dataset_dict[\"sem_seg\"] = torch.as_tensor(sem_seg_gt.astype(\"long\"))\n\n        # USER: Remove if you don't use pre-computed proposals.\n        # Most users would not need this feature.\n        if self.proposal_topk is not None:\n            utils.transform_proposals(\n                dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk\n            )\n\n        transform_feedbacks(\n            dataset_dict, image_shape, transforms\n        )\n\n        if not self.is_train:\n            # USER: Modify this if you want to keep them for some reason.\n            dataset_dict.pop(\"annotations\", None)\n            dataset_dict.pop(\"sem_seg_file_name\", None)\n            return dataset_dict\n\n        if \"annotations\" in dataset_dict:\n            # USER: Modify this if you want to keep them for some reason.\n            for anno in dataset_dict[\"annotations\"]:\n                if not self.use_instance_mask:\n                    anno.pop(\"segmentation\", None)\n                if not self.use_keypoint:\n                    anno.pop(\"keypoints\", None)\n\n            # USER: Implement additional transformations if you have other types of data\n            annos = [\n                utils.transform_instance_annotations(\n                    obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices\n                )\n                for obj in dataset_dict.pop(\"annotations\")\n                if obj.get(\"iscrowd\", 0) == 0\n            ]\n            instances = utils.annotations_to_instances(\n                annos, image_shape, mask_format=self.instance_mask_format\n            )\n\n            # After transforms such as cropping are applied, the bounding box may no longer\n            # tightly bound the object. As an example, imagine a triangle object\n            # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\n            # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\n            # the intersection of original bounding box and the cropping box.\n            if self.recompute_boxes:\n                instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\n            dataset_dict[\"instances\"] = utils.filter_empty_instances(instances)\n        return dataset_dict"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/__init__.py",
    "content": "from . import meta_arch\nfrom . import roi_heads"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/matcher.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport torch\r\nfrom torch import Tensor, nn\r\n\r\nfrom detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple\r\n\r\n\r\nclass MaxOverlapMatcher(object):\r\n    \"\"\"\r\n    This class assigns to each predicted \"element\" (e.g., a box) a ground-truth\r\n    element. Each predicted element will have exactly zero or one matches; each\r\n    ground-truth element may be matched to one predicted elements.\r\n    \"\"\"\r\n\r\n    def __init__(self):\r\n        pass\r\n\r\n    def __call__(self, match_quality_matrix):\r\n        \"\"\"\r\n        Args:\r\n            match_quality_matrix (Tensor[float]): an MxN tensor, containing the\r\n                pairwise quality between M ground-truth elements and N predicted\r\n                elements. All elements must be >= 0 (due to the us of `torch.nonzero`\r\n                for selecting indices in :meth:`set_low_quality_matches_`).\r\n\r\n        Returns:\r\n            matches (Tensor[int64]): a vector of length N, where matches[i] is a matched\r\n                ground-truth index in [0, M)\r\n            match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates\r\n                whether a prediction is a true or false positive or ignored\r\n        \"\"\"\r\n        assert match_quality_matrix.dim() == 2\r\n        # match_quality_matrix is M (gt) x N (predicted)\r\n        # Max over gt elements (dim 0) to find best gt candidate for each prediction\r\n        _, matched_idxs = match_quality_matrix.max(dim=0)\r\n\r\n        anchor_labels = match_quality_matrix.new_full(\r\n            (match_quality_matrix.size(1),), -1, dtype=torch.int8\r\n        )\r\n\r\n        # For each gt, find the prediction with which it has highest quality\r\n        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)\r\n        # Find the highest quality match available, even if it is low, including ties.\r\n        # Note that the matches qualities must be positive due to the use of\r\n        # `torch.nonzero`.\r\n        _, pred_inds_with_highest_quality = nonzero_tuple(\r\n            match_quality_matrix == highest_quality_foreach_gt[:, None]\r\n        )\r\n        anchor_labels[pred_inds_with_highest_quality] = 1\r\n\r\n        return matched_idxs, anchor_labels\r\n"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/meta_arch/__init__.py",
    "content": "from .rcnn import DecoupledGeneralizedRCNN\r\nfrom .retinanet import DecoupledRetinaNet"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/meta_arch/rcnn.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport torch\r\nfrom typing import Optional, Callable, Tuple, Any, List, Sequence, Dict\r\nimport numpy as np\r\n\r\nfrom detectron2.utils.events import get_event_storage\r\nfrom detectron2.structures import Instances\r\nfrom detectron2.data.detection_utils import convert_image_to_rgb\r\nfrom detectron2.modeling.postprocessing import detector_postprocess\r\nfrom detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY\r\n\r\nfrom tllib.vision.models.object_detection.meta_arch import TLGeneralizedRCNN\r\n\r\n\r\n@META_ARCH_REGISTRY.register()\r\nclass DecoupledGeneralizedRCNN(TLGeneralizedRCNN):\r\n    \"\"\"\r\n    Generalized R-CNN for Decoupled Adaptation (D-adapt).\r\n    Similar to that in in Supervised Learning, DecoupledGeneralizedRCNN has the following three components:\r\n    1. Per-image feature extraction (aka backbone)\r\n    2. Region proposal generation\r\n    3. Per-region feature extraction and prediction\r\n\r\n    Different from that in Supervised Learning, DecoupledGeneralizedRCNN\r\n    1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training\r\n    2. generate foreground and background proposals during inference\r\n\r\n    Args:\r\n        backbone: a backbone module, must follow detectron2's backbone interface\r\n        proposal_generator: a module that generates proposals using backbone features\r\n        roi_heads: a ROI head that performs per-region computation\r\n        pixel_mean, pixel_std: list or tuple with #channels element,\r\n            representing the per-channel mean and std to be used to normalize\r\n            the input image\r\n        input_format: describe the meaning of channels of input. Needed by visualization\r\n        vis_period: the period to run visualization. Set to 0 to disable.\r\n        finetune (bool): whether finetune the detector or train from scratch. Default: True\r\n\r\n    Inputs:\r\n        - batched_inputs: a list, batched outputs of :class:`DatasetMapper`.\r\n          Each item in the list contains the inputs for one image.\r\n          For now, each item in the list is a dict that contains:\r\n            * image: Tensor, image in (C, H, W) format.\r\n            * instances (optional): groundtruth :class:`Instances`\r\n            * feedbacks (optional): :class:`Instances`, feedbacks from adaptors.\r\n            * \"height\", \"width\" (int): the output resolution of the model, used in inference.\r\n              See :meth:`postprocess` for details.\r\n        - labeled (bool, optional): whether has ground-truth label\r\n\r\n    Outputs:\r\n        - outputs (during inference): A list of dict where each dict is the output for one input image.\r\n          The dict contains a key \"instances\" whose value is a :class:`Instances`.\r\n          The :class:`Instances` object has the following keys:\r\n          \"pred_boxes\", \"pred_classes\", \"scores\", \"pred_masks\", \"pred_keypoints\"\r\n        - losses (during training): A dict of different losses\r\n    \"\"\"\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n\r\n    def forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]], labeled=True):\r\n        if not self.training:\r\n            return self.inference(batched_inputs)\r\n\r\n        images = self.preprocess_image(batched_inputs)\r\n        if \"instances\" in batched_inputs[0]:\r\n            gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\r\n        else:\r\n            gt_instances = None\r\n\r\n        features = self.backbone(images.tensor)\r\n\r\n        if \"feedbacks\" in batched_inputs[0]:\r\n            feedbacks = [x[\"feedbacks\"].to(self.device) for x in batched_inputs]\r\n        else:\r\n            feedbacks = None\r\n\r\n        proposals, proposal_losses = self.proposal_generator(images, features, gt_instances, labeled=labeled)\r\n        _, _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, feedbacks, labeled=labeled)\r\n        losses = {}\r\n        losses.update(detector_losses)\r\n        losses.update(proposal_losses)\r\n\r\n        if self.vis_period > 0:\r\n            storage = get_event_storage()\r\n            if storage.iter % self.vis_period == 0:\r\n                self.visualize_training(batched_inputs, proposals, feedbacks)\r\n\r\n        return losses\r\n\r\n    def visualize_training(self, batched_inputs, proposals, feedbacks=None):\r\n        \"\"\"\r\n        A function used to visualize images and proposals. It shows ground truth\r\n        bounding boxes on the original image and up to 20 top-scoring predicted\r\n        object proposals on the original image. Users can implement different\r\n        visualization functions for different models.\r\n\r\n        Args:\r\n            batched_inputs (list): a list that contains input to the model.\r\n            proposals (list): a list that contains predicted proposals. Both\r\n                batched_inputs and proposals should have the same length.\r\n            feedbacks (list): a list that contains feedbacks from adaptors. Both\r\n                batched_inputs and feedbacks should have the same length.\r\n        \"\"\"\r\n        from detectron2.utils.visualizer import Visualizer\r\n\r\n        storage = get_event_storage()\r\n        max_vis_prop = 20\r\n\r\n        for input, prop in zip(batched_inputs, proposals):\r\n            img = input[\"image\"]\r\n            img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format)\r\n            v_gt = Visualizer(img, None)\r\n            v_gt = v_gt.overlay_instances(boxes=input[\"instances\"].gt_boxes)\r\n            anno_img = v_gt.get_image()\r\n            box_size = min(len(prop.proposal_boxes), max_vis_prop)\r\n            v_pred = Visualizer(img, None)\r\n            v_pred = v_pred.overlay_instances(\r\n                boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy()\r\n            )\r\n            prop_img = v_pred.get_image()\r\n\r\n            num_classes = self.roi_heads.box_predictor.num_classes\r\n            if feedbacks is not None:\r\n                v_feedback_gt = Visualizer(img, None)\r\n                instance = feedbacks[0].to(torch.device(\"cpu\"))\r\n                v_feedback_gt = v_feedback_gt.overlay_instances(\r\n                    boxes=instance.proposal_boxes[instance.gt_classes != num_classes])\r\n                feedback_gt_img = v_feedback_gt.get_image()\r\n\r\n                v_feedback_gf = Visualizer(img, None)\r\n                v_feedback_gf = v_feedback_gf.overlay_instances(\r\n                    boxes=instance.proposal_boxes[instance.gt_classes == num_classes])\r\n                feedback_gf_img = v_feedback_gf.get_image()\r\n\r\n                vis_img = np.vstack((anno_img, prop_img, feedback_gt_img, feedback_gf_img))\r\n                vis_img = vis_img.transpose(2, 0, 1)\r\n                vis_name = f\"Top: GT; Middle: Pred; Bottom: Feedback GT, Feedback GF\"\r\n            else:\r\n                vis_img = np.concatenate((anno_img, prop_img), axis=1)\r\n                vis_img = vis_img.transpose(2, 0, 1)\r\n                vis_name = \"Left: GT bounding boxes;  Right: Predicted proposals\"\r\n            storage.put_image(vis_name, vis_img)\r\n            break  # only visualize one image in a batch\r\n\r\n    def inference(\r\n        self,\r\n        batched_inputs: Tuple[Dict[str, torch.Tensor]],\r\n        detected_instances: Optional[List[Instances]] = None,\r\n        do_postprocess: bool = True,\r\n    ):\r\n        \"\"\"\r\n        Run inference on the given inputs.\r\n\r\n        Args:\r\n            batched_inputs (list[dict]): same as in :meth:`forward`\r\n            detected_instances (None or list[Instances]): if not None, it\r\n                contains an `Instances` object per image. The `Instances`\r\n                object contains \"pred_boxes\" and \"pred_classes\" which are\r\n                known boxes in the image.\r\n                The inference will then skip the detection of bounding boxes,\r\n                and only predict other per-ROI outputs.\r\n            do_postprocess (bool): whether to apply post-processing on the outputs.\r\n\r\n        Returns:\r\n            When do_postprocess=True, same as in :meth:`forward`.\r\n            Otherwise, a list[Instances] containing raw network outputs.\r\n        \"\"\"\r\n        assert not self.training\r\n\r\n        images = self.preprocess_image(batched_inputs)\r\n        features = self.backbone(images.tensor)\r\n        proposals, _ = self.proposal_generator(images, features, None)\r\n\r\n        results, background_results, _ = self.roi_heads(images, features, proposals, None)\r\n        processed_results = []\r\n        for results_per_image, background_results_per_image, input_per_image, image_size in zip(\r\n                results, background_results, batched_inputs, images.image_sizes\r\n        ):\r\n            height = input_per_image.get(\"height\", image_size[0])\r\n            width = input_per_image.get(\"width\", image_size[1])\r\n            r = detector_postprocess(results_per_image, height, width)\r\n            background_r = detector_postprocess(background_results_per_image, height, width)\r\n            processed_results.append({\"instances\": r, 'background': background_r})\r\n        return processed_results\r\n"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/meta_arch/retinanet.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nfrom typing import Optional, Callable, Tuple, Any, List, Sequence, Dict\r\nimport random\r\nimport numpy as np\r\n\r\nimport torch\r\nfrom torch import Tensor\r\nfrom detectron2.structures import BoxMode, Boxes, Instances, pairwise_iou, ImageList\r\nfrom detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple\r\nfrom detectron2.modeling import detector_postprocess\r\nfrom detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY\r\nfrom detectron2.data.detection_utils import convert_image_to_rgb\r\nfrom detectron2.utils.events import get_event_storage\r\n\r\nfrom tllib.vision.models.object_detection.meta_arch import TLRetinaNet\r\nfrom ..matcher import MaxOverlapMatcher\r\n\r\n\r\n@META_ARCH_REGISTRY.register()\r\nclass DecoupledRetinaNet(TLRetinaNet):\r\n    \"\"\"\r\n    RetinaNet for Decoupled Adaptation (D-adapt).\r\n\r\n    Different from that in Supervised Learning, DecoupledRetinaNet\r\n    1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training\r\n    2. generate foreground and background proposals during inference\r\n\r\n    Args:\r\n        backbone: a backbone module, must follow detectron2's backbone interface\r\n        head (nn.Module): a module that predicts logits and regression deltas\r\n            for each level from a list of per-level features\r\n        head_in_features (Tuple[str]): Names of the input feature maps to be used in head\r\n        anchor_generator (nn.Module): a module that creates anchors from a\r\n            list of features. Usually an instance of :class:`AnchorGenerator`\r\n        box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to\r\n            instance boxes\r\n        anchor_matcher (Matcher): label the anchors by matching them with ground truth.\r\n        num_classes (int): number of classes. Used to label background proposals.\r\n\r\n        # Loss parameters:\r\n        focal_loss_alpha (float): focal_loss_alpha\r\n        focal_loss_gamma (float): focal_loss_gamma\r\n        smooth_l1_beta (float): smooth_l1_beta\r\n        box_reg_loss_type (str): Options are \"smooth_l1\", \"giou\"\r\n\r\n        # Inference parameters:\r\n        test_score_thresh (float): Inference cls score threshold, only anchors with\r\n            score > INFERENCE_TH are considered for inference (to improve speed)\r\n        test_topk_candidates (int): Select topk candidates before NMS\r\n        test_nms_thresh (float): Overlap threshold used for non-maximum suppression\r\n            (suppress boxes with IoU >= this threshold)\r\n        max_detections_per_image (int):\r\n            Maximum number of detections to return per image during inference\r\n            (100 is based on the limit established for the COCO dataset).\r\n\r\n        # Input parameters\r\n        pixel_mean (Tuple[float]):\r\n            Values to be used for image normalization (BGR order).\r\n            To train on images of different number of channels, set different mean & std.\r\n            Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]\r\n        pixel_std (Tuple[float]):\r\n            When using pre-trained models in Detectron1 or any MSRA models,\r\n            std has been absorbed into its conv1 weights, so the std needs to be set 1.\r\n            Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)\r\n        vis_period (int):\r\n            The period (in terms of steps) for minibatch visualization at train time.\r\n            Set to 0 to disable.\r\n        input_format (str): Whether the model needs RGB, YUV, HSV etc.\r\n        finetune (bool): whether finetune the detector or train from scratch. Default: True\r\n\r\n    Inputs:\r\n        - batched_inputs: a list, batched outputs of :class:`DatasetMapper`.\r\n          Each item in the list contains the inputs for one image.\r\n          For now, each item in the list is a dict that contains:\r\n            * image: Tensor, image in (C, H, W) format.\r\n            * instances (optional): groundtruth :class:`Instances`\r\n            * \"height\", \"width\" (int): the output resolution of the model, used in inference.\r\n              See :meth:`postprocess` for details.\r\n        - labeled (bool, optional): whether has ground-truth label\r\n\r\n    Outputs:\r\n        - outputs: A list of dict where each dict is the output for one input image.\r\n          The dict contains a key \"instances\" whose value is a :class:`Instances`\r\n          and a key \"features\" whose value is the features of middle layers.\r\n          The :class:`Instances` object has the following keys:\r\n          \"pred_boxes\", \"pred_classes\", \"scores\", \"pred_masks\", \"pred_keypoints\"\r\n        - losses: A dict of different losses\r\n    \"\"\"\r\n    def __init__(self, *args, max_samples_per_level=25, **kwargs):\r\n        super(DecoupledRetinaNet, self).__init__(*args, **kwargs)\r\n        self.max_samples_per_level = max_samples_per_level\r\n        self.max_matcher = MaxOverlapMatcher()\r\n\r\n    def forward_training(self, images, features, predictions, gt_instances=None, feedbacks=None, labeled=True):\r\n        # Transpose the Hi*Wi*A dimension to the middle:\r\n        pred_logits, pred_anchor_deltas = self._transpose_dense_predictions(\r\n            predictions, [self.num_classes, 4]\r\n        )\r\n        anchors = self.anchor_generator(features)\r\n        if labeled:\r\n            gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances)\r\n            losses = self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes)\r\n        else:\r\n            proposal_labels, proposal_boxes = self.label_pseudo_anchors(anchors, feedbacks)\r\n            losses = self.losses(anchors, pred_logits, proposal_labels, pred_anchor_deltas, proposal_boxes)\r\n            losses.pop('loss_box_reg')\r\n        return losses\r\n\r\n    def forward(self, batched_inputs: Tuple[Dict[str, Tensor]], labeled=True):\r\n        images = self.preprocess_image(batched_inputs)\r\n        features = self.backbone(images.tensor)\r\n        features = [features[f] for f in self.head_in_features]\r\n        predictions = self.head(features)\r\n\r\n        if self.training:\r\n            if \"instances\" in batched_inputs[0]:\r\n                gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\r\n            else:\r\n                gt_instances = None\r\n\r\n            if \"feedbacks\" in batched_inputs[0]:\r\n                feedbacks = [x[\"feedbacks\"].to(self.device) for x in batched_inputs]\r\n            else:\r\n                feedbacks = None\r\n\r\n            losses = self.forward_training(images, features, predictions, gt_instances, feedbacks, labeled)\r\n\r\n            if self.vis_period > 0:\r\n                storage = get_event_storage()\r\n                if storage.iter % self.vis_period == 0:\r\n                    results = self.forward_inference(images, features, predictions)\r\n                    self.visualize_training(batched_inputs, results, feedbacks)\r\n\r\n            return losses\r\n        else:\r\n            # sample_background must be called before inference\r\n            # since inference will change predictions\r\n            background_results = self.sample_background(images, features, predictions)\r\n            results = self.forward_inference(images, features, predictions)\r\n\r\n            processed_results = []\r\n            for results_per_image, background_results_per_image, input_per_image, image_size in zip(\r\n                    results, background_results, batched_inputs, images.image_sizes\r\n            ):\r\n                height = input_per_image.get(\"height\", image_size[0])\r\n                width = input_per_image.get(\"width\", image_size[1])\r\n                r = detector_postprocess(results_per_image, height, width)\r\n                background_r = detector_postprocess(background_results_per_image, height, width)\r\n                processed_results.append({\"instances\": r, \"background\": background_r})\r\n            return processed_results\r\n\r\n    @torch.no_grad()\r\n    def label_pseudo_anchors(self, anchors, instances):\r\n        \"\"\"\r\n        Args:\r\n            anchors (list[Boxes]): A list of #feature level Boxes.\r\n                The Boxes contains anchors of this image on the specific feature level.\r\n            instances (list[Instances]): a list of N `Instances`s. The i-th\r\n                `Instances` contains the ground-truth per-instance annotations\r\n                for the i-th input image.\r\n\r\n        Returns:\r\n            list[Tensor]:\r\n                List of #img tensors. i-th element is a vector of labels whose length is\r\n                the total number of anchors across all feature maps (sum(Hi * Wi * A)).\r\n                Label values are in {-1, 0, ..., K}, with -1 means ignore, and K means background.\r\n            list[Tensor]:\r\n                i-th element is a Rx4 tensor, where R is the total number of anchors across\r\n                feature maps. The values are the matched gt boxes for each anchor.\r\n                Values are undefined for those anchors not labeled as foreground.\r\n        \"\"\"\r\n        anchors = Boxes.cat(anchors)  # Rx4\r\n\r\n        gt_labels = []\r\n        matched_gt_boxes = []\r\n        for gt_per_image in instances:\r\n            match_quality_matrix = pairwise_iou(gt_per_image.gt_boxes, anchors)\r\n            matched_idxs, anchor_labels = self.max_matcher(match_quality_matrix)\r\n            del match_quality_matrix\r\n\r\n            if len(gt_per_image) > 0:\r\n                matched_gt_boxes_i = gt_per_image.gt_boxes.tensor[matched_idxs]\r\n\r\n                gt_labels_i = gt_per_image.gt_classes[matched_idxs]\r\n                # Anchors with label -1 are ignored.\r\n                gt_labels_i[anchor_labels == -1] = -1\r\n            else:\r\n                matched_gt_boxes_i = torch.zeros_like(anchors.tensor)\r\n                gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes\r\n\r\n            gt_labels.append(gt_labels_i)\r\n            matched_gt_boxes.append(matched_gt_boxes_i)\r\n\r\n        return gt_labels, matched_gt_boxes\r\n\r\n    def sample_background(\r\n        self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]]\r\n    ):\r\n        pred_logits, pred_anchor_deltas = self._transpose_dense_predictions(\r\n            predictions, [self.num_classes, 4]\r\n        )\r\n        anchors = self.anchor_generator(features)\r\n\r\n        results: List[Instances] = []\r\n        for img_idx, image_size in enumerate(images.image_sizes):\r\n            scores_per_image = [x[img_idx].sigmoid() for x in pred_logits]\r\n            deltas_per_image = [x[img_idx] for x in pred_anchor_deltas]\r\n            results_per_image = self.sample_background_single_image(\r\n                anchors, scores_per_image, deltas_per_image, image_size\r\n            )\r\n            results.append(results_per_image)\r\n        return results\r\n\r\n    def sample_background_single_image(\r\n            self,\r\n            anchors: List[Boxes],\r\n            box_cls: List[Tensor],\r\n            box_delta: List[Tensor],\r\n            image_size: Tuple[int, int],\r\n    ):\r\n        boxes_all = []\r\n        scores_all = []\r\n\r\n        # Iterate over every feature level\r\n        for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta, anchors):\r\n            # (HxWxAxK,)\r\n            predicted_prob = box_cls_i.max(dim=1).values\r\n\r\n            # 1. Keep boxes with confidence score lower than threshold\r\n            keep_idxs = predicted_prob < self.test_score_thresh\r\n            anchor_idxs = nonzero_tuple(keep_idxs)[0]\r\n\r\n            # 2. Random sample boxes\r\n            anchor_idxs = anchor_idxs[\r\n                random.sample(range(len(anchor_idxs)), k=min(len(anchor_idxs), self.max_samples_per_level))]\r\n            predicted_prob = predicted_prob[anchor_idxs]\r\n            anchors_i = anchors_i[anchor_idxs]\r\n\r\n            boxes_all.append(anchors_i.tensor)\r\n            scores_all.append(predicted_prob)\r\n\r\n        boxes_all, scores_all = [\r\n            cat(x) for x in [boxes_all, scores_all]\r\n        ]\r\n\r\n        result = Instances(image_size)\r\n        result.pred_boxes = Boxes(boxes_all)\r\n        result.scores = 1. - scores_all  # the confidence score to be background\r\n        result.pred_classes = torch.tensor([self.num_classes for _ in range(len(scores_all))])\r\n        return result\r\n\r\n    def visualize_training(self, batched_inputs, results, feedbacks=None):\r\n        \"\"\"\r\n        A function used to visualize ground truth images and final network predictions.\r\n        It shows ground truth bounding boxes on the original image and up to 20\r\n        predicted object bounding boxes on the original image.\r\n\r\n        Args:\r\n            batched_inputs (list): a list that contains input to the model.\r\n            results (List[Instances]): a list of #images elements returned by forward_inference().\r\n        \"\"\"\r\n        from detectron2.utils.visualizer import Visualizer\r\n\r\n        assert len(batched_inputs) == len(\r\n            results\r\n        ), \"Cannot visualize inputs and results of different sizes\"\r\n        storage = get_event_storage()\r\n        max_boxes = 20\r\n\r\n        image_index = 0  # only visualize a single image\r\n        img = batched_inputs[image_index][\"image\"]\r\n        img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format)\r\n        v_gt = Visualizer(img, None)\r\n        v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index][\"instances\"].gt_boxes)\r\n        anno_img = v_gt.get_image()\r\n        processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1])\r\n        predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy()\r\n\r\n        v_pred = Visualizer(img, None)\r\n        v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes])\r\n        prop_img = v_pred.get_image()\r\n\r\n        num_classes = self.num_classes\r\n        if feedbacks is not None:\r\n            v_feedback_gt = Visualizer(img, None)\r\n            instance = feedbacks[0].to(torch.device(\"cpu\"))\r\n            v_feedback_gt = v_feedback_gt.overlay_instances(\r\n                boxes=instance.proposal_boxes[instance.gt_classes != num_classes])\r\n            feedback_gt_img = v_feedback_gt.get_image()\r\n\r\n            v_feedback_gf = Visualizer(img, None)\r\n            v_feedback_gf = v_feedback_gf.overlay_instances(\r\n                boxes=instance.proposal_boxes[instance.gt_classes == num_classes])\r\n            feedback_gf_img = v_feedback_gf.get_image()\r\n\r\n            vis_img = np.vstack((anno_img, prop_img, feedback_gt_img, feedback_gf_img))\r\n            vis_img = vis_img.transpose(2, 0, 1)\r\n            vis_name = f\"Top: GT; Middle: Pred; Bottom: Feedback GT, Feedback GF\"\r\n        else:\r\n            vis_img = np.concatenate((anno_img, prop_img), axis=1)\r\n            vis_img = vis_img.transpose(2, 0, 1)\r\n            vis_name = \"Left: GT bounding boxes;  Right: Predicted proposals\"\r\n\r\n        storage.put_image(vis_name, vis_img)\r\n"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/roi_heads/__init__.py",
    "content": "from .roi_heads import DecoupledRes5ROIHeads\n"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/roi_heads/fast_rcnn.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nfrom typing import Dict\r\n\r\nfrom detectron2.layers import cat\r\nfrom detectron2.modeling.roi_heads.fast_rcnn import (\r\n    _log_classification_stats,\r\n    FastRCNNOutputLayers\r\n)\r\nfrom detectron2.structures import Instances\r\nfrom tllib.modules.loss import LabelSmoothSoftmaxCEV1\r\n\r\nimport torch\r\n\r\n\r\ndef label_smoothing_cross_entropy(input, target, *, reduction=\"mean\", **kwargs):\r\n    \"\"\"\r\n    Same as `tllib.modules.loss.LabelSmoothSoftmaxCEV1`, but returns 0 (instead of nan)\r\n    for empty inputs.\r\n    \"\"\"\r\n    if target.numel() == 0 and reduction == \"mean\":\r\n        return input.sum() * 0.0  # connect the gradient\r\n    return LabelSmoothSoftmaxCEV1(reduction=reduction, **kwargs)(input, target)\r\n\r\n\r\nclass DecoupledFastRCNNOutputLayers(FastRCNNOutputLayers):\r\n    \"\"\"\r\n    Two linear layers for predicting Fast R-CNN outputs:\r\n\r\n    1. proposal-to-detection box regression deltas\r\n    2. classification scores\r\n\r\n    Replace cross-entropy with label-smoothing cross-entropy\r\n    \"\"\"\r\n\r\n    def losses(self, predictions, proposals):\r\n        \"\"\"\r\n        Args:\r\n            predictions: return values of :meth:`forward()`.\r\n            proposals (list[Instances]): proposals that match the features that were used\r\n                to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``,\r\n                ``gt_classes`` are expected.\r\n\r\n        Returns:\r\n            Dict[str, Tensor]: dict of losses\r\n        \"\"\"\r\n        scores, proposal_deltas = predictions\r\n\r\n        # parse classification outputs\r\n        gt_classes = (\r\n            cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)\r\n        )\r\n        _log_classification_stats(scores, gt_classes)\r\n\r\n        # parse box regression outputs\r\n        if len(proposals):\r\n            proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0)  # Nx4\r\n            assert not proposal_boxes.requires_grad, \"Proposals should not require gradients!\"\r\n            # If \"gt_boxes\" does not exist, the proposals must be all negative and\r\n            # should not be included in regression loss computation.\r\n            # Here we just use proposal_boxes as an arbitrary placeholder because its\r\n            # value won't be used in self.box_reg_loss().\r\n            gt_boxes = cat(\r\n                [(p.gt_boxes if p.has(\"gt_boxes\") else p.proposal_boxes).tensor for p in proposals],\r\n                dim=0,\r\n            )\r\n        else:\r\n            proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)\r\n\r\n        losses = {\r\n            \"loss_cls\": label_smoothing_cross_entropy(scores, gt_classes, reduction=\"mean\"),\r\n            \"loss_box_reg\": self.box_reg_loss(\r\n                proposal_boxes, gt_boxes, proposal_deltas, gt_classes\r\n            ),\r\n        }\r\n        return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}\r\n"
  },
  {
    "path": "tllib/alignment/d_adapt/modeling/roi_heads/roi_heads.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport torch\r\nimport numpy as np\r\nimport random\r\nfrom typing import List, Tuple, Dict\r\n\r\nfrom detectron2.structures import Boxes, Instances\r\nfrom detectron2.utils.events import get_event_storage\r\nfrom detectron2.layers import ShapeSpec, batched_nms\r\nfrom detectron2.modeling.roi_heads import (\r\n    ROI_HEADS_REGISTRY,\r\n    Res5ROIHeads,\r\n    StandardROIHeads\r\n)\r\nfrom detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference\r\nfrom detectron2.modeling.sampling import subsample_labels\r\nfrom detectron2.layers import nonzero_tuple\r\n\r\n\r\nfrom .fast_rcnn import DecoupledFastRCNNOutputLayers\r\n\r\n\r\n@ROI_HEADS_REGISTRY.register()\r\nclass DecoupledRes5ROIHeads(Res5ROIHeads):\r\n    \"\"\"\r\n    The ROIHeads in a typical \"C4\" R-CNN model, where\r\n    the box and mask head share the cropping and\r\n    the per-region feature computation by a Res5 block.\r\n\r\n    It typically contains logic to\r\n\r\n      1. when training on labeled source domain, match proposals with ground truth and sample them\r\n      2. when training on unlabeled target domain, match proposals with feedbacks from adaptors and sample them\r\n      3. crop the regions and extract per-region features using proposals\r\n      4. make per-region predictions with different heads\r\n    \"\"\"\r\n    def __init__(self, *args, **kwargs):\r\n        super(DecoupledRes5ROIHeads, self).__init__(*args, **kwargs)\r\n\r\n    @classmethod\r\n    def from_config(cls, cfg, input_shape):\r\n        # fmt: off\r\n        ret = super().from_config(cfg, input_shape)\r\n        ret[\"res5\"], out_channels = cls._build_res5_block(cfg)\r\n        box_predictor = DecoupledFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1))\r\n        ret[\"box_predictor\"] = box_predictor\r\n        return ret\r\n\r\n    def forward(self, images, features, proposals, targets=None, feedbacks=None, labeled=True):\r\n        \"\"\"\r\n        Prepare some proposals to be used to train the ROI heads.\r\n        When training on labeled source domain, it performs box matching between `proposals` and `targets`, and assigns\r\n        training labels to the proposals.\r\n        When training on unlabeled target domain, it performs box matching between `proposals` and `feedbacks`, and assigns\r\n        training labels to the proposals.\r\n        It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth\r\n        boxes, with a fraction of positives that is no larger than\r\n        ``self.positive_fraction``.\r\n\r\n        Args:\r\n            images (ImageList):\r\n            features (dict[str,Tensor]): input data as a mapping from feature\r\n                map name to tensor. Axis 0 represents the number of images `N` in\r\n                the input data; axes 1-3 are channels, height, and width, which may\r\n                vary between feature maps (e.g., if a feature pyramid is used).\r\n            proposals (list[Instances]): length `N` list of `Instances`. The i-th\r\n                `Instances` contains object proposals for the i-th input image,\r\n                with fields \"proposal_boxes\" and \"objectness_logits\".\r\n            targets (list[Instances], optional): length `N` list of `Instances`. The i-th\r\n                `Instances` contains the ground-truth per-instance annotations\r\n                for the i-th input image.  Specify `targets` during training only.\r\n                It may have the following fields:\r\n\r\n                - gt_boxes: the bounding box of each instance.\r\n                - gt_classes: the label for each instance with a category ranging in [0, #class].\r\n                - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.\r\n                - gt_keypoints: NxKx3, the groud-truth keypoints for each instance.\r\n            feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th\r\n                `Instances` contains the feedback of per-instance annotations\r\n                for the i-th input image.  Specify `feedbacks` during training only.\r\n                It have the same fields as `targets`.\r\n            labeled (bool, optional): whether has ground-truth label\r\n\r\n        Returns:\r\n            tuple[list[Instances], list[Instances], dict]:\r\n                a tuple containing foreground proposals (`Instances`), background proposals (`Instances`) and a dict of different losses.\r\n\r\n            Each `Instances` has the following fields:\r\n\r\n                - proposal_boxes: the proposal boxes\r\n                - gt_boxes: the ground-truth box that the proposal is assigned to\r\n                  (this is only meaningful if the proposal has a label > 0; if label = 0\r\n                  then the ground-truth box is random)\r\n\r\n                Other fields such as \"gt_classes\", \"gt_masks\", that's included in `targets`.\r\n        \"\"\"\r\n        del images\r\n\r\n        if self.training:\r\n            assert targets\r\n            if labeled:\r\n                proposals = self.label_and_sample_proposals(proposals, targets)\r\n            else:\r\n                proposals = self.label_and_sample_feedbacks(feedbacks)\r\n            del targets\r\n        proposal_boxes = [x.proposal_boxes for x in proposals]\r\n\r\n        box_features = self._shared_roi_transform(\r\n            [features[f] for f in self.in_features], proposal_boxes\r\n        )\r\n        predictions = self.box_predictor(box_features.mean(dim=[2, 3]))\r\n\r\n        if self.training:\r\n            del features\r\n            losses = self.box_predictor.losses(predictions, proposals)\r\n            if not labeled:\r\n                losses.pop(\"loss_box_reg\")\r\n            return [], [], losses\r\n        else:\r\n            pred_instances, _ = self.box_predictor.inference(predictions, proposals)\r\n            boxes = self.box_predictor.predict_boxes(predictions, proposals)\r\n            scores = self.box_predictor.predict_probs(predictions, proposals)\r\n            image_shapes = [x.image_size for x in proposals]\r\n            pred_instances, _ = fast_rcnn_inference(\r\n                boxes,\r\n                scores,\r\n                image_shapes,\r\n                self.box_predictor.test_score_thresh,\r\n                self.box_predictor.test_nms_thresh,\r\n                self.box_predictor.test_topk_per_image,\r\n            )\r\n            background_instances, _ = fast_rcnn_sample_background(\r\n                [box.tensor for box in proposal_boxes],\r\n                scores,\r\n                image_shapes,\r\n                self.box_predictor.test_score_thresh,\r\n                self.box_predictor.test_nms_thresh,\r\n                self.box_predictor.test_topk_per_image,\r\n            )\r\n            pred_instances = self.forward_with_given_boxes(features, pred_instances)\r\n            background_instances = self.forward_with_given_boxes(features, background_instances)\r\n            return pred_instances, background_instances, {}\r\n\r\n    @torch.no_grad()\r\n    def label_and_sample_feedbacks(\r\n            self, feedbacks, batch_size_per_image=256\r\n    ) -> List[Instances]:\r\n        \"\"\"\r\n        Prepare some proposals to be used to train the ROI heads.\r\n        It performs box matching between `proposals` and `feedbacks`, and assigns\r\n        training labels to the proposals.\r\n        It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth\r\n        boxes, with a fraction of positives that is no larger than\r\n        ``self.positive_fraction``.\r\n\r\n        Args:\r\n            feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th\r\n                `Instances` contains the feedback of per-instance annotations\r\n                for the i-th input image.  Specify `feedbacks` during training only.\r\n                It have the same fields as `targets`.\r\n\r\n        Returns:\r\n            list[Instances]:\r\n                length `N` list of `Instances`s containing the proposals\r\n                sampled for training. Each `Instances` has the following fields:\r\n\r\n                - proposal_boxes: the proposal boxes\r\n                - gt_boxes: the ground-truth box that the proposal is assigned to\r\n                  (this is only meaningful if the proposal has a label > 0; if label = 0\r\n                  then the ground-truth box is random)\r\n\r\n                Other fields such as \"gt_classes\", \"gt_masks\", that's included in `targets`.\r\n        \"\"\"\r\n\r\n        proposals_with_gt = []\r\n\r\n        num_fg_samples = []\r\n        num_bg_samples = []\r\n        for feedbacks_per_image in feedbacks:\r\n            gt_classes = feedbacks_per_image.gt_classes\r\n            positive = nonzero_tuple((gt_classes != -1) & (gt_classes != self.num_classes))[0]\r\n            # ensure each batch consists the same number bg and fg boxes\r\n            batch_size = min(batch_size_per_image, max(2 * positive.numel(), 1))\r\n            sampled_fg_idxs, sampled_bg_idxs = subsample_labels(\r\n                gt_classes, batch_size, self.positive_fraction, self.num_classes\r\n            )\r\n\r\n            sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)\r\n            gt_classes = gt_classes[sampled_idxs]\r\n\r\n            # Set target attributes of the sampled proposals:\r\n            proposals_per_image = feedbacks_per_image[sampled_idxs]\r\n            proposals_per_image.gt_classes = gt_classes\r\n\r\n            num_bg_samples.append((gt_classes == self.num_classes).sum().item())\r\n            num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])\r\n            proposals_with_gt.append(proposals_per_image)\r\n\r\n        # Log the number of fg/bg samples that are selected for training ROI heads\r\n        storage = get_event_storage()\r\n        storage.put_scalar(\"roi_head_pseudo/num_fg_samples\", np.mean(num_fg_samples))\r\n        storage.put_scalar(\"roi_head_pseudo/num_bg_samples\", np.mean(num_bg_samples))\r\n\r\n        return proposals_with_gt\r\n\r\n\r\n@ROI_HEADS_REGISTRY.register()\r\nclass DecoupledStandardROIHeads(StandardROIHeads):\r\n    \"\"\"\r\n    The Standard ROIHeads used by most models, such as FPN and C5.\r\n    It's \"standard\" in a sense that there is no ROI transform sharing\r\n    or feature sharing between tasks.\r\n    Each head independently processes the input features by each head's\r\n    own pooler and head.\r\n\r\n    It typically contains logic to\r\n\r\n      1. when training on labeled source domain, match proposals with ground truth and sample them\r\n      2. when training on unlabeled target domain, match proposals with feedbacks from adaptors and sample them\r\n      3. crop the regions and extract per-region features using proposals\r\n      4. make per-region predictions with different heads\r\n    \"\"\"\r\n    def __init__(self, *args, **kwargs):\r\n        super(DecoupledStandardROIHeads, self).__init__(*args, **kwargs)\r\n\r\n    @classmethod\r\n    def from_config(cls, cfg, input_shape):\r\n        # fmt: off\r\n        ret = super().from_config(cfg, input_shape)\r\n        box_predictor = DecoupledFastRCNNOutputLayers(cfg, ret['box_head'].output_shape)\r\n        ret[\"box_predictor\"] = box_predictor\r\n        return ret\r\n\r\n    def forward(self, images, features, proposals, targets=None, feedbacks=None, labeled=True):\r\n        \"\"\"\r\n        Prepare some proposals to be used to train the ROI heads.\r\n        When training on labeled source domain, it performs box matching between `proposals` and `targets`, and assigns\r\n        training labels to the proposals.\r\n        When training on unlabeled target domain, it performs box matching between `proposals` and `feedbacks`, and assigns\r\n        training labels to the proposals.\r\n        It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth\r\n        boxes, with a fraction of positives that is no larger than\r\n        ``self.positive_fraction``.\r\n\r\n        Args:\r\n            images (ImageList):\r\n            features (dict[str,Tensor]): input data as a mapping from feature\r\n                map name to tensor. Axis 0 represents the number of images `N` in\r\n                the input data; axes 1-3 are channels, height, and width, which may\r\n                vary between feature maps (e.g., if a feature pyramid is used).\r\n            proposals (list[Instances]): length `N` list of `Instances`. The i-th\r\n                `Instances` contains object proposals for the i-th input image,\r\n                with fields \"proposal_boxes\" and \"objectness_logits\".\r\n            targets (list[Instances], optional): length `N` list of `Instances`. The i-th\r\n                `Instances` contains the ground-truth per-instance annotations\r\n                for the i-th input image.  Specify `targets` during training only.\r\n                It may have the following fields:\r\n\r\n                - gt_boxes: the bounding box of each instance.\r\n                - gt_classes: the label for each instance with a category ranging in [0, #class].\r\n                - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.\r\n                - gt_keypoints: NxKx3, the groud-truth keypoints for each instance.\r\n            feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th\r\n                `Instances` contains the feedback of per-instance annotations\r\n                for the i-th input image.  Specify `feedbacks` during training only.\r\n                It have the same fields as `targets`.\r\n            labeled (bool, optional): whether has ground-truth label\r\n\r\n        Returns:\r\n            tuple[list[Instances], list[Instances], dict]:\r\n                a tuple containing foreground proposals (`Instances`), background proposals (`Instances`) and a dict of different losses.\r\n\r\n            Each `Instances` has the following fields:\r\n\r\n                - proposal_boxes: the proposal boxes\r\n                - gt_boxes: the ground-truth box that the proposal is assigned to\r\n                  (this is only meaningful if the proposal has a label > 0; if label = 0\r\n                  then the ground-truth box is random)\r\n\r\n                Other fields such as \"gt_classes\", \"gt_masks\", that's included in `targets`.\r\n        \"\"\"\r\n        del images\r\n\r\n        if self.training:\r\n            assert targets\r\n            if labeled:\r\n                proposals = self.label_and_sample_proposals(proposals, targets)\r\n            else:\r\n                proposals = self.label_and_sample_feedbacks(feedbacks)\r\n            del targets\r\n\r\n        if self.training:\r\n            losses = self._forward_box(features, proposals)\r\n            # Usually the original proposals used by the box head are used by the mask, keypoint\r\n            # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes\r\n            # predicted by the box head.\r\n            losses.update(self._forward_mask(features, proposals))\r\n            losses.update(self._forward_keypoint(features, proposals))\r\n\r\n            if not labeled:\r\n                losses.pop('loss_box_reg')\r\n            return [], [], losses\r\n        else:\r\n            pred_instances, predictions = self._forward_box(features, proposals)\r\n            scores = self.box_predictor.predict_probs(predictions, proposals)\r\n            image_shapes = [x.image_size for x in proposals]\r\n            proposal_boxes = [x.proposal_boxes for x in proposals]\r\n            background_instances, _ = fast_rcnn_sample_background(\r\n                [box.tensor for box in proposal_boxes],\r\n                scores,\r\n                image_shapes,\r\n                self.box_predictor.test_score_thresh,\r\n                self.box_predictor.test_nms_thresh,\r\n                self.box_predictor.test_topk_per_image,\r\n            )\r\n            pred_instances = self.forward_with_given_boxes(features, pred_instances)\r\n            background_instances = self.forward_with_given_boxes(features, background_instances)\r\n            return pred_instances, background_instances, {}\r\n\r\n    def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]):\r\n        \"\"\"\r\n        Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`,\r\n            the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument.\r\n\r\n        Args:\r\n            features (dict[str, Tensor]): mapping from feature map names to tensor.\r\n                Same as in :meth:`ROIHeads.forward`.\r\n            proposals (list[Instances]): the per-image object proposals with\r\n                their matching ground truth.\r\n                Each has fields \"proposal_boxes\", and \"objectness_logits\",\r\n                \"gt_classes\", \"gt_boxes\".\r\n\r\n        Returns:\r\n            In training, a dict of losses.\r\n            In inference, a list of `Instances`, the predicted instances.\r\n        \"\"\"\r\n        features = [features[f] for f in self.box_in_features]\r\n        box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])\r\n        box_features = self.box_head(box_features)\r\n        predictions = self.box_predictor(box_features)\r\n        del box_features\r\n\r\n        if self.training:\r\n            losses = self.box_predictor.losses(predictions, proposals)\r\n            # proposals is modified in-place below, so losses must be computed first.\r\n            if self.train_on_pred_boxes:\r\n                with torch.no_grad():\r\n                    pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(\r\n                        predictions, proposals\r\n                    )\r\n                    for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):\r\n                        proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image)\r\n            return losses\r\n        else:\r\n            pred_instances, _ = self.box_predictor.inference(predictions, proposals)\r\n            return pred_instances, predictions\r\n\r\n    @torch.no_grad()\r\n    def label_and_sample_feedbacks(\r\n            self, feedbacks, batch_size_per_image=256\r\n    ) -> List[Instances]:\r\n        \"\"\"\r\n        Prepare some proposals to be used to train the ROI heads.\r\n        It performs box matching between `proposals` and `targets`, and assigns\r\n        training labels to the proposals.\r\n        It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth\r\n        boxes, with a fraction of positives that is no larger than\r\n        ``self.positive_fraction``.\r\n\r\n        Args:\r\n            See :meth:`ROIHeads.forward`\r\n\r\n        Returns:\r\n            list[Instances]:\r\n                length `N` list of `Instances`s containing the proposals\r\n                sampled for training. Each `Instances` has the following fields:\r\n\r\n                - proposal_boxes: the proposal boxes\r\n                - gt_boxes: the ground-truth box that the proposal is assigned to\r\n                  (this is only meaningful if the proposal has a label > 0; if label = 0\r\n                  then the ground-truth box is random)\r\n\r\n                Other fields such as \"gt_classes\", \"gt_masks\", that's included in `targets`.\r\n        \"\"\"\r\n\r\n        proposals_with_gt = []\r\n\r\n        num_fg_samples = []\r\n        num_bg_samples = []\r\n        for feedbacks_per_image in feedbacks:\r\n            gt_classes = feedbacks_per_image.gt_classes\r\n            positive = nonzero_tuple((gt_classes != -1) & (gt_classes != self.num_classes))[0]\r\n            # ensure each batch consists the same number bg and fg boxes\r\n            batch_size = min(batch_size_per_image, max(2 * positive.numel(), 1))\r\n            sampled_fg_idxs, sampled_bg_idxs = subsample_labels(\r\n                gt_classes, batch_size, self.positive_fraction, self.num_classes\r\n            )\r\n\r\n            sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)\r\n            gt_classes = gt_classes[sampled_idxs]\r\n\r\n            # Set target attributes of the sampled proposals:\r\n            proposals_per_image = feedbacks_per_image[sampled_idxs]\r\n            proposals_per_image.gt_classes = gt_classes\r\n\r\n            num_bg_samples.append((gt_classes == self.num_classes).sum().item())\r\n            num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])\r\n            proposals_with_gt.append(proposals_per_image)\r\n\r\n        # Log the number of fg/bg samples that are selected for training ROI heads\r\n        storage = get_event_storage()\r\n        storage.put_scalar(\"roi_head_pseudo/num_fg_samples\", np.mean(num_fg_samples))\r\n        storage.put_scalar(\"roi_head_pseudo/num_bg_samples\", np.mean(num_bg_samples))\r\n\r\n        return proposals_with_gt\r\n\r\n\r\ndef fast_rcnn_sample_background(\r\n    boxes: List[torch.Tensor],\r\n    scores: List[torch.Tensor],\r\n    image_shapes: List[Tuple[int, int]],\r\n    score_thresh: float,\r\n    nms_thresh: float,\r\n    topk_per_image: int,\r\n):\r\n    \"\"\"\r\n    Call `fast_rcnn_sample_background_single_image` for all images.\r\n\r\n    Args:\r\n        boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic\r\n            boxes for each image. Element i has shape (Ri, K * 4) if doing\r\n            class-specific regression, or (Ri, 4) if doing class-agnostic\r\n            regression, where Ri is the number of predicted objects for image i.\r\n            This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`.\r\n        scores (list[Tensor]): A list of Tensors of predicted class scores for each image.\r\n            Element i has shape (Ri, K + 1), where Ri is the number of predicted objects\r\n            for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`.\r\n        image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch.\r\n        score_thresh (float): Only return detections with a confidence score exceeding this\r\n            threshold.\r\n        nms_thresh (float):  The threshold to use for box non-maximum suppression. Value in [0, 1].\r\n        topk_per_image (int): The number of top scoring detections to return. Set < 0 to return\r\n            all detections.\r\n\r\n    Returns:\r\n        instances: (list[Instances]): A list of N instances, one for each image in the batch,\r\n            that stores the background proposals.\r\n        kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates\r\n            the corresponding boxes/scores index in [0, Ri) from the input, for image i.\r\n    \"\"\"\r\n    result_per_image = [\r\n        fast_rcnn_sample_background_single_image(\r\n            boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image\r\n        )\r\n        for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)\r\n    ]\r\n    return [x[0] for x in result_per_image], [x[1] for x in result_per_image]\r\n\r\n\r\ndef fast_rcnn_sample_background_single_image(\r\n    boxes,\r\n    scores,\r\n    image_shape: Tuple[int, int],\r\n    score_thresh: float,\r\n    nms_thresh: float,\r\n    topk_per_image: int,\r\n):\r\n    \"\"\"\r\n    Single-image background samples. .\r\n\r\n    Args:\r\n        Same as `fast_rcnn_sample_background`, but with boxes, scores, and image shapes\r\n        per image.\r\n\r\n    Returns:\r\n        Same as `fast_rcnn_sample_background`, but for only one image.\r\n    \"\"\"\r\n    valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)\r\n    if not valid_mask.all():\r\n        boxes = boxes[valid_mask]\r\n        scores = scores[valid_mask]\r\n\r\n    num_classes = scores.shape[1]\r\n    # Only keep background proposals\r\n    scores = scores[:, -1:]\r\n    # Convert to Boxes to use the `clip` function ...\r\n    boxes = Boxes(boxes.reshape(-1, 4))\r\n    boxes.clip(image_shape)\r\n    boxes = boxes.tensor.view(-1, 1, 4)  # R x C x 4\r\n\r\n    # 1. Filter results based on detection scores. It can make NMS more efficient\r\n    #    by filtering out low-confidence detections.\r\n    filter_mask = scores > score_thresh  # R\r\n    # R' x 2. First column contains indices of the R predictions;\r\n    # Second column contains indices of classes.\r\n    filter_inds = filter_mask.nonzero()\r\n    boxes = boxes[filter_mask]\r\n    scores = scores[filter_mask]\r\n\r\n    # 2. Apply NMS only for background class\r\n    keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)\r\n    if 0 <= topk_per_image < len(keep):\r\n        idx = list(range(len(keep)))\r\n        idx = random.sample(idx, k=topk_per_image)\r\n        idx = sorted(idx)\r\n        keep = keep[idx]\r\n    boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]\r\n\r\n    result = Instances(image_shape)\r\n    result.pred_boxes = Boxes(boxes)\r\n    result.scores = scores\r\n    result.pred_classes = filter_inds[:, 1] + num_classes - 1\r\n    return result, filter_inds[:, 0]\r\n"
  },
  {
    "path": "tllib/alignment/d_adapt/proposal.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport torch\r\nimport copy\r\nimport numpy as np\r\nimport os\r\nimport json\r\nfrom typing import Optional, Callable, List\r\nimport random\r\nimport pprint\r\n\r\nimport torchvision.datasets as datasets\r\nfrom torchvision.datasets.folder import default_loader\r\nfrom torchvision.transforms.functional import crop\r\nfrom detectron2.structures import pairwise_iou\r\nfrom detectron2.evaluation.evaluator import DatasetEvaluator\r\nfrom detectron2.data.dataset_mapper import DatasetMapper\r\nimport detectron2.data.detection_utils as utils\r\nimport detectron2.data.transforms as T\r\n\r\n\r\nclass ProposalMapper(DatasetMapper):\r\n    \"\"\"\r\n    A callable which takes a dataset dict in Detectron2 Dataset format,\r\n    and map it into a format used by the model.\r\n\r\n    This is the default callable to be used to map your dataset dict into training data.\r\n    You may need to follow it to implement your own one for customized logic,\r\n    such as a different way to read or transform images.\r\n    See :doc:`/tutorials/data_loading` for details.\r\n\r\n    The callable currently does the following:\r\n\r\n    1. Read the image from \"file_name\"\r\n    2. Prepare data and annotations to Tensor and :class:`Instances`\r\n    \"\"\"\r\n\r\n    def __call__(self, dataset_dict):\r\n        \"\"\"\r\n        Args:\r\n            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.\r\n\r\n        Returns:\r\n            dict: a format that builtin models in detectron2 accept\r\n        \"\"\"\r\n        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below\r\n        # USER: Write your own image loading if it's not from a file\r\n        image = utils.read_image(dataset_dict[\"file_name\"], format=self.image_format)\r\n        utils.check_image_size(dataset_dict, image)\r\n        origin_image_shape = image.shape[:2]  # h, w\r\n\r\n        aug_input = T.AugInput(image)\r\n        image = aug_input.image\r\n\r\n        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,\r\n        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.\r\n        # Therefore it's important to use torch.Tensor.\r\n        dataset_dict[\"image\"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))\r\n\r\n        if \"annotations\" in dataset_dict:\r\n            # USER: Modify this if you want to keep them for some reason.\r\n            for anno in dataset_dict[\"annotations\"]:\r\n                if not self.use_instance_mask:\r\n                    anno.pop(\"segmentation\", None)\r\n                if not self.use_keypoint:\r\n                    anno.pop(\"keypoints\", None)\r\n\r\n            # USER: Implement additional transformations if you have other types of data\r\n            annos = [\r\n                obj\r\n                for obj in dataset_dict.pop(\"annotations\")\r\n                if obj.get(\"iscrowd\", 0) == 0\r\n            ]\r\n            instances = utils.annotations_to_instances(\r\n                annos, origin_image_shape, mask_format=self.instance_mask_format\r\n            )\r\n\r\n            # After transforms such as cropping are applied, the bounding box may no longer\r\n            # tightly bound the object. As an example, imagine a triangle object\r\n            # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight\r\n            # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to\r\n            # the intersection of original bounding box and the cropping box.\r\n            if self.recompute_boxes:\r\n                instances.gt_boxes = instances.gt_masks.get_bounding_boxes()\r\n            dataset_dict[\"instances\"] = utils.filter_empty_instances(instances)\r\n        return dataset_dict\r\n\r\n\r\nclass ProposalGenerator(DatasetEvaluator):\r\n    \"\"\"\r\n    The function :func:`inference_on_dataset` runs the model over\r\n    all samples in the dataset, and have a ProposalGenerator to generate proposals for each inputs/outputs.\r\n\r\n    This class will accumulate information of the inputs/outputs (by :meth:`process`),\r\n    and generate proposals results in the end (by :meth:`evaluate`).\r\n    \"\"\"\r\n    def __init__(self, iou_threshold=(0.4, 0.5), num_classes=20, *args, **kwargs):\r\n        super(ProposalGenerator, self).__init__(*args, **kwargs)\r\n        self.fg_proposal_list = []\r\n        self.bg_proposal_list = []\r\n        self.iou_threshold = iou_threshold\r\n        self.num_classes = num_classes\r\n\r\n    def process_type(self, inputs, outputs, type='instances'):\r\n        cpu_device = torch.device('cpu')\r\n        input_instance = inputs[0]['instances'].to(cpu_device)\r\n        output_instance = outputs[0][type].to(cpu_device)\r\n        filename = inputs[0]['file_name']\r\n        pred_boxes = output_instance.pred_boxes\r\n        pred_scores = output_instance.scores\r\n        pred_classes = output_instance.pred_classes\r\n        proposal = Proposal(\r\n            image_id=inputs[0]['image_id'],\r\n            filename=filename,\r\n            pred_boxes=pred_boxes.tensor.numpy(),\r\n            pred_classes=pred_classes.numpy(),\r\n            pred_scores=pred_scores.numpy(),\r\n        )\r\n\r\n        if hasattr(input_instance, 'gt_boxes'):\r\n            gt_boxes = input_instance.gt_boxes\r\n            # assign a gt label for each pred_box\r\n            if pred_boxes.tensor.shape[0] == 0:\r\n                proposal.gt_fg_classes = proposal.gt_classes = proposal.gt_ious = proposal.gt_boxes = np.array([])\r\n            elif gt_boxes.tensor.shape[0] == 0:\r\n                proposal.gt_fg_classes = proposal.gt_classes = np.array([self.num_classes for _ in range(pred_boxes.tensor.shape[0])])\r\n                proposal.gt_ious = np.array([0. for _ in range(pred_boxes.tensor.shape[0])])\r\n                proposal.gt_boxes = np.array([[0, 0, 0, 0] for _ in range(pred_boxes.tensor.shape[0])])\r\n            else:\r\n                gt_ious, gt_classes_idx = pairwise_iou(pred_boxes, gt_boxes).max(dim=1)\r\n                gt_classes = input_instance.gt_classes[gt_classes_idx]\r\n                proposal.gt_fg_classes = copy.deepcopy(gt_classes.numpy())\r\n                gt_classes[gt_ious <= self.iou_threshold[0]] = self.num_classes  # background classes\r\n                gt_classes[(self.iou_threshold[0] < gt_ious) & (gt_ious <= self.iou_threshold[1])] = -1  # ignore\r\n                proposal.gt_classes = gt_classes.numpy()\r\n                proposal.gt_ious = gt_ious.numpy()\r\n                proposal.gt_boxes = input_instance.gt_boxes[gt_classes_idx].tensor.numpy()\r\n\r\n        return proposal\r\n\r\n    def process(self, inputs, outputs):\r\n        self.fg_proposal_list.append(self.process_type(inputs, outputs, \"instances\"))\r\n        self.bg_proposal_list.append(self.process_type(inputs, outputs, \"background\"))\r\n\r\n    def evaluate(self):\r\n        return self.fg_proposal_list, self.bg_proposal_list\r\n\r\n\r\nclass Proposal:\r\n    \"\"\"\r\n    A data structure that stores the proposals for a single image.\r\n\r\n    Args:\r\n        image_id (str): unique image identifier\r\n        filename (str): image filename\r\n        pred_boxes (numpy.ndarray): predicted boxes\r\n        pred_classes (numpy.ndarray): predicted classes\r\n        pred_scores (numpy.ndarray): class confidence score\r\n        gt_classes (numpy.ndarray, optional): ground-truth classes, including background classes\r\n        gt_boxes (numpy.ndarray, optional): ground-truth boxes\r\n        gt_ious (numpy.ndarray, optional): IoU between predicted boxes and ground-truth boxes\r\n        gt_fg_classes (numpy.ndarray, optional): ground-truth foreground classes, not including background classes\r\n\r\n    \"\"\"\r\n    def __init__(self, image_id, filename, pred_boxes, pred_classes, pred_scores,\r\n                 gt_classes=None, gt_boxes=None, gt_ious=None, gt_fg_classes=None):\r\n        self.image_id = image_id\r\n        self.filename = filename\r\n        self.pred_boxes = pred_boxes\r\n        self.pred_classes = pred_classes\r\n        self.pred_scores = pred_scores\r\n        self.gt_classes = gt_classes\r\n        self.gt_boxes = gt_boxes\r\n        self.gt_ious = gt_ious\r\n        self.gt_fg_classes = gt_fg_classes\r\n\r\n    def to_dict(self):\r\n        return {\r\n            \"__proposal__\": True,\r\n            \"image_id\": self.image_id,\r\n            \"filename\": self.filename,\r\n            \"pred_boxes\": self.pred_boxes.tolist(),\r\n            \"pred_classes\": self.pred_classes.tolist(),\r\n            \"pred_scores\": self.pred_scores.tolist(),\r\n            \"gt_classes\": self.gt_classes.tolist(),\r\n            \"gt_boxes\": self.gt_boxes.tolist(),\r\n            \"gt_ious\": self.gt_ious.tolist(),\r\n            \"gt_fg_classes\": self.gt_fg_classes.tolist()\r\n        }\r\n\r\n    def __str__(self):\r\n        pp = pprint.PrettyPrinter(indent=2)\r\n        return pp.pformat(self.to_dict())\r\n\r\n    def __len__(self):\r\n        return len(self.pred_boxes)\r\n\r\n    def __getitem__(self, item):\r\n        return Proposal(\r\n            image_id=self.image_id,\r\n            filename=self.filename,\r\n            pred_boxes=self.pred_boxes[item],\r\n            pred_classes=self.pred_classes[item],\r\n            pred_scores=self.pred_scores[item],\r\n            gt_classes=self.gt_classes[item],\r\n            gt_boxes=self.gt_boxes[item],\r\n            gt_ious=self.gt_ious[item],\r\n            gt_fg_classes=self.gt_fg_classes[item]\r\n        )\r\n\r\n\r\nclass ProposalEncoder(json.JSONEncoder):\r\n    def default(self, obj):\r\n        if isinstance(obj, Proposal):\r\n            return obj.to_dict()\r\n        return json.JSONEncoder.default(self, obj)\r\n\r\n\r\ndef asProposal(dict):\r\n    if '__proposal__' in dict:\r\n        return Proposal(\r\n            dict[\"image_id\"],\r\n            dict[\"filename\"],\r\n            np.array(dict[\"pred_boxes\"]),\r\n            np.array(dict[\"pred_classes\"]),\r\n            np.array(dict[\"pred_scores\"]),\r\n            np.array(dict[\"gt_classes\"]),\r\n            np.array(dict[\"gt_boxes\"]),\r\n            np.array(dict[\"gt_ious\"]),\r\n            np.array(dict[\"gt_fg_classes\"])\r\n        )\r\n    return dict\r\n\r\n\r\nclass PersistentProposalList(list):\r\n    \"\"\"\r\n    A data structure that stores the proposals for a dataset.\r\n\r\n    Args:\r\n        filename (str, optional): filename indicating where to cache\r\n    \"\"\"\r\n    def __init__(self, filename=None):\r\n        super(PersistentProposalList, self).__init__()\r\n        self.filename = filename\r\n\r\n    def load(self):\r\n        \"\"\"\r\n        Load from cache.\r\n\r\n        Return:\r\n            whether succeed\r\n        \"\"\"\r\n        if os.path.exists(self.filename):\r\n            print(\"Reading from cache: {}\".format(self.filename))\r\n            with open(self.filename, \"r\") as f:\r\n                self.extend(json.load(f, object_hook=asProposal))\r\n            return True\r\n        else:\r\n            return False\r\n\r\n    def flush(self):\r\n        \"\"\"\r\n        Flush to cache.\r\n        \"\"\"\r\n        os.makedirs(os.path.dirname(self.filename), exist_ok=True)\r\n        with open(self.filename, \"w\") as f:\r\n            json.dump(self, f, cls=ProposalEncoder)\r\n        print(\"Write to cache: {}\".format(self.filename))\r\n\r\n\r\ndef flatten(proposal_list, max_number=10000):\r\n    \"\"\"\r\n    Flatten a list of proposals\r\n\r\n    Args:\r\n        proposal_list (list):  a list of proposals grouped by images\r\n        max_number (int): maximum number of kept proposals for each image\r\n\r\n    \"\"\"\r\n    flattened_list = []\r\n    for proposals in proposal_list:\r\n        for i in range(min(len(proposals), max_number)):\r\n            flattened_list.append(proposals[i:i+1])\r\n    return flattened_list\r\n\r\n\r\nclass ProposalDataset(datasets.VisionDataset):\r\n    \"\"\"\r\n    A dataset for proposals.\r\n\r\n    Args:\r\n        proposal_list (list): list of Proposal\r\n        transform (callable, optional): A function/transform that  takes in an PIL image\r\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\r\n        crop_func: (ExpandCrop, optional):\r\n    \"\"\"\r\n    def __init__(self, proposal_list: List[Proposal], transform: Optional[Callable] = None, crop_func=None):\r\n        super(ProposalDataset, self).__init__(\"\", transform=transform)\r\n        self.proposal_list = list(filter(lambda p: len(p) > 0, proposal_list))  # remove images without proposals\r\n        self.loader = default_loader\r\n        self.crop_func = crop_func\r\n\r\n    def __getitem__(self, index: int):\r\n        # get proposals for the index-th image\r\n        proposals = self.proposal_list[index]\r\n        img = self.loader(proposals.filename)\r\n\r\n        # random sample a proposal\r\n        proposal = proposals[random.randint(0, len(proposals)-1)]\r\n        image_width, image_height = img.width, img.height\r\n        # proposal_dict = proposal.to_dict()\r\n        # proposal_dict.update(width=img.width, height=img.height)\r\n\r\n        # crop the proposal from the whole image\r\n        x1, y1, x2, y2 = proposal.pred_boxes\r\n        top, left, height, width = int(y1), int(x1), int(y2 - y1), int(x2 - x1)\r\n        if self.crop_func is not None:\r\n            top, left, height, width = self.crop_func(img, top, left, height, width)\r\n        img = crop(img, top, left, height, width)\r\n\r\n        if self.transform is not None:\r\n            img = self.transform(img)\r\n\r\n        return img, {\r\n            \"image_id\": proposal.image_id,\r\n            \"filename\": proposal.filename,\r\n            \"pred_boxes\": proposal.pred_boxes.astype(np.float),\r\n            \"pred_classes\": proposal.pred_classes.astype(np.long),\r\n            \"pred_scores\": proposal.pred_scores.astype(np.float),\r\n            \"gt_classes\": proposal.gt_classes.astype(np.long),\r\n            \"gt_boxes\": proposal.gt_boxes.astype(np.float),\r\n            \"gt_ious\": proposal.gt_ious.astype(np.float),\r\n            \"gt_fg_classes\": proposal.gt_fg_classes.astype(np.long),\r\n            \"width\": image_width,\r\n            \"height\": image_height\r\n        }\r\n\r\n    def __len__(self):\r\n        return len(self.proposal_list)\r\n\r\n\r\nclass ExpandCrop:\r\n    \"\"\"\r\n    The input of the bounding box adaptor (the crops of objects) will be larger than the original\r\n    predicted box, so that the bounding box adapter could access more location information.\r\n    \"\"\"\r\n    def __init__(self, expand=1.):\r\n        self.expand = expand\r\n\r\n    def __call__(self, img, top, left, height, width):\r\n        cx = left + width / 2.\r\n        cy = top + height / 2.\r\n        height = round(height * self.expand)\r\n        width = round(width * self.expand)\r\n        new_top = round(cy - height / 2.)\r\n        new_left = round(cx - width / 2.)\r\n        return new_top, new_left, height, width"
  },
  {
    "path": "tllib/alignment/dan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, Sequence\nimport torch\nimport torch.nn as nn\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\n\n\n__all__ = ['MultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']\n\n\nclass MultipleKernelMaximumMeanDiscrepancy(nn.Module):\n    r\"\"\"The Multiple Kernel Maximum Mean Discrepancy (MK-MMD) used in\n    `Learning Transferable Features with Deep Adaptation Networks (ICML 2015) <https://arxiv.org/pdf/1502.02791>`_\n\n    Given source domain :math:`\\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\\mathcal{D}_t`\n    of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate\n    activations as :math:`\\{z_i^s\\}_{i=1}^{n_s}` and :math:`\\{z_i^t\\}_{i=1}^{n_t}`.\n    The MK-MMD :math:`D_k (P, Q)` between probability distributions P and Q is defined as\n\n    .. math::\n        D_k(P, Q) \\triangleq \\| E_p [\\phi(z^s)] - E_q [\\phi(z^t)] \\|^2_{\\mathcal{H}_k},\n\n    :math:`k` is a kernel function in the function space\n\n    .. math::\n        \\mathcal{K} \\triangleq \\{ k=\\sum_{u=1}^{m}\\beta_{u} k_{u} \\}\n\n    where :math:`k_{u}` is a single kernel.\n\n    Using kernel trick, MK-MMD can be computed as\n\n    .. math::\n        \\hat{D}_k(P, Q) &=\n        \\dfrac{1}{n_s^2} \\sum_{i=1}^{n_s}\\sum_{j=1}^{n_s} k(z_i^{s}, z_j^{s})\\\\\n        &+ \\dfrac{1}{n_t^2} \\sum_{i=1}^{n_t}\\sum_{j=1}^{n_t} k(z_i^{t}, z_j^{t})\\\\\n        &- \\dfrac{2}{n_s n_t} \\sum_{i=1}^{n_s}\\sum_{j=1}^{n_t} k(z_i^{s}, z_j^{t}).\\\\\n\n    Args:\n        kernels (tuple(torch.nn.Module)): kernel functions.\n        linear (bool): whether use the linear version of DAN. Default: False\n\n    Inputs:\n        - z_s (tensor): activations from the source domain, :math:`z^s`\n        - z_t (tensor): activations from the target domain, :math:`z^t`\n\n    Shape:\n        - Inputs: :math:`(minibatch, *)`  where * means any dimension\n        - Outputs: scalar\n\n    .. note::\n        Activations :math:`z^{s}` and :math:`z^{t}` must have the same shape.\n\n    .. note::\n        The kernel values will add up when there are multiple kernels.\n\n    Examples::\n\n        >>> from tllib.modules.kernels import GaussianKernel\n        >>> feature_dim = 1024\n        >>> batch_size = 10\n        >>> kernels = (GaussianKernel(alpha=0.5), GaussianKernel(alpha=1.), GaussianKernel(alpha=2.))\n        >>> loss = MultipleKernelMaximumMeanDiscrepancy(kernels)\n        >>> # features from source domain and target domain\n        >>> z_s, z_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)\n        >>> output = loss(z_s, z_t)\n    \"\"\"\n\n    def __init__(self, kernels: Sequence[nn.Module], linear: Optional[bool] = False):\n        super(MultipleKernelMaximumMeanDiscrepancy, self).__init__()\n        self.kernels = kernels\n        self.index_matrix = None\n        self.linear = linear\n\n    def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:\n        features = torch.cat([z_s, z_t], dim=0)\n        batch_size = int(z_s.size(0))\n        self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s.device)\n\n\n        kernel_matrix = sum([kernel(features) for kernel in self.kernels])  # Add up the matrix of each kernel\n        # Add 2 / (n-1) to make up for the value on the diagonal\n        # to ensure loss is positive in the non-linear version\n        loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)\n\n        return loss\n\n\ndef _update_index_matrix(batch_size: int, index_matrix: Optional[torch.Tensor] = None,\n                         linear: Optional[bool] = True) -> torch.Tensor:\n    r\"\"\"\n    Update the `index_matrix` which convert `kernel_matrix` to loss.\n    If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`.\n    Else return a new tensor with shape (2 x batch_size, 2 x batch_size).\n    \"\"\"\n    if index_matrix is None or index_matrix.size(0) != batch_size * 2:\n        index_matrix = torch.zeros(2 * batch_size, 2 * batch_size)\n        if linear:\n            for i in range(batch_size):\n                s1, s2 = i, (i + 1) % batch_size\n                t1, t2 = s1 + batch_size, s2 + batch_size\n                index_matrix[s1, s2] = 1. / float(batch_size)\n                index_matrix[t1, t2] = 1. / float(batch_size)\n                index_matrix[s1, t2] = -1. / float(batch_size)\n                index_matrix[s2, t1] = -1. / float(batch_size)\n        else:\n            for i in range(batch_size):\n                for j in range(batch_size):\n                    if i != j:\n                        index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1))\n                        index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1))\n            for i in range(batch_size):\n                for j in range(batch_size):\n                    index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size)\n                    index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size)\n    return index_matrix\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(0.5)\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)"
  },
  {
    "path": "tllib/alignment/dann.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tllib.modules.grl import WarmStartGradientReverseLayer\nfrom tllib.modules.classifier import Classifier as ClassifierBase\nfrom tllib.utils.metric import binary_accuracy, accuracy\n\n__all__ = ['DomainAdversarialLoss']\n\n\nclass DomainAdversarialLoss(nn.Module):\n    r\"\"\"\n    The Domain Adversarial Loss proposed in\n    `Domain-Adversarial Training of Neural Networks (ICML 2015) <https://arxiv.org/abs/1505.07818>`_\n\n    Domain adversarial loss measures the domain discrepancy through training a domain discriminator.\n    Given domain discriminator :math:`D`, feature representation :math:`f`, the definition of DANN loss is\n\n    .. math::\n        loss(\\mathcal{D}_s, \\mathcal{D}_t) = \\mathbb{E}_{x_i^s \\sim \\mathcal{D}_s} \\text{log}[D(f_i^s)]\n            + \\mathbb{E}_{x_j^t \\sim \\mathcal{D}_t} \\text{log}[1-D(f_j^t)].\n\n    Args:\n        domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)\n        reduction (str, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n            ``'mean'``: the sum of the output will be divided by the number of\n            elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n        grl (WarmStartGradientReverseLayer, optional): Default: None.\n\n    Inputs:\n        - f_s (tensor): feature representations on source domain, :math:`f^s`\n        - f_t (tensor): feature representations on target domain, :math:`f^t`\n        - w_s (tensor, optional): a rescaling weight given to each instance from source domain.\n        - w_t (tensor, optional): a rescaling weight given to each instance from target domain.\n\n    Shape:\n        - f_s, f_t: :math:`(N, F)` where F means the dimension of input features.\n        - Outputs: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`.\n\n    Examples::\n\n        >>> from tllib.modules.domain_discriminator import DomainDiscriminator\n        >>> discriminator = DomainDiscriminator(in_feature=1024, hidden_size=1024)\n        >>> loss = DomainAdversarialLoss(discriminator, reduction='mean')\n        >>> # features from source domain and target domain\n        >>> f_s, f_t = torch.randn(20, 1024), torch.randn(20, 1024)\n        >>> # If you want to assign different weights to each instance, you should pass in w_s and w_t\n        >>> w_s, w_t = torch.randn(20), torch.randn(20)\n        >>> output = loss(f_s, f_t, w_s, w_t)\n    \"\"\"\n\n    def __init__(self, domain_discriminator: nn.Module, reduction: Optional[str] = 'mean',\n                 grl: Optional = None, sigmoid=True):\n        super(DomainAdversarialLoss, self).__init__()\n        self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True) if grl is None else grl\n        self.domain_discriminator = domain_discriminator\n        self.sigmoid = sigmoid\n        self.reduction = reduction\n        self.bce = lambda input, target, weight: \\\n            F.binary_cross_entropy(input, target, weight=weight, reduction=reduction)\n        self.domain_discriminator_accuracy = None\n\n    def forward(self, f_s: torch.Tensor, f_t: torch.Tensor,\n                w_s: Optional[torch.Tensor] = None, w_t: Optional[torch.Tensor] = None) -> torch.Tensor:\n        f = self.grl(torch.cat((f_s, f_t), dim=0))\n        d = self.domain_discriminator(f)\n        if self.sigmoid:\n            d_s, d_t = d.chunk(2, dim=0)\n            d_label_s = torch.ones((f_s.size(0), 1)).to(f_s.device)\n            d_label_t = torch.zeros((f_t.size(0), 1)).to(f_t.device)\n            self.domain_discriminator_accuracy = 0.5 * (\n                        binary_accuracy(d_s, d_label_s) + binary_accuracy(d_t, d_label_t))\n\n            if w_s is None:\n                w_s = torch.ones_like(d_label_s)\n            if w_t is None:\n                w_t = torch.ones_like(d_label_t)\n            return 0.5 * (\n                F.binary_cross_entropy(d_s, d_label_s, weight=w_s.view_as(d_s), reduction=self.reduction) +\n                F.binary_cross_entropy(d_t, d_label_t, weight=w_t.view_as(d_t), reduction=self.reduction)\n            )\n        else:\n            d_label = torch.cat((\n                torch.ones((f_s.size(0),)).to(f_s.device),\n                torch.zeros((f_t.size(0),)).to(f_t.device),\n            )).long()\n            if w_s is None:\n                w_s = torch.ones((f_s.size(0),)).to(f_s.device)\n            if w_t is None:\n                w_t = torch.ones((f_t.size(0),)).to(f_t.device)\n            self.domain_discriminator_accuracy = accuracy(d, d_label)\n            loss = F.cross_entropy(d, d_label, reduction='none') * torch.cat([w_s, w_t], dim=0)\n            if self.reduction == \"mean\":\n                return loss.mean()\n            elif self.reduction == \"sum\":\n                return loss.sum()\n            elif self.reduction == \"none\":\n                return loss\n            else:\n                raise NotImplementedError(self.reduction)\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n"
  },
  {
    "path": "tllib/alignment/jan.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, Sequence\nimport torch\nimport torch.nn as nn\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\nfrom tllib.modules.grl import GradientReverseLayer\nfrom tllib.modules.kernels import GaussianKernel\nfrom tllib.alignment.dan import _update_index_matrix\n\n\n__all__ = ['JointMultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']\n\n\n\nclass JointMultipleKernelMaximumMeanDiscrepancy(nn.Module):\n    r\"\"\"The Joint Multiple Kernel Maximum Mean Discrepancy (JMMD) used in\n    `Deep Transfer Learning with Joint Adaptation Networks (ICML 2017) <https://arxiv.org/abs/1605.06636>`_\n\n    Given source domain :math:`\\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\\mathcal{D}_t`\n    of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate\n    activations in layers :math:`\\mathcal{L}` as :math:`\\{(z_i^{s1}, ..., z_i^{s|\\mathcal{L}|})\\}_{i=1}^{n_s}` and\n    :math:`\\{(z_i^{t1}, ..., z_i^{t|\\mathcal{L}|})\\}_{i=1}^{n_t}`. The empirical estimate of\n    :math:`\\hat{D}_{\\mathcal{L}}(P, Q)` is computed as the squared distance between the empirical kernel mean\n    embeddings as\n\n    .. math::\n        \\hat{D}_{\\mathcal{L}}(P, Q) &=\n        \\dfrac{1}{n_s^2} \\sum_{i=1}^{n_s}\\sum_{j=1}^{n_s} \\prod_{l\\in\\mathcal{L}} k^l(z_i^{sl}, z_j^{sl}) \\\\\n        &+ \\dfrac{1}{n_t^2} \\sum_{i=1}^{n_t}\\sum_{j=1}^{n_t} \\prod_{l\\in\\mathcal{L}} k^l(z_i^{tl}, z_j^{tl}) \\\\\n        &- \\dfrac{2}{n_s n_t} \\sum_{i=1}^{n_s}\\sum_{j=1}^{n_t} \\prod_{l\\in\\mathcal{L}} k^l(z_i^{sl}, z_j^{tl}). \\\\\n\n    Args:\n        kernels (tuple(tuple(torch.nn.Module))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\\mathcal{L}[r]}`.\n        linear (bool): whether use the linear version of JAN. Default: False\n        thetas (list(Theta): use adversarial version JAN if not None. Default: None\n\n    Inputs:\n        - z_s (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s`\n        - z_t (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t`\n\n    Shape:\n        - :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)`  where * means any dimension\n        - Outputs: scalar\n\n    .. note::\n        Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape.\n\n    .. note::\n        The kernel values will add up when there are multiple kernels for a certain layer.\n\n    Examples::\n\n        >>> feature_dim = 1024\n        >>> batch_size = 10\n        >>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.))\n        >>> layer2_kernels = (GaussianKernel(1.), )\n        >>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels))\n        >>> # layer1 features from source domain and target domain\n        >>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)\n        >>> # layer2 features from source domain and target domain\n        >>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)\n        >>> output = loss((z1_s, z2_s), (z1_t, z2_t))\n    \"\"\"\n\n    def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None):\n        super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__()\n        self.kernels = kernels\n        self.index_matrix = None\n        self.linear = linear\n        if thetas:\n            self.thetas = thetas\n        else:\n            self.thetas = [nn.Identity() for _ in kernels]\n\n    def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:\n        batch_size = int(z_s[0].size(0))\n        self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device)\n\n        kernel_matrix = torch.ones_like(self.index_matrix)\n        for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas):\n            layer_features = torch.cat([layer_z_s, layer_z_t], dim=0)\n            layer_features = theta(layer_features)\n            kernel_matrix *= sum(\n                [kernel(layer_features) for kernel in layer_kernels])  # Add up the matrix of each kernel\n\n        # Add 2 / (n-1) to make up for the value on the diagonal\n        # to ensure loss is positive in the non-linear version\n        loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)\n        return loss\n\n\nclass Theta(nn.Module):\n    \"\"\"\n    maximize loss respect to :math:`\\theta`\n    minimize loss respect to features\n    \"\"\"\n    def __init__(self, dim: int):\n        super(Theta, self).__init__()\n        self.grl1 = GradientReverseLayer()\n        self.grl2 = GradientReverseLayer()\n        self.layer1 = nn.Linear(dim, dim)\n        nn.init.eye_(self.layer1.weight)\n        nn.init.zeros_(self.layer1.bias)\n\n    def forward(self, features: torch.Tensor) -> torch.Tensor:\n        features = self.grl1(features)\n        return self.grl2(self.layer1(features))\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(0.5)\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)"
  },
  {
    "path": "tllib/alignment/mcd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport torch.nn as nn\nimport torch\n\n\ndef classifier_discrepancy(predictions1: torch.Tensor, predictions2: torch.Tensor) -> torch.Tensor:\n    r\"\"\"The `Classifier Discrepancy` in\n    `Maximum Classiﬁer Discrepancy for Unsupervised Domain Adaptation (CVPR 2018) <https://arxiv.org/abs/1712.02560>`_.\n\n    The classfier discrepancy between predictions :math:`p_1` and :math:`p_2` can be described as:\n\n    .. math::\n        d(p_1, p_2) = \\dfrac{1}{K} \\sum_{k=1}^K | p_{1k} - p_{2k} |,\n\n    where K is number of classes.\n\n    Args:\n        predictions1 (torch.Tensor): Classifier predictions :math:`p_1`. Expected to contain raw, normalized scores for each class\n        predictions2 (torch.Tensor): Classifier predictions :math:`p_2`\n    \"\"\"\n    return torch.mean(torch.abs(predictions1 - predictions2))\n\n\ndef entropy(predictions: torch.Tensor) -> torch.Tensor:\n    r\"\"\"Entropy of N predictions :math:`(p_1, p_2, ..., p_N)`.\n    The definition is:\n\n    .. math::\n        d(p_1, p_2, ..., p_N) = -\\dfrac{1}{K} \\sum_{k=1}^K \\log \\left( \\dfrac{1}{N} \\sum_{i=1}^N p_{ik} \\right)\n\n    where K is number of classes.\n\n    .. note::\n        This entropy function is specifically used in MCD and different from the usual :meth:`~tllib.modules.entropy.entropy` function.\n\n    Args:\n        predictions (torch.Tensor): Classifier predictions. Expected to contain raw, normalized scores for each class\n    \"\"\"\n    return -torch.mean(torch.log(torch.mean(predictions, 0) + 1e-6))\n\n\nclass ImageClassifierHead(nn.Module):\n    r\"\"\"Classifier Head for MCD.\n\n    Args:\n        in_features (int): Dimension of input features\n        num_classes (int): Number of classes\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024\n\n    Shape:\n        - Inputs: :math:`(minibatch, F)` where F = `in_features`.\n        - Output: :math:`(minibatch, C)` where C = `num_classes`.\n    \"\"\"\n\n    def __init__(self, in_features: int, num_classes: int, bottleneck_dim: Optional[int] = 1024, pool_layer=None):\n        super(ImageClassifierHead, self).__init__()\n        self.num_classes = num_classes\n        if pool_layer is None:\n            self.pool_layer = nn.Sequential(\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            )\n        else:\n            self.pool_layer = pool_layer\n        self.head = nn.Sequential(\n            nn.Dropout(0.5),\n            nn.Linear(in_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(0.5),\n            nn.Linear(bottleneck_dim, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Linear(bottleneck_dim, num_classes)\n        )\n\n    def forward(self, inputs: torch.Tensor) -> torch.Tensor:\n        return self.head(self.pool_layer(inputs))"
  },
  {
    "path": "tllib/alignment/mdd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, List, Dict, Tuple, Callable\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch\n\nfrom tllib.modules.grl import WarmStartGradientReverseLayer\n\n\nclass MarginDisparityDiscrepancy(nn.Module):\n    r\"\"\"The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) <https://arxiv.org/abs/1904.05801>`_.\n\n    MDD can measure the distribution discrepancy in domain adaptation.\n\n    The :math:`y^s` and :math:`y^t` are logits output by the main head on the source and target domain respectively.\n    The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial head.\n\n    The definition can be described as:\n\n    .. math::\n        \\mathcal{D}_{\\gamma}(\\hat{\\mathcal{S}}, \\hat{\\mathcal{T}}) =\n        -\\gamma \\mathbb{E}_{y^s, y_{adv}^s \\sim\\hat{\\mathcal{S}}} L_s (y^s, y_{adv}^s) +\n        \\mathbb{E}_{y^t, y_{adv}^t \\sim\\hat{\\mathcal{T}}} L_t (y^t, y_{adv}^t),\n\n    where :math:`\\gamma` is a margin hyper-parameter, :math:`L_s` refers to the disparity function defined on the source domain\n    and :math:`L_t` refers to the disparity function defined on the target domain.\n\n    Args:\n        source_disparity (callable): The disparity function defined on the source domain, :math:`L_s`.\n        target_disparity (callable): The disparity function defined on the target domain, :math:`L_t`.\n        margin (float): margin :math:`\\gamma`. Default: 4\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Inputs:\n        - y_s: output :math:`y^s` by the main head on the source domain\n        - y_s_adv: output :math:`y^s` by the adversarial head on the source domain\n        - y_t: output :math:`y^t` by the main head on the target domain\n        - y_t_adv: output :math:`y_{adv}^t` by the adversarial head on the target domain\n        - w_s (optional): instance weights for source domain\n        - w_t (optional): instance weights for target domain\n\n    Examples::\n\n        >>> num_outputs = 2\n        >>> batch_size = 10\n        >>> loss = MarginDisparityDiscrepancy(margin=4., source_disparity=F.l1_loss, target_disparity=F.l1_loss)\n        >>> # output from source domain and target domain\n        >>> y_s, y_t = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)\n        >>> # adversarial output from source domain and target domain\n        >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)\n        >>> output = loss(y_s, y_s_adv, y_t, y_t_adv)\n    \"\"\"\n\n    def __init__(self, source_disparity: Callable, target_disparity: Callable,\n                 margin: Optional[float] = 4, reduction: Optional[str] = 'mean'):\n        super(MarginDisparityDiscrepancy, self).__init__()\n        self.margin = margin\n        self.reduction = reduction\n        self.source_disparity = source_disparity\n        self.target_disparity = target_disparity\n\n    def forward(self, y_s: torch.Tensor, y_s_adv: torch.Tensor, y_t: torch.Tensor, y_t_adv: torch.Tensor,\n                w_s: Optional[torch.Tensor] = None, w_t: Optional[torch.Tensor] = None) -> torch.Tensor:\n\n        source_loss = -self.margin * self.source_disparity(y_s, y_s_adv)\n        target_loss = self.target_disparity(y_t, y_t_adv)\n        if w_s is None:\n            w_s = torch.ones_like(source_loss)\n        source_loss = source_loss * w_s\n        if w_t is None:\n            w_t = torch.ones_like(target_loss)\n        target_loss = target_loss * w_t\n\n        loss = source_loss + target_loss\n        if self.reduction == 'mean':\n            loss = loss.mean()\n        elif self.reduction == 'sum':\n            loss = loss.sum()\n        return loss\n\n\nclass ClassificationMarginDisparityDiscrepancy(MarginDisparityDiscrepancy):\n    r\"\"\"\n    The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) <https://arxiv.org/abs/1904.05801>`_.\n\n    It measures the distribution discrepancy in domain adaptation\n    for classification.\n\n    When margin is equal to 1, it's also called disparity discrepancy (DD).\n\n    The :math:`y^s` and :math:`y^t` are logits output by the main classifier on the source and target domain respectively.\n    The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial classifier.\n    They are expected to contain raw, unnormalized scores for each class.\n\n    The definition can be described as:\n\n    .. math::\n        \\mathcal{D}_{\\gamma}(\\hat{\\mathcal{S}}, \\hat{\\mathcal{T}}) =\n        \\gamma \\mathbb{E}_{y^s, y_{adv}^s \\sim\\hat{\\mathcal{S}}} \\log\\left(\\frac{\\exp(y_{adv}^s[h_{y^s}])}{\\sum_j \\exp(y_{adv}^s[j])}\\right) +\n        \\mathbb{E}_{y^t, y_{adv}^t \\sim\\hat{\\mathcal{T}}} \\log\\left(1-\\frac{\\exp(y_{adv}^t[h_{y^t}])}{\\sum_j \\exp(y_{adv}^t[j])}\\right),\n\n    where :math:`\\gamma` is a margin hyper-parameter and :math:`h_y` refers to the predicted label when the logits output is :math:`y`.\n    You can see more details in `Bridging Theory and Algorithm for Domain Adaptation <https://arxiv.org/abs/1904.05801>`_.\n\n    Args:\n        margin (float): margin :math:`\\gamma`. Default: 4\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Inputs:\n        - y_s: logits output :math:`y^s` by the main classifier on the source domain\n        - y_s_adv: logits output :math:`y^s` by the adversarial classifier on the source domain\n        - y_t: logits output :math:`y^t` by the main classifier on the target domain\n        - y_t_adv: logits output :math:`y_{adv}^t` by the adversarial classifier on the target domain\n\n    Shape:\n        - Inputs: :math:`(minibatch, C)` where C = number of classes, or :math:`(minibatch, C, d_1, d_2, ..., d_K)`\n          with :math:`K \\geq 1` in the case of `K`-dimensional loss.\n        - Output: scalar. If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(minibatch)`, or\n          :math:`(minibatch, d_1, d_2, ..., d_K)` with :math:`K \\geq 1` in the case of K-dimensional loss.\n\n    Examples::\n\n        >>> num_classes = 2\n        >>> batch_size = 10\n        >>> loss = ClassificationMarginDisparityDiscrepancy(margin=4.)\n        >>> # logits output from source domain and target domain\n        >>> y_s, y_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)\n        >>> # adversarial logits output from source domain and target domain\n        >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)\n        >>> output = loss(y_s, y_s_adv, y_t, y_t_adv)\n    \"\"\"\n\n    def __init__(self, margin: Optional[float] = 4, **kwargs):\n        def source_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):\n            _, prediction = y.max(dim=1)\n            return F.cross_entropy(y_adv, prediction, reduction='none')\n\n        def target_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):\n            _, prediction = y.max(dim=1)\n            return -F.nll_loss(shift_log(1. - F.softmax(y_adv, dim=1)), prediction, reduction='none')\n\n        super(ClassificationMarginDisparityDiscrepancy, self).__init__(source_discrepancy, target_discrepancy, margin,\n                                                                       **kwargs)\n\n\nclass RegressionMarginDisparityDiscrepancy(MarginDisparityDiscrepancy):\n    r\"\"\"\n    The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) <https://arxiv.org/abs/1904.05801>`_.\n\n    It measures the distribution discrepancy in domain adaptation\n    for regression.\n\n    The :math:`y^s` and :math:`y^t` are logits output by the main regressor on the source and target domain respectively.\n    The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial regressor.\n    They are expected to contain ``normalized`` values for each factors.\n\n    The definition can be described as:\n\n    .. math::\n        \\mathcal{D}_{\\gamma}(\\hat{\\mathcal{S}}, \\hat{\\mathcal{T}}) =\n        -\\gamma \\mathbb{E}_{y^s, y_{adv}^s \\sim\\hat{\\mathcal{S}}} L (y^s, y_{adv}^s) +\n        \\mathbb{E}_{y^t, y_{adv}^t \\sim\\hat{\\mathcal{T}}} L (y^t, y_{adv}^t),\n\n    where :math:`\\gamma` is a margin hyper-parameter and :math:`L` refers to the disparity function defined on both domains.\n    You can see more details in `Bridging Theory and Algorithm for Domain Adaptation <https://arxiv.org/abs/1904.05801>`_.\n\n    Args:\n        loss_function (callable): The disparity function defined on both domains, :math:`L`.\n        margin (float): margin :math:`\\gamma`. Default: 1\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Inputs:\n        - y_s: logits output :math:`y^s` by the main regressor on the source domain\n        - y_s_adv: logits output :math:`y^s` by the adversarial regressor on the source domain\n        - y_t: logits output :math:`y^t` by the main regressor on the target domain\n        - y_t_adv: logits output :math:`y_{adv}^t` by the adversarial regressor on the target domain\n\n    Shape:\n        - Inputs: :math:`(minibatch, F)` where F = number of factors, or :math:`(minibatch, F, d_1, d_2, ..., d_K)`\n          with :math:`K \\geq 1` in the case of `K`-dimensional loss.\n        - Output: scalar. The same size as the target: :math:`(minibatch)`, or\n          :math:`(minibatch, d_1, d_2, ..., d_K)` with :math:`K \\geq 1` in the case of K-dimensional loss.\n\n    Examples::\n\n        >>> num_outputs = 2\n        >>> batch_size = 10\n        >>> loss = RegressionMarginDisparityDiscrepancy(margin=4., loss_function=F.l1_loss)\n        >>> # output from source domain and target domain\n        >>> y_s, y_t = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)\n        >>> # adversarial output from source domain and target domain\n        >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)\n        >>> output = loss(y_s, y_s_adv, y_t, y_t_adv)\n\n    \"\"\"\n\n    def __init__(self, margin: Optional[float] = 1, loss_function=F.l1_loss, **kwargs):\n        def source_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):\n            return loss_function(y_adv, y.detach(), reduction='none')\n\n        def target_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):\n            return loss_function(y_adv, y.detach(), reduction='none')\n\n        super(RegressionMarginDisparityDiscrepancy, self).__init__(source_discrepancy, target_discrepancy, margin,\n                                                                   **kwargs)\n\n\ndef shift_log(x: torch.Tensor, offset: Optional[float] = 1e-6) -> torch.Tensor:\n    r\"\"\"\n    First shift, then calculate log, which can be described as:\n\n    .. math::\n        y = \\max(\\log(x+\\text{offset}), 0)\n\n    Used to avoid the gradient explosion problem in log(x) function when x=0.\n\n    Args:\n        x (torch.Tensor): input tensor\n        offset (float, optional): offset size. Default: 1e-6\n\n    .. note::\n        Input tensor falls in [0., 1.] and the output tensor falls in [-log(offset), 0]\n    \"\"\"\n    return torch.log(torch.clamp(x + offset, max=1.))\n\n\nclass GeneralModule(nn.Module):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: nn.Module,\n                 head: nn.Module, adv_head: nn.Module, grl: Optional[WarmStartGradientReverseLayer] = None,\n                 finetune: Optional[bool] = True):\n        super(GeneralModule, self).__init__()\n        self.backbone = backbone\n        self.num_classes = num_classes\n        self.bottleneck = bottleneck\n        self.head = head\n        self.adv_head = adv_head\n        self.finetune = finetune\n        self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000,\n                                                       auto_step=False) if grl is None else grl\n\n    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\"\"\"\n        features = self.backbone(x)\n        features = self.bottleneck(features)\n        outputs = self.head(features)\n        features_adv = self.grl_layer(features)\n        outputs_adv = self.adv_head(features_adv)\n        if self.training:\n            return outputs, outputs_adv\n        else:\n            return outputs\n\n    def step(self):\n        \"\"\"\n        Gradually increase :math:`\\lambda` in GRL layer.\n        \"\"\"\n        self.grl_layer.step()\n\n    def get_parameters(self, base_lr=1.0) -> List[Dict]:\n        \"\"\"\n        Return a parameters list which decides optimization hyper-parameters,\n        such as the relative learning rate of each layer.\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": base_lr},\n            {\"params\": self.head.parameters(), \"lr\": base_lr},\n            {\"params\": self.adv_head.parameters(), \"lr\": base_lr}\n        ]\n        return params\n\n\nclass ImageClassifier(GeneralModule):\n    r\"\"\"Classifier for MDD.\n\n    Classifier for MDD has one backbone, one bottleneck, while two classifier heads.\n    The first classifier head is used for final predictions.\n    The adversarial classifier head is only used when calculating MarginDisparityDiscrepancy.\n\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 1-d features from data\n        num_classes (int): Number of classes\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024\n        width (int, optional): Feature dimension of the classifier head. Default: 1024\n        grl (nn.Module): Gradient reverse layer. Will use default parameters if None. Default: None.\n        finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True\n\n    Inputs:\n        - x (tensor): input data\n\n    Outputs:\n        - outputs: logits outputs by the main classifier\n        - outputs_adv: logits outputs by the adversarial classifier\n\n    Shape:\n        - x: :math:`(minibatch, *)`, same shape as the input of the `backbone`.\n        - outputs, outputs_adv: :math:`(minibatch, C)`, where C means the number of classes.\n\n    .. note::\n        Remember to call function `step()` after function `forward()` **during training phase**! For instance,\n\n            >>> # x is inputs, classifier is an ImageClassifier\n            >>> outputs, outputs_adv = classifier(x)\n            >>> classifier.step()\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int,\n                 bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024,\n                 grl: Optional[WarmStartGradientReverseLayer] = None, finetune=True, pool_layer=None):\n        grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000,\n                                                       auto_step=False) if grl is None else grl\n\n        if pool_layer is None:\n            pool_layer = nn.Sequential(\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            )\n        bottleneck = nn.Sequential(\n            pool_layer,\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(0.5)\n        )\n        bottleneck[1].weight.data.normal_(0, 0.005)\n        bottleneck[1].bias.data.fill_(0.1)\n\n        # The classifier head used for final predictions.\n        head = nn.Sequential(\n            nn.Linear(bottleneck_dim, width),\n            nn.ReLU(),\n            nn.Dropout(0.5),\n            nn.Linear(width, num_classes)\n        )\n        # The adversarial classifier head\n        adv_head = nn.Sequential(\n            nn.Linear(bottleneck_dim, width),\n            nn.ReLU(),\n            nn.Dropout(0.5),\n            nn.Linear(width, num_classes)\n        )\n        for dep in range(2):\n            head[dep * 3].weight.data.normal_(0, 0.01)\n            head[dep * 3].bias.data.fill_(0.0)\n            adv_head[dep * 3].weight.data.normal_(0, 0.01)\n            adv_head[dep * 3].bias.data.fill_(0.0)\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck,\n                                              head, adv_head, grl_layer, finetune)\n\n\nclass ImageRegressor(GeneralModule):\n    r\"\"\"Regressor for MDD.\n\n    Regressor for MDD has one backbone, one bottleneck, while two regressor heads.\n    The first regressor head is used for final predictions.\n    The adversarial regressor head is only used when calculating MarginDisparityDiscrepancy.\n\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 1-d features from data\n        num_factors (int): Number of factors\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024\n        width (int, optional): Feature dimension of the classifier head. Default: 1024\n        finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True\n\n    Inputs:\n        - x (Tensor): input data\n\n    Outputs: (outputs, outputs_adv)\n        - outputs: outputs by the main regressor\n        - outputs_adv: outputs by the adversarial regressor\n\n    Shape:\n        - x: :math:`(minibatch, *)`, same shape as the input of the `backbone`.\n        - outputs, outputs_adv: :math:`(minibatch, F)`, where F means the number of factors.\n\n    .. note::\n        Remember to call function `step()` after function `forward()` **during training phase**! For instance,\n\n            >>> # x is inputs, regressor is an ImageRegressor\n            >>> outputs, outputs_adv = regressor(x)\n            >>> regressor.step()\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_factors: int, bottleneck = None, head=None, adv_head=None,\n                 bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024, finetune=True):\n        grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False)\n        if bottleneck is None:\n            bottleneck = nn.Sequential(\n                nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1),\n                nn.BatchNorm2d(bottleneck_dim),\n                nn.ReLU(),\n            )\n\n        # The regressor head used for final predictions.\n        if head is None:\n            head = nn.Sequential(\n                nn.Conv2d(bottleneck_dim, width, kernel_size=3, stride=1, padding=1),\n                nn.BatchNorm2d(width),\n                nn.ReLU(),\n                nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1),\n                nn.BatchNorm2d(width),\n                nn.ReLU(),\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten(),\n                nn.Linear(width, num_factors),\n                nn.Sigmoid()\n            )\n            for layer in head:\n                if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n                    nn.init.normal_(layer.weight, 0, 0.01)\n                    nn.init.constant_(layer.bias, 0)\n        # The adversarial regressor head\n        if adv_head is None:\n            adv_head = nn.Sequential(\n                nn.Conv2d(bottleneck_dim, width, kernel_size=3, stride=1, padding=1),\n                nn.BatchNorm2d(width),\n                nn.ReLU(),\n                nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1),\n                nn.BatchNorm2d(width),\n                nn.ReLU(),\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten(),\n                nn.Linear(width, num_factors),\n                nn.Sigmoid()\n            )\n            for layer in adv_head:\n                if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):\n                    nn.init.normal_(layer.weight, 0, 0.01)\n                    nn.init.constant_(layer.bias, 0)\n        super(ImageRegressor, self).__init__(backbone, num_factors, bottleneck,\n                                              head, adv_head, grl_layer, finetune)\n        self.num_factors = num_factors\n"
  },
  {
    "path": "tllib/alignment/osbp.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\nfrom tllib.modules.grl import GradientReverseLayer\n\n\nclass UnknownClassBinaryCrossEntropy(nn.Module):\n    r\"\"\"\n    Binary cross entropy loss to make a boundary for unknown samples, proposed by\n    `Open Set Domain Adaptation by Backpropagation (ECCV 2018) <https://arxiv.org/abs/1804.10427>`_.\n\n    Given a sample on target domain :math:`x_t` and its classifcation outputs :math:`y`, the binary cross entropy\n    loss is defined as\n\n    .. math::\n        L_{\\text{adv}}(x_t) = -t \\text{log}(p(y=C+1|x_t)) - (1-t)\\text{log}(1-p(y=C+1|x_t))\n\n    where t is a hyper-parameter and C is the number of known classes.\n\n    Args:\n        t (float): Predefined hyper-parameter. Default: 0.5\n\n    Inputs:\n        - y (tensor): classification outputs (before softmax).\n\n    Shape:\n        - y: :math:`(minibatch, C+1)`  where C is the number of known classes.\n        - Outputs: scalar\n\n    \"\"\"\n    def __init__(self, t: Optional[float]=0.5):\n        super(UnknownClassBinaryCrossEntropy, self).__init__()\n        self.t = t\n\n    def forward(self, y):\n        # y : N x (C+1)\n        softmax_output = F.softmax(y, dim=1)\n        unknown_class_prob = softmax_output[:, -1].contiguous().view(-1, 1)\n        known_class_prob = 1. - unknown_class_prob\n\n        unknown_target = torch.ones((y.size(0), 1)).to(y.device) * self.t\n        known_target = 1. - unknown_target\n        return - torch.mean(unknown_target * torch.log(unknown_class_prob + 1e-6)) \\\n               - torch.mean(known_target * torch.log(known_class_prob + 1e-6))\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(),\n            nn.Linear(bottleneck_dim, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n        self.grl = GradientReverseLayer()\n\n    def forward(self, x: torch.Tensor, grad_reverse: Optional[bool] = False):\n        features = self.pool_layer(self.backbone(x))\n        features = self.bottleneck(features)\n        if grad_reverse:\n            features = self.grl(features)\n        outputs = self.head(features)\n        if self.training:\n            return outputs, features\n        else:\n            return outputs\n\n\n"
  },
  {
    "path": "tllib/alignment/regda.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom tllib.modules.gl import WarmStartGradientLayer\nfrom tllib.utils.metric.keypoint_detection import get_max_preds\n\n\nclass FastPseudoLabelGenerator2d(nn.Module):\n    def __init__(self, sigma=2):\n        super().__init__()\n        self.sigma = sigma\n    \n    def forward(self, heatmap: torch.Tensor):\n        heatmap = heatmap.detach()\n        height, width = heatmap.shape[-2:]\n        idx = heatmap.flatten(-2).argmax(dim=-1) # B, K\n        pred_h, pred_w = idx.div(width, rounding_mode='floor'), idx.remainder(width) # B, K\n        delta_h = torch.arange(height, device=heatmap.device) - pred_h.unsqueeze(-1) # B, K, H\n        delta_w = torch.arange(width, device=heatmap.device) - pred_w.unsqueeze(-1) # B, K, W\n        gaussian = (delta_h.square().unsqueeze(-1) + delta_w.square().unsqueeze(-2)).div(-2 * self.sigma * self.sigma).exp() # B, K, H, W\n        ground_truth = F.threshold(gaussian, threshold=1e-2, value=0.)\n\n        ground_false = (ground_truth.sum(dim=1, keepdim=True) - ground_truth).clamp(0., 1.)\n        return ground_truth, ground_false\n\n\nclass PseudoLabelGenerator2d(nn.Module):\n    \"\"\"\n    Generate ground truth heatmap and ground false heatmap from a prediction.\n\n    Args:\n        num_keypoints (int): Number of keypoints\n        height (int): height of the heatmap. Default: 64\n        width (int): width of the heatmap. Default: 64\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\n\n    Inputs:\n        - y: predicted heatmap\n\n    Outputs:\n        - ground_truth: heatmap conforming to Gaussian distribution\n        - ground_false: ground false heatmap\n\n    Shape:\n        - y: :math:`(minibatch, K, H, W)` where K means the number of keypoints,\n          H and W is the height and width of the heatmap respectively.\n        - ground_truth: :math:`(minibatch, K, H, W)`\n        - ground_false: :math:`(minibatch, K, H, W)`\n    \"\"\"\n    def __init__(self, num_keypoints, height=64, width=64, sigma=2):\n        super(PseudoLabelGenerator2d, self).__init__()\n        self.height = height\n        self.width = width\n        self.sigma = sigma\n\n        heatmaps = np.zeros((width, height, height, width), dtype=np.float32)\n\n        tmp_size = sigma * 3\n        for mu_x in range(width):\n            for mu_y in range(height):\n                # Check that any part of the gaussian is in-bounds\n                ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]\n                br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]\n\n                # Generate gaussian\n                size = 2 * tmp_size + 1\n                x = np.arange(0, size, 1, np.float32)\n                y = x[:, np.newaxis]\n                x0 = y0 = size // 2\n                # The gaussian is not normalized, we want the center value to equal 1\n                g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))\n\n                # Usable gaussian range\n                g_x = max(0, -ul[0]), min(br[0], width) - ul[0]\n                g_y = max(0, -ul[1]), min(br[1], height) - ul[1]\n                # Image range\n                img_x = max(0, ul[0]), min(br[0], width)\n                img_y = max(0, ul[1]), min(br[1], height)\n\n                heatmaps[mu_x][mu_y][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \\\n                    g[g_y[0]:g_y[1], g_x[0]:g_x[1]]\n\n        self.heatmaps = heatmaps\n        self.false_matrix = 1. - np.eye(num_keypoints, dtype=np.float32)\n\n    def forward(self, y):\n        B, K, H, W = y.shape\n        y = y.detach()\n        preds, max_vals = get_max_preds(y.cpu().numpy())  # B x K x (x, y)\n        preds = preds.reshape(-1, 2).astype(np.int)\n        ground_truth = self.heatmaps[preds[:, 0], preds[:, 1], :, :].copy().reshape(B, K, H, W).copy()\n\n        ground_false = ground_truth.reshape(B, K, -1).transpose((0, 2, 1))\n        ground_false = ground_false.dot(self.false_matrix).clip(max=1., min=0.).transpose((0, 2, 1)).reshape(B, K, H, W).copy()\n        return torch.from_numpy(ground_truth).to(y.device), torch.from_numpy(ground_false).to(y.device)\n\n\nclass RegressionDisparity(nn.Module):\n    \"\"\"\n    Regression Disparity proposed by `Regressive Domain Adaptation for Unsupervised Keypoint Detection (CVPR 2021) <https://arxiv.org/abs/2103.06175>`_.\n\n    Args:\n        pseudo_label_generator (PseudoLabelGenerator2d): generate ground truth heatmap and ground false heatmap\n          from a prediction.\n        criterion (torch.nn.Module): the loss function to calculate distance between two predictions.\n\n    Inputs:\n        - y: output by the main head\n        - y_adv: output by the adversarial head\n        - weight (optional): instance weights\n        - mode (str): whether minimize the disparity or maximize the disparity. Choices includes ``min``, ``max``.\n          Default: ``min``.\n\n    Shape:\n        - y: :math:`(minibatch, K, H, W)` where K means the number of keypoints,\n          H and W is the height and width of the heatmap respectively.\n        - y_adv: :math:`(minibatch, K, H, W)`\n        - weight: :math:`(minibatch, K)`.\n        - Output: depends on the ``criterion``.\n\n    Examples::\n\n        >>> num_keypoints = 5\n        >>> batch_size = 10\n        >>> H = W = 64\n        >>> pseudo_label_generator = PseudoLabelGenerator2d(num_keypoints)\n        >>> from tllibvision.models.keypoint_detection.loss import JointsKLLoss\n        >>> loss = RegressionDisparity(pseudo_label_generator, JointsKLLoss())\n        >>> # output from source domain and target domain\n        >>> y_s, y_t = torch.randn(batch_size, num_keypoints, H, W), torch.randn(batch_size, num_keypoints, H, W)\n        >>> # adversarial output from source domain and target domain\n        >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_keypoints, H, W), torch.randn(batch_size, num_keypoints, H, W)\n        >>> # minimize regression disparity on source domain\n        >>> output = loss(y_s, y_s_adv, mode='min')\n        >>> # maximize regression disparity on target domain\n        >>> output = loss(y_t, y_t_adv, mode='max')\n    \"\"\"\n    def __init__(self, pseudo_label_generator: PseudoLabelGenerator2d, criterion: nn.Module):\n        super(RegressionDisparity, self).__init__()\n        self.criterion = criterion\n        self.pseudo_label_generator = pseudo_label_generator\n\n    def forward(self, y, y_adv, weight=None, mode='min'):\n        assert mode in ['min', 'max']\n        ground_truth, ground_false = self.pseudo_label_generator(y.detach())\n        self.ground_truth = ground_truth\n        self.ground_false = ground_false\n        if mode == 'min':\n            return self.criterion(y_adv, ground_truth, weight)\n        else:\n            return self.criterion(y_adv, ground_false, weight)\n\n\nclass PoseResNet2d(nn.Module):\n    \"\"\"\n    Pose ResNet for RegDA has one backbone, one upsampling, while two regression heads.\n\n    Args:\n        backbone (torch.nn.Module): Backbone to extract 2-d features from data\n        upsampling (torch.nn.Module): Layer to upsample image feature to heatmap size\n        feature_dim (int): The dimension of the features from upsampling layer.\n        num_keypoints (int): Number of keypoints\n        gl (WarmStartGradientLayer):\n        finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True\n        num_head_layers (int): Number of head layers. Default: 2\n\n    Inputs:\n        - x (tensor): input data\n\n    Outputs:\n        - outputs: logits outputs by the main regressor\n        - outputs_adv: logits outputs by the adversarial regressor\n\n    Shape:\n        - x: :math:`(minibatch, *)`, same shape as the input of the `backbone`.\n        - outputs, outputs_adv: :math:`(minibatch, K, H, W)`, where K means the number of keypoints.\n\n    .. note::\n        Remember to call function `step()` after function `forward()` **during training phase**! For instance,\n\n            >>> # x is inputs, model is an PoseResNet\n            >>> outputs, outputs_adv = model(x)\n            >>> model.step()\n    \"\"\"\n    def __init__(self, backbone, upsampling, feature_dim, num_keypoints,\n                 gl: Optional[WarmStartGradientLayer] = None, finetune: Optional[bool] = True, num_head_layers=2):\n        super(PoseResNet2d, self).__init__()\n        self.backbone = backbone\n        self.upsampling = upsampling\n        self.head = self._make_head(num_head_layers, feature_dim, num_keypoints)\n        self.head_adv = self._make_head(num_head_layers, feature_dim, num_keypoints)\n        self.finetune = finetune\n        self.gl_layer = WarmStartGradientLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False) if gl is None else gl\n\n    @staticmethod\n    def _make_head(num_layers, channel_dim, num_keypoints):\n        layers = []\n        for i in range(num_layers-1):\n            layers.extend([\n                nn.Conv2d(channel_dim, channel_dim, 3, 1, 1),\n                nn.BatchNorm2d(channel_dim),\n                nn.ReLU(),\n            ])\n        layers.append(\n            nn.Conv2d(\n                in_channels=channel_dim,\n                out_channels=num_keypoints,\n                kernel_size=1,\n                stride=1,\n                padding=0\n            )\n        )\n        layers = nn.Sequential(*layers)\n        for m in layers.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.normal_(m.weight, std=0.001)\n                nn.init.constant_(m.bias, 0)\n        return layers\n\n    def forward(self, x):\n        x = self.backbone(x)\n        f = self.upsampling(x)\n        f_adv = self.gl_layer(f)\n        y = self.head(f)\n        y_adv = self.head_adv(f_adv)\n\n        if self.training:\n            return y, y_adv\n        else:\n            return y\n\n    def get_parameters(self, lr=1.):\n        return [\n            {'params': self.backbone.parameters(), 'lr': 0.1 * lr if self.finetune else lr},\n            {'params': self.upsampling.parameters(), 'lr': lr},\n            {'params': self.head.parameters(), 'lr': lr},\n            {'params': self.head_adv.parameters(), 'lr': lr},\n        ]\n\n    def step(self):\n        \"\"\"Call step() each iteration during training.\n        Will increase :math:`\\lambda` in GL layer.\n        \"\"\"\n        self.gl_layer.step()"
  },
  {
    "path": "tllib/alignment/rsd.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport torch.nn as nn\r\nimport torch\r\n\r\n\r\nclass RepresentationSubspaceDistance(nn.Module):\r\n    \"\"\"\r\n    `Representation Subspace Distance (ICML 2021) <http://ise.thss.tsinghua.edu.cn/~mlong/doc/Representation-Subspace-Distance-for-Domain-Adaptation-Regression-icml21.pdf>`_\r\n\r\n    Args:\r\n        trade_off (float):  The trade-off value between Representation Subspace Distance\r\n            and Base Mismatch Penalization. Default: 0.1\r\n\r\n    Inputs:\r\n        - f_s (tensor): feature representations on source domain, :math:`f^s`\r\n        - f_t (tensor): feature representations on target domain, :math:`f^t`\r\n\r\n    \"\"\"\r\n    def __init__(self, trade_off=0.1):\r\n        super(RepresentationSubspaceDistance, self).__init__()\r\n        self.trade_off = trade_off\r\n\r\n    def forward(self, f_s, f_t):\r\n        U_s, _, _ = torch.svd(f_s.t())\r\n        U_t, _, _ = torch.svd(f_t.t())\r\n        P_s, cosine, P_t = torch.svd(torch.mm(U_s.t(), U_t))\r\n        sine = torch.sqrt(1 - torch.pow(cosine, 2))\r\n        rsd = torch.norm(sine, 1)  # Representation Subspace Distance\r\n        bmp = torch.norm(torch.abs(P_s) - torch.abs(P_t), 2)  # Base Mismatch Penalization\r\n        return rsd + self.trade_off * bmp"
  },
  {
    "path": "tllib/modules/__init__.py",
    "content": "from .classifier import *\nfrom .regressor import *\nfrom .grl import *\nfrom .domain_discriminator import *\nfrom .kernels import *\nfrom .entropy import *\n\n__all__ = ['classifier', 'regressor', 'grl', 'kernels', 'domain_discriminator', 'entropy']"
  },
  {
    "path": "tllib/modules/classifier.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Tuple, Optional, List, Dict\nimport torch.nn as nn\nimport torch\n\n__all__ = ['Classifier']\n\n\nclass Classifier(nn.Module):\n    \"\"\"A generic Classifier class for domain adaptation.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data\n        num_classes (int): Number of classes\n        bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1\n        head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default\n        finetune (bool): Whether finetune the classifier or train from scratch. Default: True\n\n    .. note::\n        Different classifiers are used in different domain adaptation algorithms to achieve better accuracy\n        respectively, and we provide a suggested `Classifier` for different algorithms.\n        Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with\n        the domain adaptation algorithm in this algorithm library.\n\n    .. note::\n        The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy\n        by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`.\n\n    Inputs:\n        - x (tensor): input data fed to `backbone`\n\n    Outputs:\n        - predictions: classifier's predictions\n        - features: features after `bottleneck` layer and before `head` layer\n\n    Shape:\n        - Inputs: (minibatch, *) where * means, any number of additional dimensions\n        - predictions: (minibatch, `num_classes`)\n        - features: (minibatch, `features_dim`)\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,\n                 bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None):\n        super(Classifier, self).__init__()\n        self.backbone = backbone\n        self.num_classes = num_classes\n        if pool_layer is None:\n            self.pool_layer = nn.Sequential(\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            )\n        else:\n            self.pool_layer = pool_layer\n        if bottleneck is None:\n            self.bottleneck = nn.Identity()\n            self._features_dim = backbone.out_features\n        else:\n            self.bottleneck = bottleneck\n            assert bottleneck_dim > 0\n            self._features_dim = bottleneck_dim\n\n        if head is None:\n            self.head = nn.Linear(self._features_dim, num_classes)\n        else:\n            self.head = head\n        self.finetune = finetune\n\n    @property\n    def features_dim(self) -> int:\n        \"\"\"The dimension of features before the final `head` layer\"\"\"\n        return self._features_dim\n\n    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\"\"\"\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        predictions = self.head(f)\n        if self.training:\n            return predictions, f\n        else:\n            return predictions\n\n    def get_parameters(self, base_lr=1.0) -> List[Dict]:\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr},\n        ]\n\n        return params\n\n\nclass ImageClassifier(Classifier):\n    pass\n"
  },
  {
    "path": "tllib/modules/domain_discriminator.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import List, Dict\nimport torch.nn as nn\n\n__all__ = ['DomainDiscriminator']\n\n\nclass DomainDiscriminator(nn.Sequential):\n    r\"\"\"Domain discriminator model from\n    `Domain-Adversarial Training of Neural Networks (ICML 2015) <https://arxiv.org/abs/1505.07818>`_\n\n    Distinguish whether the input features come from the source domain or the target domain.\n    The source domain label is 1 and the target domain label is 0.\n\n    Args:\n        in_feature (int): dimension of the input feature\n        hidden_size (int): dimension of the hidden features\n        batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`.\n            Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True.\n\n    Shape:\n        - Inputs: (minibatch, `in_feature`)\n        - Outputs: :math:`(minibatch, 1)`\n    \"\"\"\n\n    def __init__(self, in_feature: int, hidden_size: int, batch_norm=True, sigmoid=True):\n        if sigmoid:\n            final_layer = nn.Sequential(\n                nn.Linear(hidden_size, 1),\n                nn.Sigmoid()\n            )\n        else:\n            final_layer = nn.Linear(hidden_size, 2)\n        if batch_norm:\n            super(DomainDiscriminator, self).__init__(\n                nn.Linear(in_feature, hidden_size),\n                nn.BatchNorm1d(hidden_size),\n                nn.ReLU(),\n                nn.Linear(hidden_size, hidden_size),\n                nn.BatchNorm1d(hidden_size),\n                nn.ReLU(),\n                final_layer\n            )\n        else:\n            super(DomainDiscriminator, self).__init__(\n                nn.Linear(in_feature, hidden_size),\n                nn.ReLU(inplace=True),\n                nn.Dropout(0.5),\n                nn.Linear(hidden_size, hidden_size),\n                nn.ReLU(inplace=True),\n                nn.Dropout(0.5),\n                final_layer\n            )\n\n    def get_parameters(self) -> List[Dict]:\n        return [{\"params\": self.parameters(), \"lr\": 1.}]\n\n\n"
  },
  {
    "path": "tllib/modules/entropy.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch\n\n\ndef entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor:\n    r\"\"\"Entropy of prediction.\n    The definition is:\n\n    .. math::\n        entropy(p) = - \\sum_{c=1}^C p_c \\log p_c\n\n    where C is number of classes.\n\n    Args:\n        predictions (tensor): Classifier predictions. Expected to contain raw, normalized scores for each class\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output. Default: ``'mean'``\n\n    Shape:\n        - predictions: :math:`(minibatch, C)` where C means the number of classes.\n        - Output: :math:`(minibatch, )` by default. If :attr:`reduction` is ``'mean'``, then scalar.\n    \"\"\"\n    epsilon = 1e-5\n    H = -predictions * torch.log(predictions + epsilon)\n    H = H.sum(dim=1)\n    if reduction == 'mean':\n        return H.mean()\n    else:\n        return H\n"
  },
  {
    "path": "tllib/modules/gl.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, Any, Tuple\nimport numpy as np\nimport torch.nn as nn\nfrom torch.autograd import Function\nimport torch\n\n\nclass GradientFunction(Function):\n\n    @staticmethod\n    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:\n        ctx.coeff = coeff\n        output = input * 1.0\n        return output\n\n    @staticmethod\n    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:\n        return grad_output * ctx.coeff, None\n\n\nclass WarmStartGradientLayer(nn.Module):\n    \"\"\"Warm Start Gradient Layer :math:`\\mathcal{R}(x)` with warm start\n\n        The forward and backward behaviours are:\n\n        .. math::\n            \\mathcal{R}(x) = x,\n\n            \\dfrac{ d\\mathcal{R}} {dx} = \\lambda I.\n\n        :math:`\\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule:\n\n        .. math::\n            \\lambda = \\dfrac{2(hi-lo)}{1+\\exp(- α \\dfrac{i}{N})} - (hi-lo) + lo\n\n        where :math:`i` is the iteration step.\n\n        Parameters:\n            - **alpha** (float, optional): :math:`α`. Default: 1.0\n            - **lo** (float, optional): Initial value of :math:`\\lambda`. Default: 0.0\n            - **hi** (float, optional): Final value of :math:`\\lambda`. Default: 1.0\n            - **max_iters** (int, optional): :math:`N`. Default: 1000\n            - **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called.\n              Otherwise use function `step` to increase :math:`i`. Default: False\n        \"\"\"\n\n    def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,\n                 max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False):\n        super(WarmStartGradientLayer, self).__init__()\n        self.alpha = alpha\n        self.lo = lo\n        self.hi = hi\n        self.iter_num = 0\n        self.max_iters = max_iters\n        self.auto_step = auto_step\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        \"\"\"\"\"\"\n        coeff = np.float(\n            2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))\n            - (self.hi - self.lo) + self.lo\n        )\n        if self.auto_step:\n            self.step()\n        return GradientFunction.apply(input, coeff)\n\n    def step(self):\n        \"\"\"Increase iteration number :math:`i` by 1\"\"\"\n        self.iter_num += 1\n"
  },
  {
    "path": "tllib/modules/grl.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, Any, Tuple\nimport numpy as np\nimport torch.nn as nn\nfrom torch.autograd import Function\nimport torch\n\n\nclass GradientReverseFunction(Function):\n\n    @staticmethod\n    def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:\n        ctx.coeff = coeff\n        output = input * 1.0\n        return output\n\n    @staticmethod\n    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:\n        return grad_output.neg() * ctx.coeff, None\n\n\nclass GradientReverseLayer(nn.Module):\n    def __init__(self):\n        super(GradientReverseLayer, self).__init__()\n\n    def forward(self, *input):\n        return GradientReverseFunction.apply(*input)\n\n\nclass WarmStartGradientReverseLayer(nn.Module):\n    \"\"\"Gradient Reverse Layer :math:`\\mathcal{R}(x)` with warm start\n\n        The forward and backward behaviours are:\n\n        .. math::\n            \\mathcal{R}(x) = x,\n\n            \\dfrac{ d\\mathcal{R}} {dx} = - \\lambda I.\n\n        :math:`\\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule:\n\n        .. math::\n            \\lambda = \\dfrac{2(hi-lo)}{1+\\exp(- α \\dfrac{i}{N})} - (hi-lo) + lo\n\n        where :math:`i` is the iteration step.\n\n        Args:\n            alpha (float, optional): :math:`α`. Default: 1.0\n            lo (float, optional): Initial value of :math:`\\lambda`. Default: 0.0\n            hi (float, optional): Final value of :math:`\\lambda`. Default: 1.0\n            max_iters (int, optional): :math:`N`. Default: 1000\n            auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called.\n              Otherwise use function `step` to increase :math:`i`. Default: False\n        \"\"\"\n\n    def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,\n                 max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False):\n        super(WarmStartGradientReverseLayer, self).__init__()\n        self.alpha = alpha\n        self.lo = lo\n        self.hi = hi\n        self.iter_num = 0\n        self.max_iters = max_iters\n        self.auto_step = auto_step\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        \"\"\"\"\"\"\n        coeff = np.float(\n            2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))\n            - (self.hi - self.lo) + self.lo\n        )\n        if self.auto_step:\n            self.step()\n        return GradientReverseFunction.apply(input, coeff)\n\n    def step(self):\n        \"\"\"Increase iteration number :math:`i` by 1\"\"\"\n        self.iter_num += 1\n"
  },
  {
    "path": "tllib/modules/kernels.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\n\n\n__all__ = ['GaussianKernel']\n\n\nclass GaussianKernel(nn.Module):\n    r\"\"\"Gaussian Kernel Matrix\n\n    Gaussian Kernel k is defined by\n\n    .. math::\n        k(x_1, x_2) = \\exp \\left( - \\dfrac{\\| x_1 - x_2 \\|^2}{2\\sigma^2} \\right)\n\n    where :math:`x_1, x_2 \\in R^d` are 1-d tensors.\n\n    Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),`\n\n    .. math::\n        K(X)_{i,j} = k(x_i, x_j)\n\n    Also by default, during training this layer keeps running estimates of the\n    mean of L2 distances, which are then used to set hyperparameter  :math:`\\sigma`.\n    Mathematically, the estimation is :math:`\\sigma^2 = \\dfrac{\\alpha}{n^2}\\sum_{i,j} \\| x_i - x_j \\|^2`.\n    If :attr:`track_running_stats` is set to ``False``, this layer then does not\n    keep running estimates, and use a fixed :math:`\\sigma` instead.\n\n    Args:\n        sigma (float, optional): bandwidth :math:`\\sigma`. Default: None\n        track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\\sigma^2`.\n          Otherwise, it won't track such statistics and always uses fix :math:`\\sigma^2`. Default: ``True``\n        alpha (float, optional): :math:`\\alpha` which decides the magnitude of :math:`\\sigma^2` when track_running_stats is set to ``True``\n\n    Inputs:\n        - X (tensor): input group :math:`X`\n\n    Shape:\n        - Inputs: :math:`(minibatch, F)` where F means the dimension of input features.\n        - Outputs: :math:`(minibatch, minibatch)`\n    \"\"\"\n\n    def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True,\n                 alpha: Optional[float] = 1.):\n        super(GaussianKernel, self).__init__()\n        assert track_running_stats or sigma is not None\n        self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None\n        self.track_running_stats = track_running_stats\n        self.alpha = alpha\n\n    def forward(self, X: torch.Tensor) -> torch.Tensor:\n        l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2)\n\n        if self.track_running_stats:\n            self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach())\n\n        return torch.exp(-l2_distance_square / (2 * self.sigma_square))"
  },
  {
    "path": "tllib/modules/loss.py",
    "content": "import torch.nn as nn\r\nimport torch\r\nimport torch.nn.functional as F\r\n\r\n\r\n# version 1: use torch.autograd\r\nclass LabelSmoothSoftmaxCEV1(nn.Module):\r\n    '''\r\n    Adapted from https://github.com/CoinCheung/pytorch-loss\r\n    '''\r\n\r\n    def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-1):\r\n        super(LabelSmoothSoftmaxCEV1, self).__init__()\r\n        self.lb_smooth = lb_smooth\r\n        self.reduction = reduction\r\n        self.lb_ignore = ignore_index\r\n        self.log_softmax = nn.LogSoftmax(dim=1)\r\n\r\n    def forward(self, input, target):\r\n        '''\r\n        Same usage method as nn.CrossEntropyLoss:\r\n            >>> criteria = LabelSmoothSoftmaxCEV1()\r\n            >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half\r\n            >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t\r\n            >>> loss = criteria(logits, lbs)\r\n        '''\r\n        # overcome ignored label\r\n        logits = input.float() # use fp32 to avoid nan\r\n        with torch.no_grad():\r\n            num_classes = logits.size(1)\r\n            label = target.clone().detach()\r\n            ignore = label.eq(self.lb_ignore)\r\n            n_valid = ignore.eq(0).sum()\r\n            label[ignore] = 0\r\n            lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes\r\n            lb_one_hot = torch.empty_like(logits).fill_(\r\n                lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()\r\n\r\n        logs = self.log_softmax(logits)\r\n        loss = -torch.sum(logs * lb_one_hot, dim=1)\r\n        loss[ignore] = 0\r\n        if self.reduction == 'mean':\r\n            loss = loss.sum() / n_valid\r\n        if self.reduction == 'sum':\r\n            loss = loss.sum()\r\n\r\n        return loss\r\n\r\n\r\nclass KnowledgeDistillationLoss(nn.Module):\r\n    \"\"\"Knowledge Distillation Loss.\r\n\r\n    Args:\r\n        T (double): Temperature. Default: 1.\r\n        reduction (str, optional): Specifies the reduction to apply to the output:\r\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\r\n          ``'mean'``: the sum of the output will be divided by the number of\r\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'``\r\n\r\n    Inputs:\r\n        - y_student (tensor): logits output of the student\r\n        - y_teacher (tensor): logits output of the teacher\r\n\r\n    Shape:\r\n        - y_student: (minibatch, `num_classes`)\r\n        - y_teacher: (minibatch, `num_classes`)\r\n\r\n    \"\"\"\r\n    def __init__(self, T=1., reduction='batchmean'):\r\n        super(KnowledgeDistillationLoss, self).__init__()\r\n        self.T = T\r\n        self.kl = nn.KLDivLoss(reduction=reduction)\r\n\r\n    def forward(self, y_student, y_teacher):\r\n        \"\"\"\"\"\"\r\n        return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1))\r\n"
  },
  {
    "path": "tllib/modules/regressor.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Tuple, Optional, List, Dict\nimport torch.nn as nn\nimport torch\n\n__all__ = ['Regressor']\n\n\nclass Regressor(nn.Module):\n    \"\"\"A generic Regressor class for domain adaptation.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data\n        num_factors (int): Number of factors\n        bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1\n        head (torch.nn.Module, optional): Any classifier head. Use `nn.Linear` by default\n        finetune (bool): Whether finetune the classifier or train from scratch. Default: True\n\n    .. note::\n        The learning rate of this regressor is set 10 times to that of the feature extractor for better accuracy\n        by default. If you have other optimization strategies, please over-ride :meth:`~Regressor.get_parameters`.\n\n    Inputs:\n        - x (tensor): input data fed to `backbone`\n\n    Outputs:\n        - predictions: regressor's predictions\n        - features: features after `bottleneck` layer and before `head` layer\n\n    Shape:\n        - Inputs: (minibatch, *) where * means, any number of additional dimensions\n        - predictions: (minibatch, `num_factors`)\n        - features: (minibatch, `features_dim`)\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_factors: int, bottleneck: Optional[nn.Module] = None,\n                 bottleneck_dim=-1, head: Optional[nn.Module] = None, finetune=True):\n        super(Regressor, self).__init__()\n        self.backbone = backbone\n        self.num_factors = num_factors\n        if bottleneck is None:\n            feature_dim = backbone.out_features\n            self.bottleneck = nn.Sequential(\n                nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1),\n                nn.BatchNorm2d(feature_dim, feature_dim),\n                nn.ReLU(),\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            )\n            self._features_dim = feature_dim\n        else:\n            self.bottleneck = bottleneck\n            assert bottleneck_dim > 0\n            self._features_dim = bottleneck_dim\n\n        if head is None:\n            self.head = nn.Sequential(\n                nn.Linear(self._features_dim, num_factors),\n                nn.Sigmoid()\n            )\n        else:\n            self.head = head\n        self.finetune = finetune\n\n    @property\n    def features_dim(self) -> int:\n        \"\"\"The dimension of features before the final `head` layer\"\"\"\n        return self._features_dim\n\n    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\"\"\"\n        f = self.backbone(x)\n        f = self.bottleneck(f)\n        predictions = self.head(f)\n        if self.training:\n            return predictions, f\n        else:\n            return predictions\n\n    def get_parameters(self, base_lr=1.0) -> List[Dict]:\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr},\n        ]\n\n        return params\n\n\n"
  },
  {
    "path": "tllib/normalization/__init__.py",
    "content": ""
  },
  {
    "path": "tllib/normalization/afn.py",
    "content": "\"\"\"\nModified from https://github.com/jihanyang/AFN\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom typing import Optional, List, Dict\nimport torch\nimport torch.nn as nn\nimport math\n\nfrom tllib.modules.classifier import Classifier as ClassfierBase\n\n\nclass AdaptiveFeatureNorm(nn.Module):\n    r\"\"\"\n    The `Stepwise Adaptive Feature Norm loss (ICCV 2019) <https://arxiv.org/pdf/1811.07456v2.pdf>`_\n\n    Instead of using restrictive scalar R to match the corresponding feature norm, Stepwise Adaptive Feature Norm\n    is used in order to learn task-specific features with large norms in a progressive manner.\n    We denote parameters of backbone :math:`G` as :math:`\\theta_g`, parameters of bottleneck :math:`F_f` as :math:`\\theta_f`\n    , parameters of classifier head :math:`F_y` as :math:`\\theta_y`, and features extracted from sample :math:`x_i` as\n    :math:`h(x_i;\\theta)`. Full loss is calculated as follows\n\n    .. math::\n        L(\\theta_g,\\theta_f,\\theta_y)=\\frac{1}{n_s}\\sum_{(x_i,y_i)\\in D_s}L_y(x_i,y_i)+\\frac{\\lambda}{n_s+n_t}\n        \\sum_{x_i\\in D_s\\cup D_t}L_d(h(x_i;\\theta_0)+\\Delta_r,h(x_i;\\theta))\\\\\n\n    where :math:`L_y` denotes classification loss, :math:`L_d` denotes norm loss, :math:`\\theta_0` and :math:`\\theta`\n    represent the updated and updating model parameters in the last and current iterations respectively.\n\n    Args:\n        delta (float): positive residual scalar to control the feature norm enlargement.\n\n    Inputs:\n        - f (tensor): feature representations on source or target domain.\n\n    Shape:\n        - f: :math:`(N, F)` where F means the dimension of input features.\n        - Outputs: scalar.\n\n    Examples::\n\n        >>> adaptive_feature_norm = AdaptiveFeatureNorm(delta=1)\n        >>> f_s = torch.randn(32, 1000)\n        >>> f_t = torch.randn(32, 1000)\n        >>> norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)\n    \"\"\"\n\n    def __init__(self, delta):\n        super(AdaptiveFeatureNorm, self).__init__()\n        self.delta = delta\n\n    def forward(self, f: torch.Tensor) -> torch.Tensor:\n        radius = f.norm(p=2, dim=1).detach()\n        assert radius.requires_grad == False\n        radius = radius + self.delta\n        loss = ((f.norm(p=2, dim=1) - radius) ** 2).mean()\n        return loss\n\n\nclass Block(nn.Module):\n    r\"\"\"\n    Basic building block for Image Classifier with structure: FC-BN-ReLU-Dropout.\n    We use :math:`L_2` preserved dropout layers.\n    Given mask probability :math:`p`, input :math:`x_k`, generated mask :math:`a_k`,\n    vanilla dropout layers calculate\n\n    .. math::\n        \\hat{x}_k = a_k\\frac{1}{1-p}x_k\\\\\n\n    While in :math:`L_2` preserved dropout layers\n\n    .. math::\n        \\hat{x}_k = a_k\\frac{1}{\\sqrt{1-p}}x_k\\\\\n\n    Args:\n        in_features (int): Dimension of input features\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1000\n        dropout_p (float, optional): dropout probability. Default: 0.5\n    \"\"\"\n\n    def __init__(self, in_features: int, bottleneck_dim: Optional[int] = 1000, dropout_p: Optional[float] = 0.5):\n        super(Block, self).__init__()\n        self.fc = nn.Linear(in_features, bottleneck_dim)\n        self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)\n        self.relu = nn.ReLU(inplace=True)\n        self.dropout = nn.Dropout(dropout_p)\n        self.dropout_p = dropout_p\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        f = self.fc(x)\n        f = self.bn(f)\n        f = self.relu(f)\n        f = self.dropout(f)\n        if self.training:\n            f.mul_(math.sqrt(1 - self.dropout_p))\n        return f\n\n\nclass ImageClassifier(ClassfierBase):\n    r\"\"\"\n    ImageClassifier for AFN.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data\n        num_classes (int): Number of classes\n        num_blocks (int, optional): Number of basic blocks. Default: 1\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1000\n        dropout_p (float, optional): dropout probability. Default: 0.5\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, num_blocks: Optional[int] = 1,\n                 bottleneck_dim: Optional[int] = 1000, dropout_p: Optional[float] = 0.5, **kwargs):\n        assert num_blocks >= 1\n        layers = [nn.Sequential(\n            Block(backbone.out_features, bottleneck_dim, dropout_p)\n        )]\n        for _ in range(num_blocks - 1):\n            layers.append(Block(bottleneck_dim, bottleneck_dim, dropout_p))\n        bottleneck = nn.Sequential(*layers)\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n        # init parameters for bottleneck and head\n        for m in self.bottleneck.modules():\n            if isinstance(m, nn.BatchNorm1d):\n                m.weight.data.normal_(1.0, 0.01)\n                m.bias.data.fill_(0)\n            if isinstance(m, nn.Linear):\n                m.weight.data.normal_(0.0, 0.01)\n                m.bias.data.normal_(0.0, 0.01)\n        for m in self.head.modules():\n            if isinstance(m, nn.Linear):\n                m.weight.data.normal_(0.0, 0.01)\n                m.bias.data.normal_(0.0, 0.01)\n\n    def get_parameters(self, base_lr=1.0) -> List[Dict]:\n        params = [\n            {\"params\": self.backbone.parameters()},\n            {\"params\": self.bottleneck.parameters(), \"momentum\": 0.9},\n            {\"params\": self.head.parameters(), \"momentum\": 0.9},\n        ]\n        return params\n"
  },
  {
    "path": "tllib/normalization/ibn.py",
    "content": "\"\"\"\nModified from https://github.com/XingangPan/IBN-Net\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport math\nimport torch\nimport torch.nn as nn\n\n__all__ = ['resnet18_ibn_a', 'resnet18_ibn_b', 'resnet34_ibn_a', 'resnet34_ibn_b', 'resnet50_ibn_a', 'resnet50_ibn_b',\n           'resnet101_ibn_a', 'resnet101_ibn_b']\n\nmodel_urls = {\n    'resnet18_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',\n    'resnet34_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',\n    'resnet50_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',\n    'resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',\n    'resnet18_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_b-bc2f3c11.pth',\n    'resnet34_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_b-04134c37.pth',\n    'resnet50_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_b-9ca61e85.pth',\n    'resnet101_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_b-c55f6dba.pth',\n}\n\n\nclass InstanceBatchNorm2d(nn.Module):\n    r\"\"\"Instance-Batch Normalization layer from\n    `Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (ECCV 2018)\n    <https://arxiv.org/pdf/1807.09441.pdf>`_.\n\n    Given input feature map :math:`f\\_input` of dimension :math:`(C,H,W)`, we first split :math:`f\\_input` into\n    two parts along `channel` dimension. They are denoted as :math:`f_1` of dimension :math:`(C_1,H,W)` and\n    :math:`f_2` of dimension :math:`(C_2,H,W)`, where :math:`C_1+C_2=C`. Then we pass :math:`f_1` and :math:`f_2`\n    through IN and BN layer, respectively, to get :math:`IN(f_1)` and :math:`BN(f_2)`. Last, we concat them along\n    `channel` dimension to create :math:`f\\_output=concat(IN(f_1), BN(f_2))`.\n\n    Args:\n        planes (int): Number of channels for the input tensor\n        ratio (float): Ratio of instance normalization in the IBN layer\n    \"\"\"\n\n    def __init__(self, planes, ratio=0.5):\n        super(InstanceBatchNorm2d, self).__init__()\n        self.half = int(planes * ratio)\n        self.IN = nn.InstanceNorm2d(self.half, affine=True)\n        self.BN = nn.BatchNorm2d(planes - self.half)\n\n    def forward(self, x):\n        split = torch.split(x, self.half, 1)\n        out1 = self.IN(split[0].contiguous())\n        out2 = self.BN(split[1].contiguous())\n        out = torch.cat((out1, out2), 1)\n        return out\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        if ibn == 'a':\n            self.bn1 = InstanceBatchNorm2d(planes)\n        else:\n            self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.IN = nn.InstanceNorm2d(planes, affine=True) if ibn == 'b' else None\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        if self.IN is not None:\n            out = self.IN(out)\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        if ibn == 'a':\n            self.bn1 = InstanceBatchNorm2d(planes)\n        else:\n            self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.IN = nn.InstanceNorm2d(planes * 4, affine=True) if ibn == 'b' else None\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        if self.IN is not None:\n            out = self.IN(out)\n        out = self.relu(out)\n\n        return out\n\n\nclass IBNNet(nn.Module):\n    r\"\"\"\n    IBNNet without fully connected layer\n    \"\"\"\n\n    def __init__(self, block, layers, ibn_cfg=('a', 'a', 'a', None)):\n        self.inplanes = 64\n        super(IBNNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        if ibn_cfg[0] == 'b':\n            self.bn1 = nn.InstanceNorm2d(64, affine=True)\n        else:\n            self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3])\n        self._out_features = 512 * block.expansion\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1, ibn=None):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes,\n                            None if ibn == 'b' else ibn,\n                            stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes,\n                                None if (ibn == 'b' and i < blocks - 1) else ibn))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        \"\"\"\"\"\"\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        return x\n\n    @property\n    def out_features(self) -> int:\n        \"\"\"The dimension of output features\"\"\"\n        return self._out_features\n\n\ndef resnet18_ibn_a(pretrained=False):\n    \"\"\"Constructs a ResNet-18-IBN-a model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=BasicBlock,\n                   layers=[2, 2, 2, 2],\n                   ibn_cfg=('a', 'a', 'a', None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_a']), strict=False)\n    return model\n\n\ndef resnet34_ibn_a(pretrained=False):\n    \"\"\"Constructs a ResNet-34-IBN-a model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=BasicBlock,\n                   layers=[3, 4, 6, 3],\n                   ibn_cfg=('a', 'a', 'a', None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_a']), strict=False)\n    return model\n\n\ndef resnet50_ibn_a(pretrained=False):\n    \"\"\"Constructs a ResNet-50-IBN-a model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=Bottleneck,\n                   layers=[3, 4, 6, 3],\n                   ibn_cfg=('a', 'a', 'a', None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_a']), strict=False)\n    return model\n\n\ndef resnet101_ibn_a(pretrained=False):\n    \"\"\"Constructs a ResNet-101-IBN-a model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=Bottleneck,\n                   layers=[3, 4, 23, 3],\n                   ibn_cfg=('a', 'a', 'a', None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_a']), strict=False)\n    return model\n\n\ndef resnet18_ibn_b(pretrained=False):\n    \"\"\"Constructs a ResNet-18-IBN-b model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=BasicBlock,\n                   layers=[2, 2, 2, 2],\n                   ibn_cfg=('b', 'b', None, None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_b']), strict=False)\n    return model\n\n\ndef resnet34_ibn_b(pretrained=False):\n    \"\"\"Constructs a ResNet-34-IBN-b model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=BasicBlock,\n                   layers=[3, 4, 6, 3],\n                   ibn_cfg=('b', 'b', None, None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_b']), strict=False)\n    return model\n\n\ndef resnet50_ibn_b(pretrained=False):\n    \"\"\"Constructs a ResNet-50-IBN-b model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=Bottleneck,\n                   layers=[3, 4, 6, 3],\n                   ibn_cfg=('b', 'b', None, None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_b']), strict=False)\n    return model\n\n\ndef resnet101_ibn_b(pretrained=False):\n    \"\"\"Constructs a ResNet-101-IBN-b model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = IBNNet(block=Bottleneck,\n                   layers=[3, 4, 23, 3],\n                   ibn_cfg=('b', 'b', None, None))\n    if pretrained:\n        model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_b']), strict=False)\n    return model\n"
  },
  {
    "path": "tllib/normalization/mixstyle/__init__.py",
    "content": "\"\"\"\nModified from https://github.com/KaiyangZhou/mixstyle-release\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport torch\nimport torch.nn as nn\n\n\nclass MixStyle(nn.Module):\n    r\"\"\"MixStyle module from `DOMAIN GENERALIZATION WITH MIXSTYLE (ICLR 2021) <https://arxiv.org/pdf/2104.02008v1.pdf>`_.\n    Given input :math:`x`, we first compute mean :math:`\\mu(x)` and standard deviation :math:`\\sigma(x)` across spatial\n    dimension. Then we permute :math:`x` and get :math:`\\tilde{x}`, corresponding mean :math:`\\mu(\\tilde{x})` and\n    standard deviation :math:`\\sigma(\\tilde{x})`. `MixUp` is performed using mean and standard deviation\n\n    .. math::\n        \\gamma_{mix} = \\lambda\\sigma(x) + (1-\\lambda)\\sigma(\\tilde{x})\n\n    .. math::\n        \\beta_{mix} = \\lambda\\mu(x) + (1-\\lambda)\\mu(\\tilde{x})\n\n    where :math:`\\lambda` is instance-wise weight sampled from `Beta distribution`. MixStyle is then\n\n    .. math::\n        MixStyle(x) = \\gamma_{mix}\\frac{x-\\mu(x)}{\\sigma(x)} + \\beta_{mix}\n\n    Args:\n          p (float): probability of using MixStyle.\n          alpha (float): parameter of the `Beta distribution`.\n          eps (float): scaling parameter to avoid numerical issues.\n    \"\"\"\n\n    def __init__(self, p=0.5, alpha=0.1, eps=1e-6):\n        super().__init__()\n        self.p = p\n        self.beta = torch.distributions.Beta(alpha, alpha)\n        self.eps = eps\n        self.alpha = alpha\n\n    def forward(self, x):\n        if not self.training:\n            return x\n\n        if random.random() > self.p:\n            return x\n\n        batch_size = x.size(0)\n\n        mu = x.mean(dim=[2, 3], keepdim=True)\n        var = x.var(dim=[2, 3], keepdim=True)\n        sigma = (var + self.eps).sqrt()\n        mu, sigma = mu.detach(), sigma.detach()\n        x_normed = (x - mu) / sigma\n\n        interpolation = self.beta.sample((batch_size, 1, 1, 1))\n        interpolation = interpolation.to(x.device)\n\n        # split into two halves and swap the order\n        perm = torch.arange(batch_size - 1, -1, -1)  # inverse index\n        perm_b, perm_a = perm.chunk(2)\n        perm_b = perm_b[torch.randperm(batch_size // 2)]\n        perm_a = perm_a[torch.randperm(batch_size // 2)]\n        perm = torch.cat([perm_b, perm_a], 0)\n\n        mu_perm, sigma_perm = mu[perm], sigma[perm]\n        mu_mix = mu * interpolation + mu_perm * (1 - interpolation)\n        sigma_mix = sigma * interpolation + sigma_perm * (1 - interpolation)\n\n        return x_normed * sigma_mix + mu_mix\n"
  },
  {
    "path": "tllib/normalization/mixstyle/resnet.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom . import MixStyle\nfrom tllib.vision.models.reid.resnet import ReidResNet\nfrom tllib.vision.models.resnet import ResNet, load_state_dict_from_url, model_urls, BasicBlock, Bottleneck\n\n__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101']\n\n\ndef _resnet_with_mix_style(arch, block, layers, pretrained, progress, mix_layers=None, mix_p=0.5, mix_alpha=0.1,\n                           resnet_class=ResNet, **kwargs):\n    \"\"\"Construct `ResNet` with MixStyle modules. Given any resnet architecture **resnet_class** that contains conv1,\n    bn1, relu, maxpool, layer1-4, this function define a new class that inherits from **resnet_class** and inserts\n    MixStyle module during forward pass. Although MixStyle Module can be inserted anywhere, original paper finds it\n    better to place MixStyle after layer1-3. Our implementation follows this idea, but you are free to modify this\n    function to try other possibilities.\n\n    Args:\n        arch (str): resnet architecture (resnet50 for example)\n        block (class): class of resnet block\n        layers (list): depth list of each block\n        pretrained (bool): if True, load imagenet pre-trained model parameters\n        progress (bool): whether or not to display a progress bar to stderr\n        mix_layers (list): layers to insert MixStyle module after\n        mix_p (float): probability to activate MixStyle during forward pass\n        mix_alpha (float): parameter alpha for beta distribution\n        resnet_class (class): base resnet class to inherit from\n    \"\"\"\n\n    if mix_layers is None:\n        mix_layers = []\n\n    available_resnet_class = [ResNet, ReidResNet]\n    assert resnet_class in available_resnet_class\n\n    class ResNetWithMixStyleModule(resnet_class):\n        def __init__(self, mix_layers, mix_p=0.5, mix_alpha=0.1, *args, **kwargs):\n            super(ResNetWithMixStyleModule, self).__init__(*args, **kwargs)\n            self.mixStyleModule = MixStyle(p=mix_p, alpha=mix_alpha)\n            for layer in mix_layers:\n                assert layer in ['layer1', 'layer2', 'layer3']\n            self.apply_layers = mix_layers\n\n        def forward(self, x):\n            x = self.conv1(x)\n            x = self.bn1(x)\n            # turn on relu activation here **except for** reid tasks\n            if resnet_class != ReidResNet:\n                x = self.relu(x)\n            x = self.maxpool(x)\n\n            x = self.layer1(x)\n            if 'layer1' in self.apply_layers:\n                x = self.mixStyleModule(x)\n            x = self.layer2(x)\n            if 'layer2' in self.apply_layers:\n                x = self.mixStyleModule(x)\n            x = self.layer3(x)\n            if 'layer3' in self.apply_layers:\n                x = self.mixStyleModule(x)\n            x = self.layer4(x)\n\n            return x\n\n    model = ResNetWithMixStyleModule(mix_layers=mix_layers, mix_p=mix_p, mix_alpha=mix_alpha, block=block,\n                                     layers=layers, **kwargs)\n    if pretrained:\n        model_dict = model.state_dict()\n        pretrained_dict = load_state_dict_from_url(model_urls[arch],\n                                                   progress=progress)\n        # remove keys from pretrained dict that doesn't appear in model dict\n        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n        model.load_state_dict(pretrained_dict, strict=False)\n    return model\n\n\ndef resnet18(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructs a ResNet-18 model with MixStyle.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet_with_mix_style('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                                  **kwargs)\n\n\ndef resnet34(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructs a ResNet-34 model with MixStyle.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet_with_mix_style('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                                  **kwargs)\n\n\ndef resnet50(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructs a ResNet-50 model with MixStyle.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet_with_mix_style('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                                  **kwargs)\n\n\ndef resnet101(pretrained=False, progress=True, **kwargs):\n    \"\"\"Constructs a ResNet-101 model with MixStyle.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet_with_mix_style('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                                  **kwargs)\n"
  },
  {
    "path": "tllib/normalization/mixstyle/sampler.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport random\nimport copy\nfrom torch.utils.data.dataset import ConcatDataset\nfrom torch.utils.data.sampler import Sampler\n\n\nclass RandomDomainMultiInstanceSampler(Sampler):\n    r\"\"\"Randomly sample :math:`N` domains, then randomly select :math:`P` instances in each domain, for each instance,\n    randomly select :math:`K` images to form a mini-batch of size :math:`N\\times P\\times K`.\n\n    Args:\n        dataset (ConcatDataset): dataset that contains data from multiple domains\n        batch_size (int): mini-batch size (:math:`N\\times P\\times K` here)\n        n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here)\n        num_instances (int): number of instances to select in each domain (:math:`K` here)\n    \"\"\"\n\n    def __init__(self, dataset, batch_size, n_domains_per_batch, num_instances):\n        super(Sampler, self).__init__()\n        self.dataset = dataset\n        self.sample_idxes_per_domain = {}\n        for idx, (_, _, domain_id) in enumerate(self.dataset):\n            if domain_id not in self.sample_idxes_per_domain:\n                self.sample_idxes_per_domain[domain_id] = []\n            self.sample_idxes_per_domain[domain_id].append(idx)\n        self.n_domains_in_dataset = len(self.sample_idxes_per_domain)\n        self.n_domains_per_batch = n_domains_per_batch\n        assert self.n_domains_in_dataset >= self.n_domains_per_batch\n\n        assert batch_size % n_domains_per_batch == 0\n        self.batch_size_per_domain = batch_size // n_domains_per_batch\n\n        assert self.batch_size_per_domain % num_instances == 0\n        self.num_instances = num_instances\n        self.num_classes_per_domain = self.batch_size_per_domain // num_instances\n        self.length = len(list(self.__iter__()))\n\n    def __iter__(self):\n        sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain)\n        domain_idxes = [idx for idx in range(self.n_domains_in_dataset)]\n        final_idxes = []\n        stop_flag = False\n        while not stop_flag:\n            selected_domains = random.sample(domain_idxes, self.n_domains_per_batch)\n\n            for domain in selected_domains:\n                sample_idxes = sample_idxes_per_domain[domain]\n                selected_idxes = self.sample_multi_instances(sample_idxes)\n                final_idxes.extend(selected_idxes)\n\n                for idx in selected_idxes:\n                    sample_idxes_per_domain[domain].remove(idx)\n\n                remaining_size = len(sample_idxes_per_domain[domain])\n                if remaining_size < self.batch_size_per_domain:\n                    stop_flag = True\n\n        return iter(final_idxes)\n\n    def sample_multi_instances(self, sample_idxes):\n        idxes_per_cls = {}\n        for idx in sample_idxes:\n            _, cls, _ = self.dataset[idx]\n            if cls not in idxes_per_cls:\n                idxes_per_cls[cls] = []\n            idxes_per_cls[cls].append(idx)\n\n        cls_list = [cls for cls in idxes_per_cls if len(idxes_per_cls[cls]) >= self.num_instances]\n        if len(cls_list) < self.num_classes_per_domain:\n            return random.sample(sample_idxes, self.batch_size_per_domain)\n\n        selected_idxes = []\n        selected_classes = random.sample(cls_list, self.num_classes_per_domain)\n        for cls in selected_classes:\n            selected_idxes.extend(random.sample(idxes_per_cls[cls], self.num_instances))\n        return selected_idxes\n\n    def __len__(self):\n        return self.length\n"
  },
  {
    "path": "tllib/normalization/stochnorm.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch.nn.parameter import Parameter\n\n__all__ = ['StochNorm1d', 'StochNorm2d', 'convert_model']\n\n\nclass _StochNorm(nn.Module):\n\n    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, p=0.5):\n        super(_StochNorm, self).__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.affine = affine\n        self.track_running_stats = track_running_stats\n        self.p = p\n        if self.affine:\n            self.weight = Parameter(torch.Tensor(num_features))\n            self.bias = Parameter(torch.Tensor(num_features))\n        else:\n            self.register_parameter('weight', None)\n            self.register_parameter('bias', None)\n\n        if self.track_running_stats:\n            self.register_buffer('running_mean', torch.zeros(num_features))\n            self.register_buffer('running_var', torch.ones(num_features))\n        else:\n            self.register_parameter('running_mean', None)\n            self.register_parameter('running_var', None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.track_running_stats:\n            self.running_mean.zero_()\n            self.running_var.fill_(1)\n        if self.affine:\n            self.weight.data.uniform_()\n            self.bias.data.zero_()\n\n    def _check_input_dim(self, input):\n        return NotImplemented\n\n    def forward(self, input):\n        self._check_input_dim(input)\n\n        if self.training:\n            z_0 = F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                False, self.momentum, self.eps)\n\n            z_1 = F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                True, self.momentum, self.eps)\n\n            if input.dim() == 2:\n                s = torch.from_numpy(\n                    np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1,\n                                                                                      self.num_features)).float().cuda()\n            elif input.dim() == 3:\n                s = torch.from_numpy(\n                    np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features,\n                                                                                      1)).float().cuda()\n            elif input.dim() == 4:\n                s = torch.from_numpy(\n                    np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features, 1,\n                                                                                      1)).float().cuda()\n            else:\n                raise BaseException()\n\n            z = (1 - s) * z_0 + s * z_1\n        else:\n            z = F.batch_norm(\n                input, self.running_mean, self.running_var, self.weight, self.bias,\n                False, self.momentum, self.eps)\n\n        return z\n\n\nclass StochNorm1d(_StochNorm):\n    r\"\"\"Applies Stochastic Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension)\n\n    Stochastic  Normalization is proposed in `Stochastic Normalization (NIPS 2020) <https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf>`_\n\n    .. math::\n\n        \\hat{x}_{i,0} = \\frac{x_i - \\tilde{\\mu}}{ \\sqrt{\\tilde{\\sigma} + \\epsilon}}\n\n        \\hat{x}_{i,1} = \\frac{x_i - \\mu}{ \\sqrt{\\sigma + \\epsilon}}\n\n        \\hat{x}_i = (1-s)\\cdot \\hat{x}_{i,0} + s\\cdot \\hat{x}_{i,1}\n\n         y_i = \\gamma \\hat{x}_i + \\beta\n\n    where :math:`\\mu` and :math:`\\sigma` are mean and variance of current mini-batch data.\n\n    :math:`\\tilde{\\mu}` and :math:`\\tilde{\\sigma}` are current moving statistics of training data.\n\n    :math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`.\n\n\n    During training, there are two normalization branches. One uses mean and\n    variance of current mini-batch data, while the other uses current moving\n    statistics of the training data as usual batch normalization.\n\n    During evaluation, the moving statistics is used for normalization.\n\n\n    Args:\n        num_features (int): :math:`c` from an expected input of size :math:`(b, c, l)` or  :math:`l` from an expected input of size :math:`(b, l)`.\n        eps (float): A value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum (float): The value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine (bool): A boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n        track_running_stats (bool): A boolean value that when set to True, this module tracks\n         the running mean and variance, and when set to False, this module does not\n         track such statistics, and initializes statistics buffers running_mean and\n         running_var as None. When these buffers are None, this module always uses\n         batch statistics in both training and eval modes. Default: True\n         p (float): The probability to choose the second branch (usual BN). Default: 0.5\n\n    Shape:\n        - Input: :math:`(b, l)` or :math:`(b, c, l)`\n        - Output: :math:`(b, l)` or :math:`(b, c, l)` (same shape as input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 2 and input.dim() != 3:\n            raise ValueError('expected 2D or 3D input (got {}D input)'\n                             .format(input.dim()))\n\n\nclass StochNorm2d(_StochNorm):\n    r\"\"\"\n    Applies Stochastic  Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension)\n\n    Stochastic  Normalization is proposed in `Stochastic Normalization (NIPS 2020) <https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf>`_\n\n    .. math::\n\n        \\hat{x}_{i,0} = \\frac{x_i - \\tilde{\\mu}}{ \\sqrt{\\tilde{\\sigma} + \\epsilon}}\n\n        \\hat{x}_{i,1} = \\frac{x_i - \\mu}{ \\sqrt{\\sigma + \\epsilon}}\n\n        \\hat{x}_i = (1-s)\\cdot \\hat{x}_{i,0} + s\\cdot \\hat{x}_{i,1}\n\n         y_i = \\gamma \\hat{x}_i + \\beta\n\n    where :math:`\\mu` and :math:`\\sigma` are mean and variance of current mini-batch data.\n\n    :math:`\\tilde{\\mu}` and :math:`\\tilde{\\sigma}` are current moving statistics of training data.\n\n    :math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`.\n\n\n    During training, there are two normalization branches. One uses mean and\n    variance of current mini-batch data, while the other uses current moving\n    statistics of the training data as usual batch normalization.\n\n    During evaluation, the moving statistics is used for normalization.\n\n\n    Args:\n        num_features (int): :math:`c` from an expected input of size :math:`(b, c, h, w)`.\n        eps (float): A value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum (float): The value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine (bool): A boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n        track_running_stats (bool): A boolean value that when set to True, this module tracks\n         the running mean and variance, and when set to False, this module does not\n         track such statistics, and initializes statistics buffers running_mean and\n         running_var as None. When these buffers are None, this module always uses\n         batch statistics in both training and eval modes. Default: True\n         p (float): The probability to choose the second branch (usual BN). Default: 0.5\n\n    Shape:\n        - Input: :math:`(b, c, h, w)`\n        - Output: :math:`(b, c, h, w)` (same shape as input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 4:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(input.dim()))\n\n\nclass StochNorm3d(_StochNorm):\n    r\"\"\"\n    Applies Stochastic  Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension)\n\n    Stochastic  Normalization is proposed in `Stochastic Normalization (NIPS 2020) <https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf>`_\n\n    .. math::\n\n        \\hat{x}_{i,0} = \\frac{x_i - \\tilde{\\mu}}{ \\sqrt{\\tilde{\\sigma} + \\epsilon}}\n\n        \\hat{x}_{i,1} = \\frac{x_i - \\mu}{ \\sqrt{\\sigma + \\epsilon}}\n\n        \\hat{x}_i = (1-s)\\cdot \\hat{x}_{i,0} + s\\cdot \\hat{x}_{i,1}\n\n         y_i = \\gamma \\hat{x}_i + \\beta\n\n    where :math:`\\mu` and :math:`\\sigma` are mean and variance of current mini-batch data.\n\n    :math:`\\tilde{\\mu}` and :math:`\\tilde{\\sigma}` are current moving statistics of training data.\n\n    :math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`.\n\n\n    During training, there are two normalization branches. One uses mean and\n    variance of current mini-batch data, while the other uses current moving\n    statistics of the training data as usual batch normalization.\n\n    During evaluation, the moving statistics is used for normalization.\n\n\n    Args:\n        num_features (int): :math:`c` from an expected input of size :math:`(b, c, d, h, w)`\n        eps (float): A value added to the denominator for numerical stability.\n            Default: 1e-5\n        momentum (float): The value used for the running_mean and running_var\n            computation. Default: 0.1\n        affine (bool): A boolean value that when set to ``True``, gives the layer learnable\n            affine parameters. Default: ``True``\n        track_running_stats (bool): A boolean value that when set to True, this module tracks\n         the running mean and variance, and when set to False, this module does not\n         track such statistics, and initializes statistics buffers running_mean and\n         running_var as None. When these buffers are None, this module always uses\n         batch statistics in both training and eval modes. Default: True\n         p (float): The probability to choose the second branch (usual BN). Default: 0.5\n\n    Shape:\n        - Input: :math:`(b, c, d, h, w)`\n        - Output: :math:`(b, c, d, h, w)` (same shape as input)\n    \"\"\"\n\n    def _check_input_dim(self, input):\n        if input.dim() != 5:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(input.dim()))\n\n\ndef convert_model(module, p):\n    \"\"\"\n    Traverses the input module and its child recursively and replaces all\n    instance of BatchNorm to StochNorm.\n\n    Args:\n        module (torch.nn.Module): The input module needs to be convert to StochNorm model.\n        p (float): The hyper-parameter for StochNorm layer.\n\n    Returns:\n         The module converted to StochNorm version.\n    \"\"\"\n\n    mod = module\n    for pth_module, stoch_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,\n                                         torch.nn.modules.batchnorm.BatchNorm2d,\n                                         torch.nn.modules.batchnorm.BatchNorm3d],\n                                        [StochNorm1d,\n                                         StochNorm2d,\n                                         StochNorm3d]):\n        if isinstance(module, pth_module):\n            mod = stoch_module(module.num_features, module.eps, module.momentum, module.affine, p)\n            mod.running_mean = module.running_mean\n            mod.running_var = module.running_var\n\n            if module.affine:\n                mod.weight.data = module.weight.data.clone().detach()\n                mod.bias.data = module.bias.data.clone().detach()\n\n    for name, child in module.named_children():\n        mod.add_module(name, convert_model(child, p))\n\n    return mod\n"
  },
  {
    "path": "tllib/ranking/__init__.py",
    "content": "from .logme import log_maximum_evidence\nfrom .nce import negative_conditional_entropy\nfrom .leep import log_expected_empirical_prediction\nfrom .hscore import h_score\n\n__all__ = ['log_maximum_evidence', 'negative_conditional_entropy', 'log_expected_empirical_prediction', 'h_score']"
  },
  {
    "path": "tllib/ranking/hscore.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\nimport numpy as np\nfrom sklearn.covariance import LedoitWolf\n\n__all__ = ['h_score', 'regularized_h_score']\n\n\ndef h_score(features: np.ndarray, labels: np.ndarray):\n    r\"\"\"\n    H-score in `An Information-theoretic Approach to Transferability in Task Transfer Learning (ICIP 2019) \n    <http://yangli-feasibility.com/home/media/icip-19.pdf>`_.\n    \n    The H-Score :math:`\\mathcal{H}` can be described as:\n\n    .. math::\n        \\mathcal{H}=\\operatorname{tr}\\left(\\operatorname{cov}(f)^{-1} \\operatorname{cov}\\left(\\mathbb{E}[f \\mid y]\\right)\\right)\n    \n    where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector\n\n    Args:\n        features (np.ndarray):features extracted by pre-trained model.\n        labels (np.ndarray):  groud-truth labels.\n\n    Shape:\n        - features: (N, F), with number of samples N and feature dimension F.\n        - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.\n        - score: scalar.\n    \"\"\"\n    f = features\n    y = labels\n\n    covf = np.cov(f, rowvar=False)\n    C = int(y.max() + 1)\n    g = np.zeros_like(f)\n\n    for i in range(C):\n        Ef_i = np.mean(f[y == i, :], axis=0)\n        g[y == i] = Ef_i\n\n    covg = np.cov(g, rowvar=False)\n    score = np.trace(np.dot(np.linalg.pinv(covf, rcond=1e-15), covg))\n\n    return score\n\n\ndef regularized_h_score(features: np.ndarray, labels: np.ndarray):\n    r\"\"\"\n    Regularized H-score in `Newer is not always better: Rethinking transferability metrics, their peculiarities, stability and performance (NeurIPS 2021) \n    <https://openreview.net/pdf?id=iz_Wwmfquno>`_.\n    \n    The  regularized H-Score :math:`\\mathcal{H}_{\\alpha}` can be described as:\n\n    .. math::\n        \\mathcal{H}_{\\alpha}=\\operatorname{tr}\\left(\\operatorname{cov}_{\\alpha}(f)^{-1}\\left(1-\\alpha \\right)\\operatorname{cov}\\left(\\mathbb{E}[f \\mid y]\\right)\\right)\n    \n    where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector and :math:`\\operatorname{cov}_{\\alpha}` the  Ledoit-Wolf \n    covariance estimator with shrinkage parameter :math:`\\alpha`\n    Args:\n        features (np.ndarray):features extracted by pre-trained model.\n        labels (np.ndarray):  groud-truth labels.\n\n    Shape:\n        - features: (N, F), with number of samples N and feature dimension F.\n        - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.\n        - score: scalar.\n    \"\"\"\n    f = features.astype('float64')\n    f = f - np.mean(f, axis=0, keepdims=True)  # Center the features for correct Ledoit-Wolf Estimation\n    y = labels\n\n    C = int(y.max() + 1)\n    g = np.zeros_like(f)\n\n    cov = LedoitWolf(assume_centered=False).fit(f)\n    alpha = cov.shrinkage_\n    covf_alpha = cov.covariance_\n\n    for i in range(C):\n        Ef_i = np.mean(f[y == i, :], axis=0)\n        g[y == i] = Ef_i\n\n    covg = np.cov(g, rowvar=False)\n    score = np.trace(np.dot(np.linalg.pinv(covf_alpha, rcond=1e-15), (1 - alpha) * covg))\n\n    return score\n"
  },
  {
    "path": "tllib/ranking/leep.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\n\nimport numpy as np\n\n__all__ = ['log_expected_empirical_prediction']\n\n\ndef log_expected_empirical_prediction(predictions: np.ndarray, labels: np.ndarray):\n    r\"\"\"\n    Log Expected Empirical Prediction in `LEEP: A New Measure to\n    Evaluate Transferability of Learned Representations (ICML 2020)\n    <http://proceedings.mlr.press/v119/nguyen20b/nguyen20b.pdf>`_.\n    \n    The LEEP :math:`\\mathcal{T}` can be described as:\n\n    .. math::\n        \\mathcal{T}=\\mathbb{E}\\log \\left(\\sum_{z \\in \\mathcal{C}_s} \\hat{P}\\left(y \\mid z\\right) \\theta\\left(y \\right)_{z}\\right)\n\n    where :math:`\\theta\\left(y\\right)_{z}` is the predictions of pre-trained model on source category, :math:`\\hat{P}\\left(y \\mid z\\right)` is the empirical conditional distribution estimated by prediction and ground-truth label.\n\n    Args:\n        predictions (np.ndarray): predictions of pre-trained model.\n        labels (np.ndarray): groud-truth labels.\n\n    Shape: \n        - predictions: (N, :math:`C_s`), with number of samples N and source class number :math:`C_s`.\n        - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.\n        - score: scalar\n    \"\"\"\n    N, C_s = predictions.shape\n    labels = labels.reshape(-1)\n    C_t = int(np.max(labels) + 1)\n\n    normalized_prob = predictions / float(N)\n    joint = np.zeros((C_t, C_s), dtype=float)  # placeholder for joint distribution over (y, z)\n\n    for i in range(C_t):\n        this_class = normalized_prob[labels == i]\n        row = np.sum(this_class, axis=0)\n        joint[i] = row\n\n    p_target_given_source = (joint / joint.sum(axis=0, keepdims=True)).T  # P(y | z)\n    empirical_prediction = predictions @ p_target_given_source\n    empirical_prob = np.array([predict[label] for predict, label in zip(empirical_prediction, labels)])\n    score = np.mean(np.log(empirical_prob))\n\n    return score\n"
  },
  {
    "path": "tllib/ranking/logme.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\nimport numpy as np\nfrom numba import njit\n\n__all__ = ['log_maximum_evidence']\n\n\ndef log_maximum_evidence(features: np.ndarray, targets: np.ndarray, regression=False, return_weights=False):\n    r\"\"\"\n    Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models\n    for Transfer Learning (ICML 2021) <https://arxiv.org/pdf/2102.11005.pdf>`_.\n    \n    Args:\n        features (np.ndarray): feature matrix from pre-trained model.\n        targets (np.ndarray): targets labels/values.\n        regression (bool, optional): whether to apply in regression setting. (Default: False)\n        return_weights (bool, optional): whether to return bayesian weight. (Default: False)\n\n    Shape:\n        - features: (N, F) with element in [0, :math:`C_t`) and feature dimension F, where :math:`C_t` denotes the number of target class\n        - targets: (N, ) or (N, C), with C regression-labels.\n        - weights: (F, :math:`C_t`).\n        - score: scalar.\n    \"\"\"\n    f = features.astype(np.float64)\n    y = targets\n    if regression:\n        y = targets.astype(np.float64)\n\n    fh = f\n    f = f.transpose()\n    D, N = f.shape\n    v, s, vh = np.linalg.svd(f @ fh, full_matrices=True)\n\n    evidences = []\n    weights = []\n    if regression:\n        C = y.shape[1]\n        for i in range(C):\n            y_ = y[:, i]\n            evidence, weight = each_evidence(y_, f, fh, v, s, vh, N, D)\n            evidences.append(evidence)\n            weights.append(weight)\n    else:\n        C = int(y.max() + 1)\n        for i in range(C):\n            y_ = (y == i).astype(np.float64)\n            evidence, weight = each_evidence(y_, f, fh, v, s, vh, N, D)\n            evidences.append(evidence)\n            weights.append(weight)\n\n    score = np.mean(evidences)\n    weights = np.vstack(weights)\n\n    if return_weights:\n        return score, weights\n    else:\n        return score\n\n\n@njit\ndef each_evidence(y_, f, fh, v, s, vh, N, D):\n    \"\"\"\n    compute the maximum evidence for each class\n    \"\"\"\n    alpha = 1.0\n    beta = 1.0\n    lam = alpha / beta\n    tmp = (vh @ (f @ y_))\n\n    for _ in range(11):\n        # should converge after at most 10 steps\n        # typically converge after two or three steps\n        gamma = (s / (s + lam)).sum()\n        m = v @ (tmp * beta / (alpha + beta * s))\n        alpha_de = (m * m).sum()\n        alpha = gamma / alpha_de\n        beta_de = ((y_ - fh @ m) ** 2).sum()\n        beta = (N - gamma) / beta_de\n        new_lam = alpha / beta\n        if np.abs(new_lam - lam) / lam < 0.01:\n            break\n        lam = new_lam\n\n    evidence = D / 2.0 * np.log(alpha) \\\n               + N / 2.0 * np.log(beta) \\\n               - 0.5 * np.sum(np.log(alpha + beta * s)) \\\n               - beta / 2.0 * beta_de \\\n               - alpha / 2.0 * alpha_de \\\n               - N / 2.0 * np.log(2 * np.pi)\n\n    return evidence / N, m\n"
  },
  {
    "path": "tllib/ranking/nce.py",
    "content": "\"\"\"\n@author: Yong Liu\n@contact: liuyong1095556447@163.com\n\"\"\"\nimport numpy as np\n\n__all__ = ['negative_conditional_entropy']\n\n\ndef negative_conditional_entropy(source_labels: np.ndarray, target_labels: np.ndarray):\n    r\"\"\"\n    Negative Conditional Entropy in `Transferability and Hardness of Supervised \n    Classification Tasks (ICCV 2019) <https://arxiv.org/pdf/1908.08142v1.pdf>`_.\n    \n    The NCE :math:`\\mathcal{H}` can be described as:\n\n    .. math::\n        \\mathcal{H}=-\\sum_{y \\in \\mathcal{C}_t} \\sum_{z \\in \\mathcal{C}_s} \\hat{P}(y, z) \\log \\frac{\\hat{P}(y, z)}{\\hat{P}(z)}\n\n    where :math:`\\hat{P}(z)` is the empirical distribution and :math:`\\hat{P}\\left(y \\mid z\\right)` is the empirical\n    conditional distribution estimated by source and target label.\n\n    Args:\n        source_labels (np.ndarray): predicted source labels.\n        target_labels (np.ndarray): groud-truth target labels.\n\n    Shape:\n        - source_labels: (N, ) elements in [0, :math:`C_s`), with source class number :math:`C_s`.\n        - target_labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.\n    \"\"\"\n    C_t = int(np.max(target_labels) + 1)\n    C_s = int(np.max(source_labels) + 1)\n    N = len(source_labels)\n\n    joint = np.zeros((C_t, C_s), dtype=float)  # placeholder for the joint distribution, shape [C_t, C_s]\n    for s, t in zip(source_labels, target_labels):\n        s = int(s)\n        t = int(t)\n        joint[t, s] += 1.0 / N\n    p_z = joint.sum(axis=0, keepdims=True)\n\n    p_target_given_source = (joint / p_z).T  # P(y | z), shape [C_s, C_t]\n    mask = p_z.reshape(-1) != 0  # valid Z, shape [C_s]\n    p_target_given_source = p_target_given_source[mask] + 1e-20  # remove NaN where p(z) = 0, add 1e-20 to avoid log (0)\n    entropy_y_given_z = np.sum(- p_target_given_source * np.log(p_target_given_source), axis=1, keepdims=True)\n    conditional_entropy = np.sum(entropy_y_given_z * p_z.reshape((-1, 1))[mask])\n\n    return -conditional_entropy\n"
  },
  {
    "path": "tllib/ranking/transrate.py",
    "content": "\"\"\"\n@author: Louis Fouquet\n@contact: louisfouquet75@gmail.com\n\"\"\"\nimport numpy as np\n\n__all__ = ['transrate']\n\n\ndef coding_rate(features: np.ndarray, eps=1e-4):\n    f = features\n    n, d = f.shape\n    (_, rate) = np.linalg.slogdet((np.eye(d) + 1 / (n * eps) * f.transpose() @ f))\n    return 0.5 * rate\n\n\ndef transrate(features: np.ndarray, labels: np.ndarray, eps=1e-4):\n    r\"\"\"\n    TransRate in `Frustratingly easy transferability estimation (ICML 2022) \n    <https://proceedings.mlr.press/v162/huang22d/huang22d.pdf>`_.\n    \n    The TransRate :math:`TrR` can be described as:\n\n    .. math::\n        TrR= R\\left(f, \\espilon \\right) - R\\left(f, \\espilon \\mid y \\right) \n    \n    where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector, \n    :math:`R` is the coding rate with distortion rate :math:`\\epsilon`\n\n    Args:\n        features (np.ndarray):features extracted by pre-trained model.\n        labels (np.ndarray):  groud-truth labels.\n        eps (float, optional): distortion rare (Default: 1e-4)\n\n    Shape:\n        - features: (N, F), with number of samples N and feature dimension F.\n        - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.\n        - score: scalar.\n    \"\"\"\n    f = features\n    y = labels\n    f = f - np.mean(f, axis=0, keepdims=True)\n    Rf = coding_rate(f, eps)\n    Rfy = 0.0\n    C = int(y.max() + 1)\n    for i in range(C):\n        Rfy += coding_rate(f[(y == i).flatten()], eps)\n    return Rf - Rfy / C\n"
  },
  {
    "path": "tllib/regularization/__init__.py",
    "content": "from .bss import *\nfrom .co_tuning import *\nfrom .delta import *\nfrom .bi_tuning import *\nfrom .knowledge_distillation import *\n\n__all__ = ['bss', 'co_tuning', 'delta', 'bi_tuning', 'knowledge_distillation']"
  },
  {
    "path": "tllib/regularization/bi_tuning.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom torch.nn.functional import normalize\nfrom tllib.modules.classifier import Classifier as ClassifierBase\n\n\nclass Classifier(ClassifierBase):\n    \"\"\"Classifier class for Bi-Tuning.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data\n        num_classes (int): Number of classes\n        projection_dim (int, optional): Dimension of the projector head. Default: 128\n        finetune (bool): Whether finetune the classifier or train from scratch. Default: True\n\n    .. note::\n        The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy\n        by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`.\n\n    Inputs:\n        - x (tensor): input data fed to `backbone`\n\n    Outputs:\n        In the training mode,\n            - y: classifier's predictions\n            - z: projector's predictions\n            - hn: normalized features after `bottleneck` layer and before `head` layer\n        In the eval mode,\n            - y: classifier's predictions\n\n    Shape:\n        - Inputs: (minibatch, *) where * means, any number of additional dimensions\n        - y: (minibatch, `num_classes`)\n        - z: (minibatch, `projection_dim`)\n        - hn: (minibatch, `features_dim`)\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=128, finetune=True, pool_layer=None):\n        head = nn.Linear(backbone.out_features, num_classes)\n        head.weight.data.normal_(0, 0.01)\n        head.bias.data.fill_(0.0)\n        super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune,\n                                         pool_layer=pool_layer)\n        self.projector = nn.Linear(backbone.out_features, projection_dim)\n        self.projection_dim = projection_dim\n\n    def forward(self, x: torch.Tensor):\n        batch_size = x.shape[0]\n        h = self.backbone(x)\n        h = self.pool_layer(h)\n        h = self.bottleneck(h)\n        y = self.head(h)\n        z = normalize(self.projector(h), dim=1)\n        hn = torch.cat([h, torch.ones(batch_size, 1, dtype=torch.float).to(h.device)], dim=1)\n        hn = normalize(hn, dim=1)\n        if self.training:\n            return y, z, hn\n        else:\n            return y\n\n    def get_parameters(self, base_lr=1.0):\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.projector.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n        ]\n\n        return params\n\n\nclass BiTuning(nn.Module):\n    \"\"\"\n    Bi-Tuning Module in `Bi-tuning of Pre-trained Representations <https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29>`_.\n\n    Args:\n        encoder_q (Classifier): Query encoder.\n        encoder_k (Classifier): Key encoder.\n        num_classes (int): Number of classes\n        K (int): Queue size. Default: 40\n        m (float): Momentum coefficient. Default: 0.999\n        T (float): Temperature. Default: 0.07\n\n    Inputs:\n        - im_q (tensor): input data fed to `encoder_q`\n        - im_k (tensor): input data fed to `encoder_k`\n        - labels (tensor): classification labels of input data\n\n    Outputs: y_q, logits_z, logits_y, labels_c\n        - y_q: query classifier's predictions\n        - logits_z: projector's predictions on both positive and negative samples\n        - logits_y: classifier's predictions on both positive and negative samples\n        - labels_c: contrastive labels\n\n    Shape:\n        - im_q, im_k: (minibatch, *) where * means, any number of additional dimensions\n        - labels: (minibatch, )\n        - y_q: (minibatch, `num_classes`)\n        - logits_z: (minibatch, 1 + `num_classes` x `K`, `projection_dim`)\n        - logits_y: (minibatch, 1 + `num_classes` x `K`, `num_classes`)\n        - labels_c: (minibatch, 1 + `num_classes` x `K`)\n    \"\"\"\n\n    def __init__(self, encoder_q: Classifier, encoder_k: Classifier, num_classes, K=40, m=0.999, T=0.07):\n        super(BiTuning, self).__init__()\n        self.K = K\n        self.m = m\n        self.T = T\n        self.num_classes = num_classes\n\n        # create the encoders\n        # num_classes is the output fc dimension\n        self.encoder_q = encoder_q\n        self.encoder_k = encoder_k\n\n        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):\n            param_k.data.copy_(param_q.data)  # initialize\n            param_k.requires_grad = False  # not update by gradient\n\n        # create the queue\n        self.register_buffer(\"queue_h\", torch.randn(encoder_q.features_dim + 1, num_classes, K))\n        self.register_buffer(\"queue_z\", torch.randn(encoder_q.projection_dim, num_classes, K))\n        self.queue_h = normalize(self.queue_h, dim=0)\n        self.queue_z = normalize(self.queue_z, dim=0)\n\n        self.register_buffer(\"queue_ptr\", torch.zeros(num_classes, dtype=torch.long))\n\n    @torch.no_grad()\n    def _momentum_update_key_encoder(self):\n        \"\"\"\n        Momentum update of the key encoder\n        \"\"\"\n        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):\n            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)\n\n    @torch.no_grad()\n    def _dequeue_and_enqueue(self, h, z, label):\n        batch_size = h.shape[0]\n        assert self.K % batch_size == 0  # for simplicity\n\n        ptr = int(self.queue_ptr[label])\n        # replace the keys at ptr (dequeue and enqueue)\n        self.queue_h[:, label, ptr: ptr + batch_size] = h.T\n        self.queue_z[:, label, ptr: ptr + batch_size] = z.T\n\n        # move pointer\n        self.queue_ptr[label] = (ptr + batch_size) % self.K\n\n    def forward(self, im_q, im_k, labels):\n        batch_size = im_q.size(0)\n        device = im_q.device\n        # compute query features\n        y_q, z_q, h_q = self.encoder_q(im_q)\n\n        # compute key features\n        with torch.no_grad():  # no gradient to keys\n            self._momentum_update_key_encoder()  # update the key encoder\n            y_k, z_k, h_k = self.encoder_k(im_k)\n\n        # compute logits for projection z\n        # current positive logits: Nx1\n        logits_z_cur = torch.einsum('nc,nc->n', [z_q, z_k]).unsqueeze(-1)\n        queue_z = self.queue_z.clone().detach().to(device)\n        # positive logits: N x K\n        logits_z_pos = torch.Tensor([]).to(device)\n        # negative logits: N x ((C-1) x K)\n        logits_z_neg = torch.Tensor([]).to(device)\n\n        for i in range(batch_size):\n            c = labels[i]\n            pos_samples = queue_z[:, c, :]  # D x K\n            neg_samples = torch.cat([queue_z[:, 0: c, :], queue_z[:, c + 1:, :]], dim=1).flatten(\n                start_dim=1)  # D x ((C-1)xK)\n            ith_pos = torch.einsum('nc,ck->nk', [z_q[i: i + 1], pos_samples])  # 1 x D\n            ith_neg = torch.einsum('nc,ck->nk', [z_q[i: i + 1], neg_samples])  # 1 x ((C-1)xK)\n            logits_z_pos = torch.cat((logits_z_pos, ith_pos), dim=0)\n            logits_z_neg = torch.cat((logits_z_neg, ith_neg), dim=0)\n\n            self._dequeue_and_enqueue(h_k[i:i + 1], z_k[i:i + 1], labels[i])\n\n        logits_z = torch.cat([logits_z_cur, logits_z_pos, logits_z_neg], dim=1)  # Nx(1+C*K)\n\n        # apply temperature\n        logits_z /= self.T\n        logits_z = nn.LogSoftmax(dim=1)(logits_z)\n\n        # compute logits for classification y\n        w = torch.cat([self.encoder_q.head.weight.data, self.encoder_q.head.bias.data.unsqueeze(-1)], dim=1)\n        w = normalize(w, dim=1)  # C x F\n\n        # current positive logits: Nx1\n        logits_y_cur = torch.einsum('nk,kc->nc', [h_q, w.T])  # N x C\n        queue_y = self.queue_h.clone().detach().to(device).flatten(start_dim=1).T  # (C * K) x F\n        logits_y_queue = torch.einsum('nk,kc->nc', [queue_y, w.T]).reshape(self.num_classes, -1,\n                                                                           self.num_classes)  # C x K x C\n\n        logits_y = torch.Tensor([]).to(device)\n\n        for i in range(batch_size):\n            c = labels[i]\n            # calculate the ith sample in the batch\n            cur_sample = logits_y_cur[i:i + 1, c]  # 1\n            pos_samples = logits_y_queue[c, :, c]  # K\n            neg_samples = torch.cat([logits_y_queue[0: c, :, c], logits_y_queue[c + 1:, :, c]], dim=0).view(\n                -1)  # (C-1)*K\n\n            ith = torch.cat([cur_sample, pos_samples, neg_samples])  # 1+C*K\n            logits_y = torch.cat([logits_y, ith.unsqueeze(dim=0)], dim=0)\n\n        logits_y /= self.T\n        logits_y = nn.LogSoftmax(dim=1)(logits_y)\n\n        # contrastive labels\n        labels_c = torch.zeros([batch_size, self.K * self.num_classes + 1]).to(device)\n        labels_c[:, 0:self.K + 1].fill_(1.0 / (self.K + 1))\n        return y_q, logits_z, logits_y, labels_c\n"
  },
  {
    "path": "tllib/regularization/bss.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport torch\nimport torch.nn as nn\n\n__all__ = ['BatchSpectralShrinkage']\n\n\nclass BatchSpectralShrinkage(nn.Module):\n    r\"\"\"\n    The regularization term in `Catastrophic Forgetting Meets Negative Transfer:\n    Batch Spectral Shrinkage for Safe Transfer Learning (NIPS 2019) <https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf>`_.\n\n\n    The BSS regularization of feature matrix :math:`F` can be described as:\n\n    .. math::\n        L_{bss}(F) = \\sum_{i=1}^{k} \\sigma_{-i}^2 ,\n\n    where :math:`k` is the number of singular values to be penalized, :math:`\\sigma_{-i}` is the :math:`i`-th smallest singular value of feature matrix :math:`F`.\n\n    All the singular values of feature matrix :math:`F` are computed by `SVD`:\n\n    .. math::\n        F = U\\Sigma V^T,\n\n    where the main diagonal elements of the singular value matrix :math:`\\Sigma` is :math:`[\\sigma_1, \\sigma_2, ..., \\sigma_b]`.\n\n\n    Args:\n        k (int):  The number of singular values to be penalized. Default: 1\n\n    Shape:\n        - Input: :math:`(b, |\\mathcal{f}|)` where :math:`b` is the batch size and :math:`|\\mathcal{f}|` is feature dimension.\n        - Output: scalar.\n\n    \"\"\"\n    def __init__(self, k=1):\n        super(BatchSpectralShrinkage, self).__init__()\n        self.k = k\n\n    def forward(self, feature):\n        result = 0\n        u, s, v = torch.svd(feature.t())\n        num = s.size(0)\n        for i in range(self.k):\n            result += torch.pow(s[num-1-i], 2)\n        return result\n"
  },
  {
    "path": "tllib/regularization/co_tuning.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nfrom typing import Tuple, Optional, List, Dict\nimport os\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\nimport tqdm\nfrom .lwf import Classifier as ClassifierBase\n\n__all__ = ['Classifier', 'CoTuningLoss', 'Relationship']\n\n\nclass CoTuningLoss(nn.Module):\n    \"\"\"\n    The Co-Tuning loss in `Co-Tuning for Transfer Learning (NIPS 2020)\n    <http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf>`_.\n\n    Inputs:\n        - input: p(y_s) predicted by source classifier.\n        - target: p(y_s|y_t), where y_t is the ground truth class label in target dataset.\n\n    Shape:\n        - input:  (b, N_p), where b is the batch size and N_p is the number of classes in source dataset\n        - target: (b, N_p), where b is the batch size and N_p is the number of classes in source dataset\n        - Outputs: scalar.\n    \"\"\"\n\n    def __init__(self):\n        super(CoTuningLoss, self).__init__()\n\n    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:\n        y = - target * F.log_softmax(input, dim=-1)\n        y = torch.mean(torch.sum(y, dim=-1))\n        return y\n\n\nclass Relationship(object):\n    \"\"\"Learns the category relationship p(y_s|y_t) between source dataset and target dataset.\n\n    Args:\n        data_loader (torch.utils.data.DataLoader): A data loader of target dataset.\n        classifier (torch.nn.Module): A classifier for Co-Tuning.\n        device (torch.nn.Module): The device to run classifier.\n        cache (str, optional): Path to find and save the relationship file.\n\n    \"\"\"\n    def __init__(self, data_loader, classifier, device, cache=None):\n        super(Relationship, self).__init__()\n        self.data_loader = data_loader\n        self.classifier = classifier\n        self.device = device\n        if cache is None or not os.path.exists(cache):\n            source_predictions, target_labels = self.collect_labels()\n            self.relationship = self.get_category_relationship(source_predictions, target_labels)\n            if cache is not None:\n                np.save(cache, self.relationship)\n        else:\n            self.relationship = np.load(cache)\n\n    def __getitem__(self, category):\n        return self.relationship[category]\n\n    def collect_labels(self):\n        \"\"\"\n        Collects predictions of target dataset by source model and corresponding ground truth class labels.\n\n        Returns:\n            - source_probabilities, [N, N_p], where N_p is the number of classes in source dataset\n            - target_labels, [N], where 0 <= each number < N_t, and N_t is the number of classes in target dataset\n        \"\"\"\n\n        print(\"Collecting labels to calculate relationship\")\n        source_predictions = []\n        target_labels = []\n\n        self.classifier.eval()\n        with torch.no_grad():\n            for i, (x, label) in enumerate(tqdm.tqdm(self.data_loader)):\n                x = x.to(self.device)\n                y_s = self.classifier(x)\n\n                source_predictions.append(F.softmax(y_s, dim=1).detach().cpu().numpy())\n                target_labels.append(label)\n\n        return np.concatenate(source_predictions, 0), np.concatenate(target_labels, 0)\n\n    def get_category_relationship(self, source_probabilities, target_labels):\n        \"\"\"\n        The direct approach of learning category relationship p(y_s | y_t).\n\n        Args:\n            source_probabilities (numpy.array): [N, N_p], where N_p is the number of classes in source dataset\n            target_labels (numpy.array): [N], where 0 <= each number < N_t, and N_t is the number of classes in target dataset\n\n        Returns:\n            Conditional probability, [N_c, N_p] matrix representing the conditional probability p(pre-trained class | target_class)\n        \"\"\"\n        N_t = np.max(target_labels) + 1  # the number of target classes\n        conditional = []\n        for i in range(N_t):\n            this_class = source_probabilities[target_labels == i]\n            average = np.mean(this_class, axis=0, keepdims=True)\n            conditional.append(average)\n        return np.concatenate(conditional)\n\n\nclass Classifier(ClassifierBase):\n    \"\"\"A Classifier used in `Co-Tuning for Transfer Learning (NIPS 2020)\n    <http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf>`_..\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data.\n        num_classes (int): Number of classes.\n        head_source (torch.nn.Module): Classifier head of source model.\n        head_target (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default\n        finetune (bool): Whether finetune the classifier or train from scratch. Default: True\n\n\n    Inputs:\n        - x (tensor): input data fed to backbone\n\n    Outputs:\n        - y_s: predictions of source classifier head\n        - y_t: predictions of target classifier head\n\n    Shape:\n        - Inputs: (b, *) where b is the batch size and * means any number of additional dimensions\n        - y_s: (b, N), where b is the batch size and N is the number of classes\n        - y_t: (b, N), where b is the batch size and N is the number of classes\n\n    \"\"\"\n    def __init__(self, backbone: nn.Module, num_classes: int,  head_source,  **kwargs):\n        super(Classifier, self).__init__(backbone, num_classes, head_source, **kwargs)\n\n    def get_parameters(self, base_lr=1.0) -> List[Dict]:\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.head_source.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head_target.parameters(), \"lr\": 1.0 * base_lr},\n        ]\n        return params\n"
  },
  {
    "path": "tllib/regularization/delta.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport torch\nimport torch.nn as nn\n\nimport functools\nfrom collections import OrderedDict\n\n\nclass L2Regularization(nn.Module):\n    r\"\"\"The L2 regularization of parameters :math:`w` can be described as:\n\n    .. math::\n        {\\Omega} (w) = \\dfrac{1}{2}  \\Vert w\\Vert_2^2 ,\n\n    Args:\n        model (torch.nn.Module):  The model to apply L2 penalty.\n\n    Shape:\n        - Output: scalar.\n    \"\"\"\n    def __init__(self, model: nn.Module):\n        super(L2Regularization, self).__init__()\n        self.model = model\n\n    def forward(self):\n        output = 0.0\n        for param in self.model.parameters():\n            output += 0.5 * torch.norm(param) ** 2\n        return output\n\n\nclass SPRegularization(nn.Module):\n    r\"\"\"\n    The SP (Starting Point) regularization from `Explicit inductive bias for transfer learning with convolutional networks\n    (ICML 2018) <https://arxiv.org/abs/1802.01483>`_\n\n    The SP regularization of parameters :math:`w` can be described as:\n\n    .. math::\n        {\\Omega} (w) = \\dfrac{1}{2}  \\Vert w-w_0\\Vert_2^2 ,\n\n    where :math:`w_0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning.\n\n\n    Args:\n        source_model (torch.nn.Module):  The source (starting point) model.\n        target_model (torch.nn.Module):  The target (fine-tuning) model.\n\n    Shape:\n        - Output: scalar.\n    \"\"\"\n    def __init__(self, source_model: nn.Module, target_model: nn.Module):\n        super(SPRegularization, self).__init__()\n        self.target_model = target_model\n        self.source_weight = {}\n        for name, param in source_model.named_parameters():\n            self.source_weight[name] = param.detach()\n\n    def forward(self):\n        output = 0.0\n        for name, param in self.target_model.named_parameters():\n            output += 0.5 * torch.norm(param - self.source_weight[name]) ** 2\n        return output\n\n\nclass BehavioralRegularization(nn.Module):\n    r\"\"\"\n    The behavioral regularization from `DELTA:DEep Learning Transfer using Feature Map with Attention\n    for convolutional networks (ICLR 2019) <https://openreview.net/pdf?id=rkgbwsAcYm>`_\n\n    It can be described as:\n\n    .. math::\n        {\\Omega} (w) = \\sum_{j=1}^{N}   \\Vert FM_j(w, \\boldsymbol x)-FM_j(w^0, \\boldsymbol x)\\Vert_2^2 ,\n\n    where :math:`w^0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning,\n    :math:`FM_j(w, \\boldsymbol x)` is feature maps generated from the :math:`j`-th layer of the model parameterized with :math:`w`, given the input :math:`\\boldsymbol x`.\n\n\n    Inputs:\n        layer_outputs_source (OrderedDict):  The dictionary for source model, where the keys are layer names and the values are feature maps correspondingly.\n\n        layer_outputs_target (OrderedDict):  The dictionary for target model, where the keys are layer names and the values are feature maps correspondingly.\n\n    Shape:\n        - Output: scalar.\n\n    \"\"\"\n    def __init__(self):\n        super(BehavioralRegularization, self).__init__()\n\n    def forward(self, layer_outputs_source, layer_outputs_target):\n        output = 0.0\n        for fm_src, fm_tgt in zip(layer_outputs_source.values(), layer_outputs_target.values()):\n            output += 0.5 * (torch.norm(fm_tgt - fm_src.detach()) ** 2)\n        return output\n\n\nclass AttentionBehavioralRegularization(nn.Module):\n    r\"\"\"\n    The behavioral regularization with attention from `DELTA:DEep Learning Transfer using Feature Map with Attention\n    for convolutional networks (ICLR 2019) <https://openreview.net/pdf?id=rkgbwsAcYm>`_\n\n    It can be described as:\n\n    .. math::\n        {\\Omega} (w) = \\sum_{j=1}^{N}  W_j(w) \\Vert FM_j(w, \\boldsymbol x)-FM_j(w^0, \\boldsymbol x)\\Vert_2^2 ,\n\n    where\n    :math:`w^0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning.\n    :math:`FM_j(w, \\boldsymbol x)` is feature maps generated from the :math:`j`-th layer of the model parameterized with :math:`w`, given the input :math:`\\boldsymbol x`.\n    :math:`W_j(w)` is the channel attention of the :math:`j`-th layer of the model parameterized with :math:`w`.\n\n    Args:\n        channel_attention (list): The channel attentions of feature maps generated by each selected layer. For the layer with C channels, the channel attention is a tensor of shape [C].\n\n    Inputs:\n        layer_outputs_source (OrderedDict):  The dictionary for source model, where the keys are layer names and the values are feature maps correspondingly.\n\n        layer_outputs_target (OrderedDict):  The dictionary for target model, where the keys are layer names and the values are feature maps correspondingly.\n\n    Shape:\n        - Output: scalar.\n\n    \"\"\"\n    def __init__(self, channel_attention):\n        super(AttentionBehavioralRegularization, self).__init__()\n        self.channel_attention = channel_attention\n\n    def forward(self, layer_outputs_source, layer_outputs_target):\n        output = 0.0\n        for i, (fm_src, fm_tgt) in enumerate(zip(layer_outputs_source.values(), layer_outputs_target.values())):\n            b, c, h, w = fm_src.shape\n            fm_src = fm_src.reshape(b, c, h * w)\n            fm_tgt = fm_tgt.reshape(b, c, h * w)\n\n            distance = torch.norm(fm_tgt - fm_src.detach(), 2, 2)\n            distance = c * torch.mul(self.channel_attention[i], distance ** 2) / (h * w)\n            output += 0.5 * torch.sum(distance)\n\n        return output\n\n\ndef get_attribute(obj, attr, *args):\n    def _getattr(obj, attr):\n        return getattr(obj, attr, *args)\n    return functools.reduce(_getattr, [obj] + attr.split('.'))\n\n\nclass IntermediateLayerGetter:\n    r\"\"\"\n    Wraps a model to get intermediate output values of selected layers.\n\n    Args:\n       model (torch.nn.Module): The model to collect intermediate layer feature maps.\n       return_layers (list): The names of selected modules to return the output.\n       keep_output (bool): If True, `model_output` contains the final model's output, else return None. Default: True\n\n    Returns:\n       - An OrderedDict of intermediate outputs. The keys are selected layer names in `return_layers` and the values are the feature map outputs. The order is the same as `return_layers`.\n       - The model's final output. If `keep_output` is False, return None.\n\n    \"\"\"\n    def __init__(self, model, return_layers, keep_output=True):\n        self._model = model\n        self.return_layers = return_layers\n        self.keep_output = keep_output\n\n    def __call__(self, *args, **kwargs):\n        ret = OrderedDict()\n        handles = []\n        for name in self.return_layers:\n            layer = get_attribute(self._model, name)\n            def hook(module, input, output, name=name):\n                ret[name] = output\n            try:\n                h = layer.register_forward_hook(hook)\n            except AttributeError as e:\n                raise AttributeError(f'Module {name} not found')\n            handles.append(h)\n\n        if self.keep_output:\n            output = self._model(*args, **kwargs)\n        else:\n            self._model(*args, **kwargs)\n            output = None\n\n        for h in handles:\n            h.remove()\n\n        return ret, output\n"
  },
  {
    "path": "tllib/regularization/knowledge_distillation.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass KnowledgeDistillationLoss(nn.Module):\n    \"\"\"Knowledge Distillation Loss.\n\n    Args:\n        T (double): Temperature. Default: 1.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'``\n\n    Inputs:\n        - y_student (tensor): logits output of the student\n        - y_teacher (tensor): logits output of the teacher\n\n    Shape:\n        - y_student: (minibatch, `num_classes`)\n        - y_teacher: (minibatch, `num_classes`)\n\n    \"\"\"\n    def __init__(self, T=1., reduction='batchmean'):\n        super(KnowledgeDistillationLoss, self).__init__()\n        self.T = T\n        self.kl = nn.KLDivLoss(reduction=reduction)\n\n    def forward(self, y_student, y_teacher):\n        \"\"\"\"\"\"\n        return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1))\n"
  },
  {
    "path": "tllib/regularization/lwf.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, List, Dict\nimport torch\nimport torch.nn as nn\nimport tqdm\n\n\ndef collect_pretrain_labels(data_loader, classifier, device):\n    source_predictions = []\n\n    classifier.eval()\n    with torch.no_grad():\n        for i, (x, label) in enumerate(tqdm.tqdm(data_loader)):\n            x = x.to(device)\n            y_s = classifier(x)\n            source_predictions.append(y_s.detach().cpu())\n    return torch.cat(source_predictions, dim=0)\n\n\nclass Classifier(nn.Module):\n    \"\"\"A Classifier used in `Learning Without Forgetting (ECCV 2016)\n    <https://arxiv.org/abs/1606.09282>`_..\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data.\n        num_classes (int): Number of classes.\n        head_source (torch.nn.Module): Classifier head of source model.\n        head_target (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default\n        finetune (bool): Whether finetune the classifier or train from scratch. Default: True\n\n\n    Inputs:\n        - x (tensor): input data fed to backbone\n\n    Outputs:\n        - y_s: predictions of source classifier head\n        - y_t: predictions of target classifier head\n\n    Shape:\n        - Inputs: (b, *) where b is the batch size and * means any number of additional dimensions\n        - y_s: (b, N), where b is the batch size and N is the number of classes\n        - y_t: (b, N), where b is the batch size and N is the number of classes\n\n    \"\"\"\n    def __init__(self, backbone: nn.Module, num_classes: int,  head_source,\n                 head_target: Optional[nn.Module] = None, bottleneck: Optional[nn.Module] = None,\n                 bottleneck_dim: Optional[int] = -1,  finetune=True, pool_layer=None):\n        super(Classifier, self).__init__()\n        self.backbone = backbone\n        self.num_classes = num_classes\n        if pool_layer is None:\n            self.pool_layer = nn.Sequential(\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            )\n        else:\n            self.pool_layer = pool_layer\n        if bottleneck is None:\n            self.bottleneck = nn.Identity()\n            self._features_dim = backbone.out_features\n        else:\n            self.bottleneck = bottleneck\n            assert bottleneck_dim > 0\n            self._features_dim = bottleneck_dim\n\n        self.head_source = head_source\n        if head_target is None:\n            self.head_target = nn.Linear(self._features_dim, num_classes)\n        else:\n            self.head_target = head_target\n        self.finetune = finetune\n\n    @property\n    def features_dim(self) -> int:\n        \"\"\"The dimension of features before the final `head` layer\"\"\"\n        return self._features_dim\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\"\"\"\n        f = self.backbone(x)\n        f = self.pool_layer(f)\n        y_s = self.head_source(f)\n        y_t = self.head_target(self.bottleneck(f))\n        if self.training:\n            return y_s, y_t\n        else:\n            return y_t\n\n    def get_parameters(self, base_lr=1.0) -> List[Dict]:\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            # {\"params\": self.head_source.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head_target.parameters(), \"lr\": 1.0 * base_lr},\n        ]\n        return params\n"
  },
  {
    "path": "tllib/reweight/__init__.py",
    "content": ""
  },
  {
    "path": "tllib/reweight/groupdro.py",
    "content": "\"\"\"\nModified from https://github.com/facebookresearch/DomainBed\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch\n\n\nclass AutomaticUpdateDomainWeightModule(object):\n    r\"\"\"\n    Maintaining group weight based on loss history of all domains according\n    to `Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case\n    Generalization (ICLR 2020) <https://arxiv.org/pdf/1911.08731.pdf>`_.\n\n    Suppose we have :math:`N` domains. During each iteration, we first calculate unweighted loss among all\n    domains, resulting in :math:`loss\\in R^N`. Then we update domain weight by\n\n    .. math::\n        w_k = w_k * \\text{exp}(loss_k ^{\\eta}), \\forall k \\in [1, N]\n\n    where :math:`\\eta` is the hyper parameter which ensures smoother change of weight.\n    As :math:`w \\in R^N` denotes a distribution, we `normalize`\n    :math:`w` by its sum. At last, weighted loss is calculated as our objective\n\n    .. math::\n        objective = \\sum_{k=1}^N w_k * loss_k\n\n    Args:\n        num_domains (int): The number of source domains.\n        eta (float): Hyper parameter eta.\n        device (torch.device): The device to run on.\n    \"\"\"\n\n    def __init__(self, num_domains: int, eta: float, device):\n        self.domain_weight = torch.ones(num_domains).to(device) / num_domains\n        self.eta = eta\n\n    def get_domain_weight(self, sampled_domain_idxes):\n        \"\"\"Get domain weight to calculate final objective.\n\n        Inputs:\n            - sampled_domain_idxes (list): sampled domain indexes in current mini-batch\n\n        Shape:\n            - sampled_domain_idxes: :math:`(D, )` where D means the number of sampled domains in current mini-batch\n            - Outputs: :math:`(D, )`\n        \"\"\"\n        domain_weight = self.domain_weight[sampled_domain_idxes]\n        domain_weight = domain_weight / domain_weight.sum()\n        return domain_weight\n\n    def update(self, sampled_domain_losses: torch.Tensor, sampled_domain_idxes):\n        \"\"\"Update domain weight using loss of current mini-batch.\n\n        Inputs:\n            - sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch\n            - sampled_domain_idxes (list): sampled domain indexes in current mini-batch\n\n        Shape:\n            - sampled_domain_losses: :math:`(D, )` where D means the number of sampled domains in current mini-batch\n            - sampled_domain_idxes: :math:`(D, )`\n        \"\"\"\n        sampled_domain_losses = sampled_domain_losses.detach()\n\n        for loss, idx in zip(sampled_domain_losses, sampled_domain_idxes):\n            self.domain_weight[idx] *= (self.eta * loss).exp()\n"
  },
  {
    "path": "tllib/reweight/iwan.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom typing import Optional, List, Dict\nimport torch\nimport torch.nn as nn\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\n\n\nclass ImportanceWeightModule(object):\n    r\"\"\"\n    Calculating class weight based on the output of discriminator.\n    Introduced by `Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018) <https://arxiv.org/abs/1803.09210>`_\n\n    Args:\n        discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features.\n            Its input shape is :math:`(N, F)` and output shape is :math:`(N, 1)`\n        partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \\\n            just for debugging, since in real-world dataset, we have no access to the index of partial classes. \\\n            Default: None.\n\n    Examples::\n\n        >>> domain_discriminator = DomainDiscriminator(1024, 1024)\n        >>> importance_weight_module = ImportanceWeightModule(domain_discriminator)\n        >>> num_iterations = 10000\n        >>> for _ in range(num_iterations):\n        >>>     # feature from source domain\n        >>>     f_s = torch.randn(32, 1024)\n        >>>     # importance weights for source instance\n        >>>     w_s = importance_weight_module.get_importance_weight(f_s)\n    \"\"\"\n\n    def __init__(self, discriminator: nn.Module, partial_classes_index: Optional[List[int]] = None):\n        self.discriminator = discriminator\n        self.partial_classes_index = partial_classes_index\n\n    def get_importance_weight(self, feature):\n        \"\"\"\n        Get importance weights for each instance.\n\n        Args:\n            feature (tensor): feature from source domain, in shape :math:`(N, F)`\n\n        Returns:\n            instance weight in shape :math:`(N, 1)`\n        \"\"\"\n        weight = 1. - self.discriminator(feature)\n        weight = weight / (weight.mean() + 1e-5)\n        weight = weight.detach()\n        return weight\n\n    def get_partial_classes_weight(self, weights: torch.Tensor, labels: torch.Tensor):\n        \"\"\"\n        Get class weight averaged on the partial classes and non-partial classes respectively.\n\n        Args:\n            weights (tensor): instance weight in shape :math:`(N, 1)`\n            labels (tensor): ground truth labels in shape :math:`(N, 1)`\n\n        .. warning::\n            This function is just for debugging, since in real-world dataset, we have no access to the index of \\\n            partial classes and this function will throw an error when `partial_classes_index` is None.\n        \"\"\"\n        assert self.partial_classes_index is not None\n\n        weights = weights.squeeze()\n        is_partial = torch.Tensor([label in self.partial_classes_index for label in labels]).to(weights.device)\n        if is_partial.sum() > 0:\n            partial_classes_weight = (weights * is_partial).sum() / is_partial.sum()\n        else:\n            partial_classes_weight = torch.tensor(0)\n\n        not_partial = 1. - is_partial\n        if not_partial.sum() > 0:\n            not_partial_classes_weight = (weights * not_partial).sum() / not_partial.sum()\n        else:\n            not_partial_classes_weight = torch.tensor(0)\n        return partial_classes_weight, not_partial_classes_weight\n\n\nclass ImageClassifier(ClassifierBase):\n    r\"\"\"The Image Classifier for `Importance Weighted Adversarial Nets for Partial Domain Adaptation <https://arxiv.org/abs/1803.09210>`_\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n"
  },
  {
    "path": "tllib/reweight/pada.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, List, Tuple\n\nfrom torch.utils.data.dataloader import DataLoader\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\n\n\nclass AutomaticUpdateClassWeightModule(object):\n    r\"\"\"\n    Calculating class weight based on the output of classifier. See ``ClassWeightModule`` about the details of the calculation.\n    Every N iterations, the class weight is updated automatically.\n\n    Args:\n        update_steps (int): N, the number of iterations to update class weight.\n        data_loader (torch.utils.data.DataLoader): The data loader from which we can collect classification outputs.\n        classifier (torch.nn.Module): Classifier.\n        num_classes (int): Number of classes.\n        device (torch.device): The device to run classifier.\n        temperature (float, optional): T, temperature in ClassWeightModule. Default: 0.1\n        partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \\\n          just for debugging, since in real-world dataset, we have no access to the index of partial classes. \\\n          Default: None.\n\n    Examples::\n\n        >>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...)\n        >>> num_iterations = 10000\n        >>> for _ in range(num_iterations):\n        >>>     class_weight_module.step()\n        >>>     # weight for F.cross_entropy\n        >>>     w_c = class_weight_module.get_class_weight_for_cross_entropy_loss()\n        >>>     # weight for tllib.alignment.dann.DomainAdversarialLoss\n        >>>     w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()\n    \"\"\"\n\n    def __init__(self, update_steps: int, data_loader: DataLoader,\n                 classifier: nn.Module, num_classes: int,\n                 device: torch.device, temperature: Optional[float] = 0.1,\n                 partial_classes_index: Optional[List[int]] = None):\n        self.update_steps = update_steps\n        self.data_loader = data_loader\n        self.classifier = classifier\n        self.device = device\n        self.class_weight_module = ClassWeightModule(temperature)\n        self.class_weight = torch.ones(num_classes).to(device)\n        self.num_steps = 0\n        self.partial_classes_index = partial_classes_index\n        if partial_classes_index is not None:\n            self.non_partial_classes_index = [c for c in range(num_classes) if c not in partial_classes_index]\n\n    def step(self):\n        self.num_steps += 1\n        if self.num_steps % self.update_steps == 0:\n            all_outputs = collect_classification_results(self.data_loader, self.classifier, self.device)\n            self.class_weight = self.class_weight_module(all_outputs)\n\n    def get_class_weight_for_cross_entropy_loss(self):\n        \"\"\"\n        Outputs: weight for F.cross_entropy\n\n        Shape: :math:`(C, )` where C means the number of classes.\n        \"\"\"\n        return self.class_weight\n\n    def get_class_weight_for_adversarial_loss(self, source_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Outputs:\n            - w_s: source weight for :py:class:`~tllib.alignment.dann.DomainAdversarialLoss`\n            - w_t: target weight for :py:class:`~tllib.alignment.dann.DomainAdversarialLoss`\n\n        Shape:\n            - w_s: :math:`(minibatch, )`\n            - w_t: :math:`(minibatch, )`\n        \"\"\"\n        class_weight_adv_source = self.class_weight[source_labels]\n        class_weight_adv_target = torch.ones_like(class_weight_adv_source) * class_weight_adv_source.mean()\n        return class_weight_adv_source, class_weight_adv_target\n\n    def get_partial_classes_weight(self):\n        \"\"\"\n        Get class weight averaged on the partial classes and non-partial classes respectively.\n\n        .. warning::\n\n            This function is just for debugging, since in real-world dataset, we have no access to the index of \\\n            partial classes and this function will throw an error when `partial_classes_index` is None.\n        \"\"\"\n        assert self.partial_classes_index is not None\n        return torch.mean(self.class_weight[self.partial_classes_index]), torch.mean(\n            self.class_weight[self.non_partial_classes_index])\n\n\nclass ClassWeightModule(nn.Module):\n    r\"\"\"\n    Calculating class weight based on the output of classifier.\n    Introduced by `Partial Adversarial Domain Adaptation (ECCV 2018) <https://arxiv.org/abs/1808.04205>`_\n\n    Given classification logits outputs :math:`\\{\\hat{y}_i\\}_{i=1}^n`, where :math:`n` is the dataset size,\n    the weight indicating the contribution of each class to the training can be calculated as\n    follows\n\n    .. math::\n        \\mathcal{\\gamma} = \\dfrac{1}{n} \\sum_{i=1}^{n}\\text{softmax}( \\hat{y}_i / T),\n\n    where :math:`\\mathcal{\\gamma}` is a :math:`|\\mathcal{C}|`-dimensional weight vector quantifying the contribution\n    of each class and T is a hyper-parameters called temperature.\n\n    In practice, it's possible that some of the weights are very small, thus, we normalize weight :math:`\\mathcal{\\gamma}`\n    by dividing its largest element, i.e. :math:`\\mathcal{\\gamma} \\leftarrow \\mathcal{\\gamma} / max(\\mathcal{\\gamma})`\n\n    Args:\n        temperature (float, optional): hyper-parameters :math:`T`. Default: 0.1\n\n    Shape:\n        - Inputs: (minibatch, :math:`|\\mathcal{C}|`)\n        - Outputs: (:math:`|\\mathcal{C}|`,)\n    \"\"\"\n\n    def __init__(self, temperature: Optional[float] = 0.1):\n        super(ClassWeightModule, self).__init__()\n        self.temperature = temperature\n\n    def forward(self, outputs: torch.Tensor):\n        outputs.detach_()\n        softmax_outputs = F.softmax(outputs / self.temperature, dim=1)\n        class_weight = torch.mean(softmax_outputs, dim=0)\n        class_weight = class_weight / torch.max(class_weight)\n        class_weight = class_weight.view(-1)\n        return class_weight\n\n\ndef collect_classification_results(data_loader: DataLoader, classifier: nn.Module,\n                                   device: torch.device) -> torch.Tensor:\n    \"\"\"\n    Fetch data from `data_loader`, and then use `classifier` to collect classification results\n\n    Args:\n        data_loader (torch.utils.data.DataLoader): Data loader.\n        classifier (torch.nn.Module): A classifier.\n        device (torch.device)\n\n    Returns:\n        Classification results in shape (len(data_loader), :math:`|\\mathcal{C}|`).\n    \"\"\"\n    training = classifier.training\n    classifier.eval()\n    all_outputs = []\n    with torch.no_grad():\n        for i, (images, target) in enumerate(data_loader):\n            images = images.to(device)\n            output = classifier(images)\n            all_outputs.append(output)\n    classifier.train(training)\n    return torch.cat(all_outputs, dim=0)\n"
  },
  {
    "path": "tllib/self_training/__init__.py",
    "content": ""
  },
  {
    "path": "tllib/self_training/cc_loss.py",
    "content": "\"\"\"\n@author: Ying Jin\n@contact: sherryying003@gmail.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\nfrom ..modules.entropy import entropy\n\n\n__all__ = ['CCConsistency']\n\n\nclass CCConsistency(nn.Module):\n    r\"\"\"\n    CC Loss attach class confusion consistency to MCC.\n\n    Args:\n        temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if\n          temperature is 1.0.\n        thr (float): The confidence threshold.\n\n    .. note::\n        Make sure that temperature is larger than 0. Confidence threshold is larger than 0, smaller than 1.0.\n\n    Inputs: g_t\n        - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`\n        - g_t_strong (tensor): unnormalized classifier predictions on target domain, with strong data augmentation, :math:`g^t_{strong}`\n\n    Shape:\n        - g_t, g_t_strong: :math:`(minibatch, C)` where C means the number of classes.\n        - Output: scalar.\n\n    Examples::\n        >>> temperature = 2.0\n        >>> loss = CCConsistency(temperature)\n        >>> # logits output from target domain\n        >>> g_t = torch.randn(batch_size, num_classes)\n        >>> g_t_strong = torch.randn(batch_size, num_classes)\n        >>> output = loss(g_t, g_t_strong)\n    \"\"\"\n\n    def __init__(self, temperature: float, thr=0.7):\n        super(CCConsistency, self).__init__()\n        self.temperature = temperature\n        self.thr = thr\n\n    def forward(self, logits: torch.Tensor, logits_strong: torch.Tensor) -> torch.Tensor:\n        batch_size, num_classes = logits.shape\n        logits = logits.detach()\n\n        prediction_thr = F.softmax(logits / self.temperature, dim=1)\n        max_probs, max_idx = torch.max(prediction_thr, dim=-1)\n        mask_binary = max_probs.ge(self.thr)  ### 0.7 for DomainNet, 0.95 for other datasets\n        mask = mask_binary.float().detach()\n\n        if mask.sum() == 0:\n            return 0, 0\n        else:\n            logits = logits[mask_binary]\n            logits_strong = logits_strong[mask_binary]\n\n            predictions = F.softmax(logits / self.temperature, dim=1)  # batch_size x num_classes\n            entropy_weight = entropy(predictions).detach()\n            entropy_weight = 1 + torch.exp(-entropy_weight)\n            entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1)   # batch_size x 1\n            class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes\n            class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1)\n\n            predictions_stong = F.softmax(logits_strong / self.temperature, dim=1)\n            entropy_weight_strong = entropy(predictions_stong).detach()\n            entropy_weight_strong = 1 + torch.exp(-entropy_weight_strong)\n            entropy_weight_strong = (batch_size * entropy_weight_strong / torch.sum(entropy_weight_strong)).unsqueeze(dim=1)   # batch_size x 1\n            class_confusion_matrix_strong = torch.mm((predictions_stong * entropy_weight_strong).transpose(1, 0), predictions_stong)  # num_classes x num_classes\n            class_confusion_matrix_strong = class_confusion_matrix_strong / torch.sum(class_confusion_matrix_strong, dim=1)\n\n            consistency_loss = ((class_confusion_matrix - class_confusion_matrix_strong) ** 2).sum()  / num_classes * mask.sum() / batch_size\n            #mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes\n            return consistency_loss, mask.sum()/batch_size"
  },
  {
    "path": "tllib/self_training/dst.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tllib.modules.grl import WarmStartGradientReverseLayer\nfrom tllib.modules.classifier import Classifier\n\n\nclass ImageClassifier(Classifier):\n    r\"\"\"\n    Classifier with non-linear pseudo head :math:`h_{\\text{pseudo}}` and worst-case estimation head\n    :math:`h_{\\text{worst}}` from `Debiased Self-Training for Semi-Supervised Learning <https://arxiv.org/abs/2202.07136>`_.\n    Both heads are directly connected to the feature extractor :math:`\\psi`. We implement end-to-end adversarial\n    training procedure between :math:`\\psi` and :math:`h_{\\text{worst}}` by introducing a gradient reverse layer.\n    Note that both heads can be safely discarded during inference, and thus will introduce no inference cost.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data\n        num_classes (int): Number of classes\n        bottleneck_dim (int, optional): Feature dimension of the bottleneck layer.\n        width (int, optional): Hidden dimension of the non-linear pseudo head and worst-case estimation head.\n\n    Inputs:\n        - x (tensor): input data fed to `backbone`\n\n    Outputs:\n        - outputs: predictions of the main head :math:`h`\n        - outputs_adv: predictions of the worst-case estimation head :math:`h_{\\text{worst}}`\n        - outputs_pseudo: predictions of the pseudo head :math:`h_{\\text{pseudo}}`\n\n    Shape:\n        - Inputs: (minibatch, *) where * means, any number of additional dimensions\n        - outputs, outputs_adv, outputs_pseudo: (minibatch, `num_classes`)\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, width=2048, **kwargs):\n        bottleneck = nn.Sequential(\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(0.5)\n        )\n        bottleneck[0].weight.data.normal_(0, 0.005)\n        bottleneck[0].bias.data.fill_(0.1)\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n        self.pseudo_head = nn.Sequential(\n            nn.Linear(self.features_dim, width),\n            nn.ReLU(),\n            nn.Dropout(0.5),\n            nn.Linear(width, self.num_classes)\n        )\n        self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False)\n        self.adv_head = nn.Sequential(\n            nn.Linear(self.features_dim, width),\n            nn.ReLU(),\n            nn.Dropout(0.5),\n            nn.Linear(width, self.num_classes)\n        )\n\n    def forward(self, x: torch.Tensor):\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        f_adv = self.grl_layer(f)\n        outputs_adv = self.adv_head(f_adv)\n        outputs = self.head(f)\n        outputs_pseudo = self.pseudo_head(f)\n        if self.training:\n            return outputs, outputs_adv, outputs_pseudo\n        else:\n            return outputs\n\n    def get_parameters(self, base_lr=1.0):\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.pseudo_head.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.adv_head.parameters(), \"lr\": 1.0 * base_lr}\n        ]\n\n        return params\n\n    def step(self):\n        self.grl_layer.step()\n\n\ndef shift_log(x, offset=1e-6):\n    \"\"\"\n    First shift, then calculate log for numerical stability.\n    \"\"\"\n\n    return torch.log(torch.clamp(x + offset, max=1.))\n\n\nclass WorstCaseEstimationLoss(nn.Module):\n    r\"\"\"\n    Worst-case Estimation loss from `Debiased Self-Training for Semi-Supervised Learning <https://arxiv.org/abs/2202.07136>`_\n    that forces the worst possible head :math:`h_{\\text{worst}}` to predict correctly on all labeled samples\n    :math:`\\mathcal{L}` while making as many mistakes as possible on unlabeled data :math:`\\mathcal{U}`. In the\n    classification task, it is defined as:\n\n    .. math::\n        loss(\\mathcal{L}, \\mathcal{U}) =\n        \\eta' \\mathbb{E}_{y^l, y_{adv}^l \\sim\\hat{\\mathcal{L}}} -\\log\\left(\\frac{\\exp(y_{adv}^l[h_{y^l}])}{\\sum_j \\exp(y_{adv}^l[j])}\\right) +\n        \\mathbb{E}_{y^u, y_{adv}^u \\sim\\hat{\\mathcal{U}}} -\\log\\left(1-\\frac{\\exp(y_{adv}^u[h_{y^u}])}{\\sum_j \\exp(y_{adv}^u[j])}\\right),\n\n    where :math:`y^l` and :math:`y^u` are logits output by the main head :math:`h` on labeled data and unlabeled data,\n    respectively. :math:`y_{adv}^l` and :math:`y_{adv}^u` are logits output by the worst-case estimation\n    head :math:`h_{\\text{worst}}`. :math:`h_y` refers to the predicted label when the logits output is :math:`y`.\n\n    Args:\n        eta_prime (float): the trade-off hyper parameter :math:`\\eta'`.\n\n    Inputs:\n        - y_l: logits output :math:`y^l` by the main head on labeled data\n        - y_l_adv: logits output :math:`y^l_{adv}` by the worst-case estimation head on labeled data\n        - y_u: logits output :math:`y^u` by the main head on unlabeled data\n        - y_u_adv: logits output :math:`y^u_{adv}` by the worst-case estimation head on unlabeled data\n\n    Shape:\n        - Inputs: :math:`(minibatch, C)` where C denotes the number of classes.\n        - Output: scalar.\n\n    \"\"\"\n\n    def __init__(self, eta_prime):\n        super(WorstCaseEstimationLoss, self).__init__()\n        self.eta_prime = eta_prime\n\n    def forward(self, y_l, y_l_adv, y_u, y_u_adv):\n        _, prediction_l = y_l.max(dim=1)\n        loss_l = self.eta_prime * F.cross_entropy(y_l_adv, prediction_l)\n\n        _, prediction_u = y_u.max(dim=1)\n        loss_u = F.nll_loss(shift_log(1. - F.softmax(y_u_adv, dim=1)), prediction_u)\n\n        return loss_l + loss_u\n"
  },
  {
    "path": "tllib/self_training/flexmatch.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom collections import Counter\n\nimport torch\n\n\nclass DynamicThresholdingModule(object):\n    r\"\"\"\n    Dynamic thresholding module from `FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling\n    <https://arxiv.org/abs/2110.08263>`_. At time :math:`t`, for each category :math:`c`,\n    the learning status :math:`\\sigma_t(c)` is estimated by the number of samples whose predictions fall into this class\n    and above a threshold (e.g. 0.95). Then, FlexMatch normalizes :math:`\\sigma_t(c)` to make its range between 0 and 1\n\n    .. math::\n        \\beta_t(c) = \\frac{\\sigma_t(c)}{\\underset{c'}{\\text{max}}~\\sigma_t(c')}.\n\n    The dynamic threshold is formulated as\n\n    .. math::\n        \\mathcal{T}_t(c) = \\mathcal{M}(\\beta_t(c)) \\cdot \\tau,\n\n    where \\tau denotes the pre-defined threshold (e.g. 0.95), :math:`\\mathcal{M}` denotes a (possibly non-linear)\n    mapping function.\n\n    Args:\n        threshold (float): The pre-defined confidence threshold\n        warmup (bool): Whether perform threshold warm-up. If True, the number of unlabeled data that have not been\n            used will be considered when normalizing :math:`\\sigma_t(c)`\n        mapping_func (callable): An increasing mapping function. For example, this function can be (1) concave\n            :math:`\\mathcal{M}(x)=\\text{ln}(x+1)/\\text{ln}2`, (2) linear :math:`\\mathcal{M}(x)=x`,\n            and (3) convex :math:`\\mathcal{M}(x)=2/2-x`\n        num_classes (int): Number of classes\n        n_unlabeled_samples (int): Size of the unlabeled dataset\n        device (torch.device): Device\n\n    \"\"\"\n\n    def __init__(self, threshold, warmup, mapping_func, num_classes, n_unlabeled_samples, device):\n        self.threshold = threshold\n        self.warmup = warmup\n        self.mapping_func = mapping_func\n        self.num_classes = num_classes\n        self.n_unlabeled_samples = n_unlabeled_samples\n        self.net_outputs = torch.zeros(n_unlabeled_samples, dtype=torch.long).to(device)\n        self.net_outputs.fill_(-1)\n        self.device = device\n\n    def get_threshold(self, pseudo_labels):\n        \"\"\"Calculate and return dynamic threshold\"\"\"\n        pseudo_counter = Counter(self.net_outputs.tolist())\n        if max(pseudo_counter.values()) == self.n_unlabeled_samples:\n            # In the early stage of training, the network does not output pseudo labels with high confidence.\n            # In this case, the learning status of all categories is simply zero.\n            status = torch.zeros(self.num_classes).to(self.device)\n        else:\n            if not self.warmup and -1 in pseudo_counter.keys():\n                pseudo_counter.pop(-1)\n            max_num = max(pseudo_counter.values())\n            # estimate learning status\n            status = [\n                pseudo_counter[c] / max_num for c in range(self.num_classes)\n            ]\n            status = torch.FloatTensor(status).to(self.device)\n        # calculate dynamic threshold\n        dynamic_threshold = self.threshold * self.mapping_func(status[pseudo_labels])\n        return dynamic_threshold\n\n    def update(self, idxes, selected_mask, pseudo_labels):\n        \"\"\"Update the learning status\n\n        Args:\n            idxes (tensor): Indexes of corresponding samples\n            selected_mask (tensor): A binary mask, a value of 1 indicates the prediction for this sample will be updated\n            pseudo_labels (tensor): Network predictions\n\n        \"\"\"\n        if idxes[selected_mask == 1].nelement() != 0:\n            self.net_outputs[idxes[selected_mask == 1]] = pseudo_labels[selected_mask == 1]\n"
  },
  {
    "path": "tllib/self_training/mcc.py",
    "content": "\"\"\"\n@author: Ying Jin\n@contact: sherryying003@gmail.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\nfrom ..modules.entropy import entropy\n\n\n__all__ = ['MinimumClassConfusionLoss', 'ImageClassifier']\n\n\nclass MinimumClassConfusionLoss(nn.Module):\n    r\"\"\"\n    Minimum Class Confusion loss minimizes the class confusion in the target predictions.\n\n    You can see more details in `Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020) <https://arxiv.org/abs/1912.03699>`_\n\n    Args:\n        temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if\n          temperature is 1.0.\n\n    .. note::\n        Make sure that temperature is larger than 0.\n\n    Inputs: g_t\n        - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`\n\n    Shape:\n        - g_t: :math:`(minibatch, C)` where C means the number of classes.\n        - Output: scalar.\n\n    Examples::\n        >>> temperature = 2.0\n        >>> loss = MinimumClassConfusionLoss(temperature)\n        >>> # logits output from target domain\n        >>> g_t = torch.randn(batch_size, num_classes)\n        >>> output = loss(g_t)\n\n    MCC can also serve as a regularizer for existing methods.\n    Examples::\n        >>> from tllib.modules.domain_discriminator import DomainDiscriminator\n        >>> num_classes = 2\n        >>> feature_dim = 1024\n        >>> batch_size = 10\n        >>> temperature = 2.0\n        >>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024)\n        >>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')\n        >>> mcc_loss = MinimumClassConfusionLoss(temperature)\n        >>> # features from source domain and target domain\n        >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)\n        >>> # logits output from source domain adn target domain\n        >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)\n        >>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t)\n    \"\"\"\n\n    def __init__(self, temperature: float):\n        super(MinimumClassConfusionLoss, self).__init__()\n        self.temperature = temperature\n\n    def forward(self, logits: torch.Tensor) -> torch.Tensor:\n        batch_size, num_classes = logits.shape\n        predictions = F.softmax(logits / self.temperature, dim=1)  # batch_size x num_classes\n        entropy_weight = entropy(predictions).detach()\n        entropy_weight = 1 + torch.exp(-entropy_weight)\n        entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1)  # batch_size x 1\n        class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes\n        class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1)\n        mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes\n        return mcc_loss\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n"
  },
  {
    "path": "tllib/self_training/mean_teacher.py",
    "content": "import copy\nfrom typing import Optional\nimport torch\n\n\ndef set_requires_grad(net, requires_grad=False):\n    \"\"\"\n    Set requires_grad=False for all the parameters to avoid unnecessary computations\n    \"\"\"\n    for param in net.parameters():\n        param.requires_grad = requires_grad\n\n\nclass EMATeacher(object):\n    r\"\"\"\n    Exponential moving average model from `Mean teachers are better role models: Weight-averaged consistency targets\n    improve semi-supervised deep learning results (NIPS 2017) <https://arxiv.org/abs/1703.01780>`_\n\n    We use :math:`\\theta_t'` to denote parameters of the teacher model at training step t, use :math:`\\theta_t` to\n    denote parameters of the student model at training step t. Given decay factor :math:`\\alpha`,\n    we update the teacher model in an exponential moving average manner\n\n    .. math::\n        \\theta_t'=\\alpha \\theta_{t-1}' + (1-\\alpha)\\theta_t\n\n    Args:\n        model (torch.nn.Module): the student model\n        alpha (float): decay factor for EMA.\n\n    Inputs:\n        x (tensor): input tensor\n\n    Examples::\n\n        >>> classifier = ImageClassifier(backbone, num_classes=31, bottleneck_dim=256).to(device)\n        >>> # initialize teacher model\n        >>> teacher = EMATeacher(classifier, 0.9)\n        >>> num_iterations = 1000\n        >>> for _ in range(num_iterations):\n        >>>     # x denotes input of one mini-batch\n        >>>     # you can get teacher model's output by teacher(x)\n        >>>     y_teacher = teacher(x)\n        >>>     # when you want to update teacher, you should call teacher.update()\n        >>>     teacher.update()\n    \"\"\"\n\n    def __init__(self, model, alpha):\n        self.model = model\n        self.alpha = alpha\n        self.teacher = copy.deepcopy(model)\n        set_requires_grad(self.teacher, False)\n\n    def set_alpha(self, alpha: float):\n        assert alpha >= 0\n        self.alpha = alpha\n\n    def update(self):\n        for teacher_param, param in zip(self.teacher.parameters(), self.model.parameters()):\n            teacher_param.data = self.alpha * teacher_param + (1 - self.alpha) * param\n\n    def __call__(self, x: torch.Tensor):\n        return self.teacher(x)\n\n    def train(self, mode: Optional[bool] = True):\n        self.teacher.train(mode)\n\n    def eval(self):\n        self.train(False)\n\n    def state_dict(self):\n        return self.teacher.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self.teacher.load_state_dict(state_dict)\n\n    @property\n    def module(self):\n        return self.teacher.module\n\n\ndef update_bn(model, ema_model):\n    \"\"\"\n    Replace batch normalization statistics of the teacher model with that ot the student model\n    \"\"\"\n    for m2, m1 in zip(ema_model.named_modules(), model.named_modules()):\n        if ('bn' in m2[0]) and ('bn' in m1[0]):\n            bn2, bn1 = m2[1].state_dict(), m1[1].state_dict()\n            bn2['running_mean'].data.copy_(bn1['running_mean'].data)\n            bn2['running_var'].data.copy_(bn1['running_var'].data)\n            bn2['num_batches_tracked'].data.copy_(bn1['num_batches_tracked'].data)\n"
  },
  {
    "path": "tllib/self_training/pi_model.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom typing import Callable, Optional\nimport numpy as np\nimport torch\nfrom torch import nn as nn\n\n\ndef sigmoid_warm_up(current_epoch, warm_up_epochs: int):\n    \"\"\"Exponential warm up function from `Temporal Ensembling for Semi-Supervised Learning\n    (ICLR 2017) <https://arxiv.org/abs/1610.02242>`_.\n    \"\"\"\n    assert warm_up_epochs >= 0\n    if warm_up_epochs == 0:\n        return 1.0\n    else:\n        current_epoch = np.clip(current_epoch, 0.0, warm_up_epochs)\n        process = 1.0 - current_epoch / warm_up_epochs\n        return float(np.exp(-5.0 * process * process))\n\n\nclass ConsistencyLoss(nn.Module):\n    r\"\"\"\n    Consistency loss between two predictions. Given distance measure :math:`D`, predictions :math:`p_1, p_2`,\n    binary mask :math:`mask`, the consistency loss is\n\n    .. math::\n        D(p_1, p_2) * mask\n\n    Args:\n        distance_measure (callable): Distance measure function.\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Inputs:\n        - p1: the first prediction\n        - p2: the second prediction\n        - mask: binary mask. Default: 1. (use all samples when calculating loss)\n\n    Shape:\n        - p1, p2: :math:`(N, C)` where C means the number of classes.\n        - mask: :math:`(N, )` where N means mini-batch size.\n    \"\"\"\n\n    def __init__(self, distance_measure: Callable, reduction: Optional[str] = 'mean'):\n        super(ConsistencyLoss, self).__init__()\n        self.distance_measure = distance_measure\n        self.reduction = reduction\n\n    def forward(self, p1: torch.Tensor, p2: torch.Tensor, mask=1.):\n        cons_loss = self.distance_measure(p1, p2)\n        cons_loss = cons_loss * mask\n        if self.reduction == 'mean':\n            return cons_loss.mean()\n        elif self.reduction == 'sum':\n            return cons_loss.sum()\n        else:\n            return cons_loss\n\n\nclass L2ConsistencyLoss(ConsistencyLoss):\n    r\"\"\"\n    L2 consistency loss. Given two predictions :math:`p_1, p_2` and binary mask :math:`mask`, the\n    L2 consistency loss is\n\n    .. math::\n        \\text{MSELoss}(p_1, p_2) * mask\n\n    \"\"\"\n\n    def __init__(self, reduction: Optional[str] = 'mean'):\n        def l2_distance(p1: torch.Tensor, p2: torch.Tensor):\n            return ((p1 - p2) ** 2).sum(dim=1)\n\n        super(L2ConsistencyLoss, self).__init__(l2_distance, reduction)\n"
  },
  {
    "path": "tllib/self_training/pseudo_label.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\n\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ConfidenceBasedSelfTrainingLoss(nn.Module):\n    \"\"\"\n    Self training loss that adopts confidence threshold to select reliable pseudo labels from\n    `Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (ICML 2013)\n    <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf>`_.\n\n    Args:\n        threshold (float): Confidence threshold.\n\n    Inputs:\n        - y: unnormalized classifier predictions.\n        - y_target: unnormalized classifier predictions which will used for generating pseudo labels.\n\n    Returns:\n         A tuple, including\n            - self_training_loss: self training loss with pseudo labels.\n            - mask: binary mask that indicates which samples are retained (whose confidence is above the threshold).\n            - pseudo_labels: generated pseudo labels.\n\n    Shape:\n        - y, y_target: :math:`(minibatch, C)` where C means the number of classes.\n        - self_training_loss: scalar.\n        - mask, pseudo_labels :math:`(minibatch, )`.\n\n    \"\"\"\n\n    def __init__(self, threshold: float):\n        super(ConfidenceBasedSelfTrainingLoss, self).__init__()\n        self.threshold = threshold\n\n    def forward(self, y, y_target):\n        confidence, pseudo_labels = F.softmax(y_target.detach(), dim=1).max(dim=1)\n        mask = (confidence > self.threshold).float()\n        self_training_loss = (F.cross_entropy(y, pseudo_labels, reduction='none') * mask).mean()\n\n        return self_training_loss, mask, pseudo_labels\n"
  },
  {
    "path": "tllib/self_training/self_ensemble.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tllib.modules.classifier import Classifier as ClassifierBase\n\n\nclass ClassBalanceLoss(nn.Module):\n    r\"\"\"\n    Class balance loss that penalises the network for making predictions that exhibit large class imbalance.\n    Given predictions :math:`p` with dimension :math:`(N, C)`, we first calculate\n    the mini-batch mean per-class probability :math:`p_{mean}` with dimension :math:`(C, )`, where\n\n    .. math::\n        p_{mean}^j = \\frac{1}{N} \\sum_{i=1}^N p_i^j\n\n    Then we calculate binary cross entropy loss between :math:`p_{mean}` and uniform probability vector :math:`u` with\n    the same dimension where :math:`u^j` = :math:`\\frac{1}{C}`\n\n    .. math::\n        loss = \\text{BCELoss}(p_{mean}, u)\n\n    Args:\n        num_classes (int): Number of classes\n\n    Inputs:\n        - p (tensor): predictions from classifier\n\n    Shape:\n        - p: :math:`(N, C)` where C means the number of classes.\n    \"\"\"\n\n    def __init__(self, num_classes):\n        super(ClassBalanceLoss, self).__init__()\n        self.uniform_distribution = torch.ones(num_classes) / num_classes\n\n    def forward(self, p: torch.Tensor):\n        return F.binary_cross_entropy(p.mean(dim=0), self.uniform_distribution.to(p.device))\n\n\nclass ImageClassifier(ClassifierBase):\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):\n        bottleneck = nn.Sequential(\n            # nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n            # nn.Flatten(),\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU()\n        )\n        super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)\n"
  },
  {
    "path": "tllib/self_training/self_tuning.py",
    "content": "\"\"\"\nAdapted from https://github.com/thuml/Self-Tuning/tree/master\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch\nimport torch.nn as nn\nfrom torch.nn.functional import normalize\nfrom tllib.modules.classifier import Classifier as ClassifierBase\n\n\nclass Classifier(ClassifierBase):\n    \"\"\"Classifier class for Self-Tuning.\n\n    Args:\n        backbone (torch.nn.Module): Any backbone to extract 2-d features from data\n        num_classes (int): Number of classes.\n        projection_dim (int, optional): Dimension of the projector head. Default: 128\n        finetune (bool): Whether finetune the classifier or train from scratch. Default: True\n\n    Inputs:\n        - x (tensor): input data fed to `backbone`\n\n    Outputs:\n        In the training mode,\n            - h: projections\n            - y: classifier's predictions\n        In the eval mode,\n            - y: classifier's predictions\n\n    Shape:\n        - Inputs: (minibatch, *) where * means, any number of additional dimensions\n        - y: (minibatch, `num_classes`)\n        - h: (minibatch, `projection_dim`)\n\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=1024, bottleneck_dim=1024, finetune=True,\n                 pool_layer=None):\n        bottleneck = nn.Sequential(\n            nn.Linear(backbone.out_features, bottleneck_dim),\n            nn.BatchNorm1d(bottleneck_dim),\n            nn.ReLU(),\n            nn.Dropout(0.5)\n        )\n        bottleneck[0].weight.data.normal_(0, 0.005)\n        bottleneck[0].bias.data.fill_(0.1)\n        head = nn.Linear(1024, num_classes)\n        super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune,\n                                         pool_layer=pool_layer, bottleneck=bottleneck, bottleneck_dim=bottleneck_dim)\n        self.projector = nn.Linear(1024, projection_dim)\n        self.projection_dim = projection_dim\n\n    def forward(self, x: torch.Tensor):\n        f = self.pool_layer(self.backbone(x))\n        f = self.bottleneck(f)\n        # projections\n        h = self.projector(f)\n        h = normalize(h, dim=1)\n        # predictions\n        predictions = self.head(f)\n        if self.training:\n            return h, predictions\n        else:\n            return predictions\n\n    def get_parameters(self, base_lr=1.0):\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.projector.parameters(), \"lr\": 0.1 * base_lr if self.finetune else 1.0 * base_lr},\n        ]\n\n        return params\n\n\nclass SelfTuning(nn.Module):\n    r\"\"\"Self-Tuning module in `Self-Tuning for Data-Efficient Deep Learning (self-tuning, ICML 2021)\n    <http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf>`_.\n\n    Args:\n        encoder_q (Classifier): Query encoder.\n        encoder_k (Classifier): Key encoder.\n        num_classes (int): Number of classes\n        K (int): Queue size. Default: 32\n        m (float): Momentum coefficient. Default: 0.999\n        T (float): Temperature. Default: 0.07\n\n    Inputs:\n        - im_q (tensor): input data fed to `encoder_q`\n        - im_k (tensor): input data fed to `encoder_k`\n        - labels (tensor): classification labels of input data\n\n    Outputs: pgc_logits, pgc_labels, y_q\n        - pgc_logits: projector's predictions on both positive and negative samples\n        - pgc_labels: contrastive labels\n        - y_q: query classifier's predictions\n\n    Shape:\n        - im_q, im_k: (minibatch, *) where * means, any number of additional dimensions\n        - labels: (minibatch, )\n        - y_q: (minibatch, `num_classes`)\n        - pgc_logits: (minibatch, 1 + `num_classes` :math:`\\times` `K`, `projection_dim`)\n        - pgc_labels: (minibatch, 1 + `num_classes` :math:`\\times` `K`)\n    \"\"\"\n\n    def __init__(self, encoder_q, encoder_k, num_classes, K=32, m=0.999, T=0.07):\n        super(SelfTuning, self).__init__()\n        self.K = K\n        self.m = m\n        self.T = T\n        self.num_classes = num_classes\n\n        # create the encoders\n        # num_classes is the output fc dimension\n        self.encoder_q = encoder_q\n        self.encoder_k = encoder_k\n\n        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):\n            param_k.data.copy_(param_q.data)\n            param_k.requires_grad = False\n\n        # create the queue\n        self.register_buffer(\"queue_list\", torch.randn(encoder_q.projection_dim, K * self.num_classes))\n        self.queue_list = normalize(self.queue_list, dim=0)\n        self.register_buffer(\"queue_ptr\", torch.zeros(self.num_classes, dtype=torch.long))\n\n    @torch.no_grad()\n    def _momentum_update_key_encoder(self):\n        \"\"\"\n        Momentum update of the key encoder\n        \"\"\"\n        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):\n            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)\n\n    @torch.no_grad()\n    def _dequeue_and_enqueue(self, h, label):\n        # gather keys before updating queue\n        batch_size = h.shape[0]\n        ptr = int(self.queue_ptr[label])\n        real_ptr = ptr + label * self.K\n        # replace the keys at ptr (dequeue and enqueue)\n        self.queue_list[:, real_ptr:real_ptr + batch_size] = h.T\n\n        # move pointer\n        ptr = (ptr + batch_size) % self.K\n        self.queue_ptr[label] = ptr\n\n    def forward(self, im_q, im_k, labels):\n        batch_size = im_q.size(0)\n        device = im_q.device\n\n        # compute query features\n        h_q, y_q = self.encoder_q(im_q)  # queries: h_q (N x projection_dim)\n\n        # compute key features\n        with torch.no_grad():  # no gradient to keys\n            self._momentum_update_key_encoder()  # update the key encoder\n            h_k, _ = self.encoder_k(im_k)  # keys: h_k (N x projection_dim)\n\n        # compute logits\n        # positive logits: Nx1\n        logits_pos = torch.einsum('nl,nl->n', [h_q, h_k]).unsqueeze(-1)  # Einstein sum is more intuitive\n\n        # cur_queue_list: queue_size * class_num\n        cur_queue_list = self.queue_list.clone().detach()\n\n        logits_neg_list = torch.Tensor([]).to(device)\n        logits_pos_list = torch.Tensor([]).to(device)\n\n        for i in range(batch_size):\n            neg_sample = torch.cat([cur_queue_list[:, 0:labels[i] * self.K],\n                                    cur_queue_list[:, (labels[i] + 1) * self.K:]],\n                                   dim=1)\n            pos_sample = cur_queue_list[:, labels[i] * self.K: (labels[i] + 1) * self.K]\n            ith_neg = torch.einsum('nl,lk->nk', [h_q[i:i + 1], neg_sample])\n            ith_pos = torch.einsum('nl,lk->nk', [h_q[i:i + 1], pos_sample])\n            logits_neg_list = torch.cat((logits_neg_list, ith_neg), dim=0)\n            logits_pos_list = torch.cat((logits_pos_list, ith_pos), dim=0)\n            self._dequeue_and_enqueue(h_k[i:i + 1], labels[i])\n\n        # logits: 1 + queue_size + queue_size * (class_num - 1)\n        pgc_logits = torch.cat([logits_pos, logits_pos_list, logits_neg_list], dim=1)\n        pgc_logits = nn.LogSoftmax(dim=1)(pgc_logits / self.T)\n\n        pgc_labels = torch.zeros([batch_size, 1 + self.K * self.num_classes]).to(device)\n        pgc_labels[:, 0:self.K + 1].fill_(1.0 / (self.K + 1))\n        return pgc_logits, pgc_labels, y_q\n"
  },
  {
    "path": "tllib/self_training/uda.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\n\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass StrongWeakConsistencyLoss(nn.Module):\n    \"\"\"\n    Consistency loss between strong and weak augmented samples from `Unsupervised Data Augmentation for\n    Consistency Training (NIPS 2020) <https://arxiv.org/pdf/1904.12848v4.pdf>`_.\n\n    Args:\n        threshold (float): Confidence threshold.\n        temperature (float): Temperature.\n\n    Inputs:\n        - y_strong: unnormalized classifier predictions on strong augmented samples.\n        - y: unnormalized classifier predictions on weak augmented samples.\n\n    Shape:\n        - y, y_strong: :math:`(minibatch, C)` where C means the number of classes.\n        - Output: scalar.\n\n    \"\"\"\n\n    def __init__(self, threshold: float, temperature: float):\n        super(StrongWeakConsistencyLoss, self).__init__()\n        self.threshold = threshold\n        self.temperature = temperature\n\n    def forward(self, y_strong, y):\n        confidence, _ = F.softmax(y.detach(), dim=1).max(dim=1)\n        mask = (confidence > self.threshold).float()\n        log_prob = F.log_softmax(y_strong / self.temperature, dim=1)\n        con_loss = (F.kl_div(log_prob, F.softmax(y.detach(), dim=1), reduction='none').sum(dim=1))\n        con_loss = (con_loss * mask).sum() / max(mask.sum(), 1)\n\n        return con_loss\n"
  },
  {
    "path": "tllib/translation/__init__.py",
    "content": ""
  },
  {
    "path": "tllib/translation/cycada.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\nfrom torch import Tensor\n\n\nclass SemanticConsistency(nn.Module):\n    \"\"\"\n    Semantic consistency loss is introduced by\n    `CyCADA: Cycle-Consistent Adversarial Domain Adaptation (ICML 2018) <https://arxiv.org/abs/1711.03213>`_\n\n    This helps to prevent label flipping during image translation.\n\n    Args:\n        ignore_index (tuple, optional): Specifies target values that are ignored\n            and do not contribute to the input gradient. When :attr:`size_average` is\n            ``True``, the loss is averaged over non-ignored targets. Default: ().\n        reduction (string, optional): Specifies the reduction to apply to the output:\n            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will\n            be applied, ``'mean'``: the weighted mean of the output is taken,\n            ``'sum'``: the output will be summed. Note: :attr:`size_average`\n            and :attr:`reduce` are in the process of being deprecated, and in\n            the meantime, specifying either of those two args will override\n            :attr:`reduction`. Default: ``'mean'``\n\n    Shape:\n        - Input: :math:`(N, C)` where `C = number of classes`, or\n          :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \\geq 1`\n          in the case of `K`-dimensional loss.\n        - Target: :math:`(N)` where each value is :math:`0 \\leq \\text{targets}[i] \\leq C-1`, or\n          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \\geq 1` in the case of\n          K-dimensional loss.\n        - Output: scalar.\n          If :attr:`reduction` is ``'none'``, then the same size as the target:\n          :math:`(N)`, or\n          :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \\geq 1` in the case\n          of K-dimensional loss.\n\n    Examples::\n\n        >>> loss = SemanticConsistency()\n        >>> input = torch.randn(3, 5, requires_grad=True)\n        >>> target = torch.empty(3, dtype=torch.long).random_(5)\n        >>> output = loss(input, target)\n        >>> output.backward()\n    \"\"\"\n    def __init__(self, ignore_index=(), reduction='mean'):\n        super(SemanticConsistency, self).__init__()\n        self.ignore_index = ignore_index\n        self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction)\n\n    def forward(self, input: Tensor, target: Tensor) -> Tensor:\n        for class_idx in self.ignore_index:\n            target[target == class_idx] = -1\n        return self.loss(input, target)\n"
  },
  {
    "path": "tllib/translation/cyclegan/__init__.py",
    "content": "from . import discriminator\nfrom . import generator\nfrom . import loss\nfrom . import transform\n\nfrom .discriminator import *\nfrom .generator import *\nfrom .loss import *\nfrom .transform import *\n"
  },
  {
    "path": "tllib/translation/cyclegan/discriminator.py",
    "content": "\"\"\"\nModified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom .util import get_norm_layer, init_weights\n\n\nclass NLayerDiscriminator(nn.Module):\n    \"\"\"Construct a PatchGAN discriminator\n\n    Args:\n        input_nc (int): the number of channels in input images.\n        ndf (int): the number of filters in the last conv layer. Default: 64\n        n_layers (int): the number of conv layers in the discriminator. Default: 3\n        norm_layer (torch.nn.Module): normalization layer. Default: :class:`nn.BatchNorm2d`\n    \"\"\"\n\n    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):\n        super(NLayerDiscriminator, self).__init__()\n        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        kw = 4\n        padw = 1\n        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, n_layers):  # gradually increase the number of filters\n            nf_mult_prev = nf_mult\n            nf_mult = min(2 ** n, 8)\n            sequence += [\n                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n                norm_layer(ndf * nf_mult),\n                nn.LeakyReLU(0.2, True)\n            ]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2 ** n_layers, 8)\n        sequence += [\n            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n            norm_layer(ndf * nf_mult),\n            nn.LeakyReLU(0.2, True)\n        ]\n\n        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map\n        self.model = nn.Sequential(*sequence)\n\n    def forward(self, input):\n        return self.model(input)\n\n\nclass PixelDiscriminator(nn.Module):\n    \"\"\"Construct a 1x1 PatchGAN discriminator (pixelGAN)\n\n    Args:\n        input_nc (int): the number of channels in input images.\n        ndf (int): the number of filters in the last conv layer. Default: 64\n        norm_layer (torch.nn.Module): normalization layer. Default: :class:`nn.BatchNorm2d`\n    \"\"\"\n\n    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):\n        super(PixelDiscriminator, self).__init__()\n        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        self.net = [\n            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),\n            nn.LeakyReLU(0.2, True),\n            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),\n            norm_layer(ndf * 2),\n            nn.LeakyReLU(0.2, True),\n            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]\n\n        self.net = nn.Sequential(*self.net)\n\n    def forward(self, input):\n        return self.net(input)\n\n\ndef patch(ndf, input_nc=3, norm='batch', n_layers=3, init_type='normal', init_gain=0.02):\n    \"\"\"\n    PatchGAN classifier described in the original pix2pix paper.\n    It can classify whether 70×70 overlapping patches are real or fake.\n    Such a patch-level discriminator architecture has fewer parameters\n    than a full-image discriminator and can work on arbitrarily-sized images\n    in a fully convolutional fashion.\n\n    Args:\n        ndf (int): the number of filters in the first conv layer\n        input_nc (int): the number of channels in input images. Default: 3\n        norm (str): the type of normalization layers used in the network. Default: 'batch'\n        n_layers (int): the number of conv layers in the discriminator. Default: 3\n        init_type (str): the name of the initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'\n        init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02\n    \"\"\"\n    norm_layer = get_norm_layer(norm_type=norm)\n    net = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers, norm_layer=norm_layer)\n    init_weights(net, init_type, init_gain=init_gain)\n    return net\n\n\ndef pixel(ndf, input_nc=3, norm='batch', init_type='normal', init_gain=0.02):\n    \"\"\"\n    1x1 PixelGAN discriminator can classify whether a pixel is real or not.\n    It encourages greater color diversity but has no effect on spatial statistics.\n\n    Args:\n        ndf (int): the number of filters in the first conv layer\n        input_nc (int): the number of channels in input images. Default: 3\n        norm (str): the type of normalization layers used in the network. Default: 'batch'\n        init_type (str): the name of the initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'\n        init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02\n    \"\"\"\n    norm_layer = get_norm_layer(norm_type=norm)\n    net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)\n    init_weights(net, init_type, init_gain=init_gain)\n    return net"
  },
  {
    "path": "tllib/translation/cyclegan/generator.py",
    "content": "\"\"\"\nModified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport functools\nfrom .util import get_norm_layer, init_weights\n\n\nclass ResnetBlock(nn.Module):\n    \"\"\"Define a Resnet block\"\"\"\n\n    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):\n        \"\"\"Initialize the Resnet block\n\n        A resnet block is a conv block with skip connections\n        We construct a conv block with build_conv_block function,\n        and implement skip connections in <forward> function.\n        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf\n        \"\"\"\n        super(ResnetBlock, self).__init__()\n        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)\n\n    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):\n        \"\"\"Construct a convolutional block.\n\n        Args:\n            dim (int): the number of channels in the conv layer.\n            padding_type (str): the name of padding layer: reflect | replicate | zero\n            norm_layer (torch.nn.Module): normalization layer\n            use_dropout (bool): if use dropout layers.\n            use_bias (bool): if the conv layer uses bias or not\n\n        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))\n        \"\"\"\n        conv_block = []\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n\n        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]\n        if use_dropout:\n            conv_block += [nn.Dropout(0.5)]\n\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]\n\n        return nn.Sequential(*conv_block)\n\n    def forward(self, x):\n        \"\"\"Forward function (with skip connections)\"\"\"\n        out = x + self.conv_block(x)  # add skip connections\n        return out\n\n\nclass ResnetGenerator(nn.Module):\n    \"\"\"Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.\n\n    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)\n\n    \"\"\"\n\n    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):\n        \"\"\"Construct a Resnet-based generator\n\n        Args:\n            input_nc (int): the number of channels in input images\n            output_nc (int): the number of channels in output images\n            ngf (int): the number of filters in the last conv layer\n            norm_layer (torch.nn.Module): normalization layer\n            use_dropout (bool): if use dropout layers\n            n_blocks (int): the number of ResNet blocks\n            padding_type (str): the name of padding layer in conv layers: reflect | replicate | zero\n        \"\"\"\n        assert(n_blocks >= 0)\n        super(ResnetGenerator, self).__init__()\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        model = [nn.ReflectionPad2d(3),\n                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),\n                 norm_layer(ngf),\n                 nn.ReLU(True)]\n\n        n_downsampling = 2\n        for i in range(n_downsampling):  # add downsampling layers\n            mult = 2 ** i\n            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),\n                      norm_layer(ngf * mult * 2),\n                      nn.ReLU(True)]\n\n        mult = 2 ** n_downsampling\n        for i in range(n_blocks):       # add ResNet blocks\n\n            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]\n\n        for i in range(n_downsampling):  # add upsampling layers\n            mult = 2 ** (n_downsampling - i)\n            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n                                         kernel_size=3, stride=2,\n                                         padding=1, output_padding=1,\n                                         bias=use_bias),\n                      norm_layer(int(ngf * mult / 2)),\n                      nn.ReLU(True)]\n        model += [nn.ReflectionPad2d(3)]\n        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]\n        model += [nn.Tanh()]\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input):\n        \"\"\"Standard forward\"\"\"\n        return self.model(input)\n\n\nclass UnetGenerator(nn.Module):\n    \"\"\"Create a Unet-based generator\"\"\"\n\n    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):\n        \"\"\"Construct a Unet generator\n        Args:\n            input_nc (int): the number of channels in input images\n            output_nc (int): the number of channels in output images\n            num_downs (int): the number of downsamplings in UNet. For example, # if |num_downs| == 7,\n                image of size 128x128 will become of size 1x1 # at the bottleneck\n            ngf (int): the number of filters in the last conv layer\n            norm_layer(torch.nn.Module): normalization layer\n\n        We construct the U-Net from the innermost layer to the outermost layer.\n        It is a recursive process.\n        \"\"\"\n        super(UnetGenerator, self).__init__()\n        # construct unet structure\n        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer\n        for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters\n            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)\n        # gradually reduce the number of filters from ngf * 8 to ngf\n        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer\n\n    def forward(self, input):\n        \"\"\"Standard forward\"\"\"\n        return self.model(input)\n\n\nclass UnetSkipConnectionBlock(nn.Module):\n    \"\"\"Defines the Unet submodule with skip connection.\n        X -------------------identity----------------------\n        |-- downsampling -- |submodule| -- upsampling --|\n    \"\"\"\n\n    def __init__(self, outer_nc, inner_nc, input_nc=None,\n                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):\n        \"\"\"Construct a Unet submodule with skip connections.\n\n        Args:\n            outer_nc (int): the number of filters in the outer conv layer\n            inner_nc (int): the number of filters in the inner conv layer\n            input_nc (int): the number of channels in input images/features\n            submodule (UnetSkipConnectionBlock): previously defined submodules\n            outermost (bool): if this module is the outermost module\n            innermost (bool): if this module is the innermost module\n            norm_layer (torch.nn.Module): normalization layer\n            use_dropout (bool): if use dropout layers.\n        \"\"\"\n        super(UnetSkipConnectionBlock, self).__init__()\n        self.outermost = outermost\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n        if input_nc is None:\n            input_nc = outer_nc\n        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,\n                             stride=2, padding=1, bias=use_bias)\n        downrelu = nn.LeakyReLU(0.2, True)\n        downnorm = norm_layer(inner_nc)\n        uprelu = nn.ReLU(True)\n        upnorm = norm_layer(outer_nc)\n\n        if outermost:\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1)\n            down = [downconv]\n            up = [uprelu, upconv, nn.Tanh()]\n            model = down + [submodule] + up\n        elif innermost:\n            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1, bias=use_bias)\n            down = [downrelu, downconv]\n            up = [uprelu, upconv, upnorm]\n            model = down + up\n        else:\n            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n                                        kernel_size=4, stride=2,\n                                        padding=1, bias=use_bias)\n            down = [downrelu, downconv, downnorm]\n            up = [uprelu, upconv, upnorm]\n\n            if use_dropout:\n                model = down + [submodule] + up + [nn.Dropout(0.5)]\n            else:\n                model = down + [submodule] + up\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, x):\n        if self.outermost:\n            return self.model(x)\n        else:   # add skip connections\n            return torch.cat([x, self.model(x)], 1)\n\n\ndef resnet_9(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,\n                       init_type='normal', init_gain=0.02):\n    \"\"\"\n    Resnet-based generator with 9 Resnet blocks.\n\n    Args:\n        ngf (int): the number of filters in the last conv layer\n        input_nc (int): the number of channels in input images. Default: 3\n        output_nc (int): the number of channels in output images. Default: 3\n        norm (str): the type of normalization layers used in the network. Default: 'batch'\n        use_dropout (bool): whether use dropout. Default: False\n        init_type (str): the name of the initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'\n        init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02\n    \"\"\"\n    norm_layer = get_norm_layer(norm_type=norm)\n    net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)\n    init_weights(net, init_type, init_gain)\n    return net\n\n\ndef resnet_6(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,\n                       init_type='normal', init_gain=0.02):\n    \"\"\"\n    Resnet-based generator with 6 Resnet blocks.\n\n    Args:\n        ngf (int): the number of filters in the last conv layer\n        input_nc (int): the number of channels in input images. Default: 3\n        output_nc (int): the number of channels in output images. Default: 3\n        norm (str): the type of normalization layers used in the network. Default: 'batch'\n        use_dropout (bool): whether use dropout. Default: False\n        init_type (str): the name of the initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'\n        init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02\n    \"\"\"\n    norm_layer = get_norm_layer(norm_type=norm)\n    net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)\n    init_weights(net, init_type, init_gain)\n    return net\n\n\ndef unet_256(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,\n             init_type='normal', init_gain=0.02):\n    \"\"\"\n    `U-Net <https://arxiv.org/abs/1505.04597>`_ generator for 256x256 input images.\n    The size of the input image should be a multiple of 256.\n\n    Args:\n        ngf (int): the number of filters in the last conv layer\n        input_nc (int): the number of channels in input images. Default: 3\n        output_nc (int): the number of channels in output images. Default: 3\n        norm (str): the type of normalization layers used in the network. Default: 'batch'\n        use_dropout (bool): whether use dropout. Default: False\n        init_type (str): the name of the initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'\n        init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02\n\n    \"\"\"\n    norm_layer = get_norm_layer(norm_type=norm)\n    net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n    init_weights(net, init_type, init_gain)\n    return net\n\n\ndef unet_128(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,\n             init_type='normal', init_gain=0.02):\n    \"\"\"\n    `U-Net <https://arxiv.org/abs/1505.04597>`_ generator for 128x128 input images.\n    The size of the input image should be a multiple of 128.\n\n    Args:\n        ngf (int): the number of filters in the last conv layer\n        input_nc (int): the number of channels in input images. Default: 3\n        output_nc (int): the number of channels in output images. Default: 3\n        norm (str): the type of normalization layers used in the network. Default: 'batch'\n        use_dropout (bool): whether use dropout. Default: False\n        init_type (str): the name of the initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'\n        init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02\n\n    \"\"\"\n    norm_layer = get_norm_layer(norm_type=norm)\n    net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n    init_weights(net, init_type, init_gain)\n    return net\n\n\ndef unet_32(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,\n             init_type='normal', init_gain=0.02):\n    \"\"\"\n    `U-Net <https://arxiv.org/abs/1505.04597>`_ generator for 32x32 input images\n\n    Args:\n        ngf (int): the number of filters in the last conv layer\n        input_nc (int): the number of channels in input images. Default: 3\n        output_nc (int): the number of channels in output images. Default: 3\n        norm (str): the type of normalization layers used in the network. Default: 'batch'\n        use_dropout (bool): whether use dropout. Default: False\n        init_type (str): the name of the initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'\n        init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02\n\n    \"\"\"\n    norm_layer = get_norm_layer(norm_type=norm)\n    net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n    init_weights(net, init_type, init_gain)\n    return net"
  },
  {
    "path": "tllib/translation/cyclegan/loss.py",
    "content": "\"\"\"\nModified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\nimport torch\n\n\nclass LeastSquaresGenerativeAdversarialLoss(nn.Module):\n    \"\"\"\n    Loss for `Least Squares Generative Adversarial Network (LSGAN) <https://arxiv.org/abs/1611.04076>`_\n\n    Args:\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Inputs:\n        - prediction (tensor): unnormalized discriminator predictions\n        - real (bool): if the ground truth label is for real images or fake images. Default: true\n\n    .. warning::\n        Do not use sigmoid as the last layer of Discriminator.\n\n    \"\"\"\n    def __init__(self, reduction='mean'):\n        super(LeastSquaresGenerativeAdversarialLoss, self).__init__()\n        self.mse_loss = nn.MSELoss(reduction=reduction)\n\n    def forward(self, prediction, real=True):\n        if real:\n            label = torch.ones_like(prediction)\n        else:\n            label = torch.zeros_like(prediction)\n        return self.mse_loss(prediction, label)\n\n\nclass VanillaGenerativeAdversarialLoss(nn.Module):\n    \"\"\"\n    Loss for `Vanilla Generative Adversarial Network <https://arxiv.org/abs/1406.2661>`_\n\n    Args:\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Inputs:\n        - prediction (tensor): unnormalized discriminator predictions\n        - real (bool): if the ground truth label is for real images or fake images. Default: true\n\n    .. warning::\n        Do not use sigmoid as the last layer of Discriminator.\n\n    \"\"\"\n    def __init__(self, reduction='mean'):\n        super(VanillaGenerativeAdversarialLoss, self).__init__()\n        self.bce_loss = nn.BCEWithLogitsLoss(reduction=reduction)\n\n    def forward(self, prediction, real=True):\n        if real:\n            label = torch.ones_like(prediction)\n        else:\n            label = torch.zeros_like(prediction)\n        return self.bce_loss(prediction, label)\n\n\nclass WassersteinGenerativeAdversarialLoss(nn.Module):\n    \"\"\"\n    Loss for `Wasserstein Generative Adversarial Network <https://arxiv.org/abs/1701.07875>`_\n\n    Args:\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``\n\n    Inputs:\n        - prediction (tensor): unnormalized discriminator predictions\n        - real (bool): if the ground truth label is for real images or fake images. Default: true\n\n    .. warning::\n        Do not use sigmoid as the last layer of Discriminator.\n\n    \"\"\"\n    def __init__(self, reduction='mean'):\n        super(WassersteinGenerativeAdversarialLoss, self).__init__()\n        self.mse_loss = nn.MSELoss(reduction=reduction)\n\n    def forward(self, prediction, real=True):\n        if real:\n            return -prediction.mean()\n        else:\n            return prediction.mean()"
  },
  {
    "path": "tllib/translation/cyclegan/transform.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torchvision.transforms as T\n\nfrom tllib.vision.transforms import Denormalize\n\n\nclass Translation(nn.Module):\n    \"\"\"\n    Image Translation Transform Module\n\n    Args:\n        generator (torch.nn.Module): An image generator, e.g. :meth:`~tllib.translation.cyclegan.resnet_9_generator`\n        device (torch.device): device to put the generator. Default: 'cpu'\n        mean (tuple): the normalized mean for image\n        std (tuple): the normalized std for image\n    Input:\n        - image (PIL.Image): raw image in shape H x W x C\n\n    Output:\n        raw image in shape H x W x 3\n\n    \"\"\"\n    def __init__(self, generator, device=torch.device(\"cpu\"), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):\n        super(Translation, self).__init__()\n        self.generator = generator.to(device)\n        self.device = device\n        self.pre_process = T.Compose([\n            T.ToTensor(),\n            T.Normalize(mean, std)\n        ])\n        self.post_process = T.Compose([\n            Denormalize(mean, std),\n            T.ToPILImage()\n        ])\n\n    def forward(self, image):\n        image = self.pre_process(image.copy())  # C x H x W\n        image = image.to(self.device)\n        generated_image = self.generator(image.unsqueeze(dim=0)).squeeze(dim=0).cpu()\n        return self.post_process(generated_image)\n"
  },
  {
    "path": "tllib/translation/cyclegan/util.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\nimport functools\nimport random\nimport torch\nfrom torch.nn import init\n\n\nclass Identity(nn.Module):\n    def forward(self, x):\n        return x\n\n\ndef get_norm_layer(norm_type='instance'):\n    \"\"\"Return a normalization layer\n\n    Parameters:\n        norm_type (str) -- the name of the normalization layer: batch | instance | none\n\n    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).\n    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.\n    \"\"\"\n    if norm_type == 'batch':\n        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)\n    elif norm_type == 'instance':\n        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)\n    elif norm_type == 'none':\n        def norm_layer(x): return Identity()\n    else:\n        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)\n    return norm_layer\n\n\ndef init_weights(net, init_type='normal', init_gain=0.02):\n    \"\"\"Initialize network weights.\n\n    Args:\n        net (torch.nn.Module): network to be initialized\n        init_type (str): the name of an initialization method. Choices includes: ``normal`` |\n            ``xavier`` | ``kaiming`` | ``orthogonal``\n        init_gain (float): scaling factor for normal, xavier and orthogonal.\n\n    'normal' is used in the original CycleGAN paper. But xavier and kaiming might\n    work better for some applications.\n    \"\"\"\n    def init_func(m):  # define the initialization function\n        classname = m.__class__.__name__\n        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n            if init_type == 'normal':\n                init.normal_(m.weight.data, 0.0, init_gain)\n            elif init_type == 'xavier':\n                init.xavier_normal_(m.weight.data, gain=init_gain)\n            elif init_type == 'kaiming':\n                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n            elif init_type == 'orthogonal':\n                init.orthogonal_(m.weight.data, gain=init_gain)\n            else:\n                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n            if hasattr(m, 'bias') and m.bias is not None:\n                init.constant_(m.bias.data, 0.0)\n        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.\n            init.normal_(m.weight.data, 1.0, init_gain)\n            init.constant_(m.bias.data, 0.0)\n\n    print('initialize network with %s' % init_type)\n    net.apply(init_func)  # apply the initialization function <init_func>\n\n\nclass ImagePool:\n    \"\"\"An image buffer that stores previously generated images.\n\n    This buffer enables us to update discriminators using a history of generated images\n    rather than the ones produced by the latest generators.\n\n    Args:\n        pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created\n\n    \"\"\"\n\n    def __init__(self, pool_size):\n        self.pool_size = pool_size\n        if self.pool_size > 0:  # create an empty pool\n            self.num_imgs = 0\n            self.images = []\n\n    def query(self, images):\n        \"\"\"Return an image from the pool.\n\n        Args:\n            images (torch.Tensor): the latest generated images from the generator\n\n        Returns:\n            By 50/100, the buffer will return input images.\n            By 50/100, the buffer will return images previously stored in the buffer,\n            and insert the current images to the buffer.\n\n        \"\"\"\n        if self.pool_size == 0:  # if the buffer size is 0, do nothing\n            return images\n        return_images = []\n        for image in images:\n            image = torch.unsqueeze(image.data, 0)\n            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer\n                self.num_imgs = self.num_imgs + 1\n                self.images.append(image)\n                return_images.append(image)\n            else:\n                p = random.uniform(0, 1)\n                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer\n                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive\n                    tmp = self.images[random_id].clone()\n                    self.images[random_id] = image\n                    return_images.append(tmp)\n                else:       # by another 50% chance, the buffer will return the current image\n                    return_images.append(image)\n        return_images = torch.cat(return_images, 0)   # collect all the images and return\n        return return_images\n\n\ndef set_requires_grad(net, requires_grad=False):\n    \"\"\"\n    Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n    \"\"\"\n    for param in net.parameters():\n        param.requires_grad = requires_grad\n"
  },
  {
    "path": "tllib/translation/fourier_transform.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport numpy as np\nimport os\nimport tqdm\nimport random\nfrom PIL import Image\nfrom typing import Optional, Sequence\nimport torch.nn as nn\n\n\ndef low_freq_mutate(amp_src: np.ndarray, amp_trg: np.ndarray, beta: Optional[int] = 1):\n    \"\"\"\n    Args:\n        amp_src (numpy.ndarray): amplitude component of the Fourier transform of source image\n        amp_trg (numpy.ndarray): amplitude component of the Fourier transform of target image\n        beta (int, optional): the size of the center region to be replace. Default: 1\n\n    Returns:\n        amplitude component of the Fourier transform of source image\n        whose low-frequency component is replaced by that of the target image.\n\n    \"\"\"\n    # Shift the zero-frequency component to the center of the spectrum.\n    a_src = np.fft.fftshift(amp_src, axes=(-2, -1))\n    a_trg = np.fft.fftshift(amp_trg, axes=(-2, -1))\n\n    # The low-frequency component includes\n    # the area where the horizontal and vertical distance from the center does not exceed beta\n    _, h, w = a_src.shape\n    c_h = np.floor(h / 2.0).astype(int)\n    c_w = np.floor(w / 2.0).astype(int)\n\n    h1 = c_h - beta\n    h2 = c_h + beta + 1\n    w1 = c_w - beta\n    w2 = c_w + beta + 1\n\n    # The low-frequency component of source amplitude is replaced by the target amplitude\n    a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2]\n    a_src = np.fft.ifftshift(a_src, axes=(-2, -1))\n    return a_src\n\n\nclass FourierTransform(nn.Module):\n    \"\"\"\n    Fourier Transform is introduced by `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) <https://arxiv.org/abs/2004.05498>`_\n\n    Fourier Transform replace the low frequency component of the amplitude of the source image to that of the target image.\n    Denote with :math:`M_{β}` a mask, whose value is zero except for the center region:\n\n    .. math::\n        M_{β}(h,w) = \\mathbb{1}_{(h, w)\\in [-β,β, -β, β]}\n\n    Given images :math:`x^s` from source domain and :math:`x^t` from target domain, the source image in the target style is\n\n    .. math::\n        x^{s→t} = \\mathcal{F}^{-1}([ M_{β}\\circ\\mathcal{F}^A(x^t) + (1-M_{β})\\circ\\mathcal{F}^A(x^s), \\mathcal{F}^P(x^s) ])\n\n    where :math:`\\mathcal{F}^A`, :math:`\\mathcal{F}^P` are the amplitude and phase component of the Fourier\n    Transform :math:`\\mathcal{F}` of an RGB image.\n\n    Args:\n        image_list (sequence[str]): A sequence of image list from the target domain.\n        amplitude_dir (str): Specifies the directory to put the amplitude component of the target image.\n        beta (int, optional): :math:`β`. Default: 1.\n        rebuild (bool, optional): whether rebuild the amplitude component of the target image in the given directory.\n\n    Inputs:\n        - image (PIL Image): image from the source domain, :math:`x^t`.\n\n    Examples:\n\n        >>> from tllib.translation.fourier_transform import FourierTransform\n        >>> image_list = [\"target_image_path1\", \"target_image_path2\"]\n        >>> amplitude_dir = \"path/to/amplitude_dir\"\n        >>> fourier_transform = FourierTransform(image_list, amplitude_dir, beta=1, rebuild=False)\n        >>> source_image = np.array((256, 256, 3)) # image form source domain\n        >>> source_image_in_target_style = fourier_transform(source_image)\n\n    .. note::\n        The meaning of :math:`β` is different from that of the origin paper. Experimentally, we found that the size of\n        the center region in the frequency space should be constant when the image size increases. Thus we make the size\n        of the center region independent of the image size. A recommended value for :math:`β` is 1.\n\n    .. note::\n        The image structure of the source domain and target domain should be as similar as possible,\n        thus for segemntation tasks, FourierTransform should be used before RandomResizeCrop and other transformations.\n\n    .. note::\n        The image size of the source domain and the target domain need to be the same, thus before FourierTransform,\n        you should use Resize to convert the source image to the target image size.\n\n    Examples:\n\n        >>> from tllib.translation.fourier_transform import FourierTransform\n        >>> import tllibvision.datasets.segmentation.transforms as T\n        >>> from PIL import Image\n        >>> target_image_list = [\"target_image_path1\", \"target_image_path2\"]\n        >>> amplitude_dir = \"path/to/amplitude_dir\"\n        >>> # build a fourier transform that translate source images to the target style\n        >>> fourier_transform = T.wrapper(FourierTransform)(target_image_list, amplitude_dir)\n        >>> transforms=T.Compose([\n        ...     # convert source image to the size of the target image before fourier transform\n        ...     T.Resize((2048, 1024)),\n        ...     fourier_transform,\n        ...     T.RandomResizedCrop((1024, 512)),\n        ...     T.RandomHorizontalFlip(),\n        ... ])\n        >>> source_image = Image.open(\"path/to/source_image\") # image form source domain\n        >>> source_image_in_target_style = transforms(source_image)\n    \"\"\"\n    # TODO add image examples when beta is different\n    def __init__(self, image_list: Sequence[str], amplitude_dir: str,\n                 beta: Optional[int] = 1, rebuild: Optional[bool] = False):\n        super(FourierTransform, self).__init__()\n        self.amplitude_dir = amplitude_dir\n        if not os.path.exists(amplitude_dir) or rebuild:\n            os.makedirs(amplitude_dir, exist_ok=True)\n            self.build_amplitude(image_list, amplitude_dir)\n        self.beta = beta\n        self.length = len(image_list)\n\n    @staticmethod\n    def build_amplitude(image_list, amplitude_dir):\n        # extract amplitudes from target domain\n        for i, image_name in enumerate(tqdm.tqdm(image_list)):\n            image = Image.open(image_name).convert('RGB')\n            image = np.asarray(image, np.float32)\n            image = image.transpose((2, 0, 1))\n            fft = np.fft.fft2(image, axes=(-2, -1))\n            amp = np.abs(fft)\n            np.save(os.path.join(amplitude_dir, \"{}.npy\".format(i)), amp)\n\n    def forward(self, image):\n        # randomly sample a target image and load its amplitude component\n        amp_trg = np.load(os.path.join(self.amplitude_dir, \"{}.npy\".format(random.randint(0, self.length-1))))\n\n        image = np.asarray(image, np.float32)\n        image = image.transpose((2, 0, 1))\n\n        # get fft, amplitude on source domain\n        fft_src = np.fft.fft2(image, axes=(-2, -1))\n        amp_src, pha_src = np.abs(fft_src), np.angle(fft_src)\n        # mutate the amplitude part of source with target\n        amp_src_ = low_freq_mutate(amp_src, amp_trg, beta=self.beta)\n\n        # mutated fft of source\n        fft_src_ = amp_src_ * np.exp(1j * pha_src)\n\n        # get the mutated image\n        src_in_trg = np.fft.ifft2(fft_src_, axes=(-2, -1))\n        src_in_trg = np.real(src_in_trg)\n\n        src_in_trg = src_in_trg.transpose((1, 2, 0))\n        src_in_trg = Image.fromarray(src_in_trg.clip(min=0, max=255).astype('uint8')).convert('RGB')\n\n        return src_in_trg\n"
  },
  {
    "path": "tllib/translation/spgan/__init__.py",
    "content": "from . import siamese\nfrom . import loss\nfrom .siamese import *\nfrom .loss import *\n"
  },
  {
    "path": "tllib/translation/spgan/loss.py",
    "content": "\"\"\"\nModified from https://github.com/Simon4Yan/eSPGAN\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch\nimport torch.nn.functional as F\n\n\nclass ContrastiveLoss(torch.nn.Module):\n    r\"\"\"Contrastive loss from `Dimensionality Reduction by Learning an Invariant Mapping (CVPR 2006)\n    <http://www.cs.toronto.edu/~hinton/csc2535/readings/hadsell-chopra-lecun-06-1.pdf>`_.\n\n    Given output features :math:`f_1, f_2`, we use :math:`D` to denote the pairwise euclidean distance between them,\n    :math:`Y` to denote the ground truth labels, :math:`m` to denote a pre-defined margin, then contrastive loss is\n    calculated as\n\n    .. math::\n        (1 - Y)\\frac{1}{2}D^2 + (Y)\\frac{1}{2}\\{\\text{max}(0, m-D)^2\\}\n\n    Args:\n        margin (float, optional): margin for contrastive loss. Default: 2.0\n\n    Inputs:\n        - output1 (tensor): feature representations of the first set of samples (:math:`f_1` here).\n        - output2 (tensor): feature representations of the second set of samples (:math:`f_2` here).\n        - label (tensor): labels (:math:`Y` here).\n\n    Shape:\n        - output1, output2: :math:`(minibatch, F)` where F means the dimension of input features.\n        - label: :math:`(minibatch, )`\n    \"\"\"\n    def __init__(self, margin=2.0):\n        super(ContrastiveLoss, self).__init__()\n        self.margin = margin\n\n    def forward(self, output1, output2, label):\n        euclidean_distance = F.pairwise_distance(output1, output2)\n        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +\n                          label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))\n\n        return loss\n"
  },
  {
    "path": "tllib/translation/spgan/siamese.py",
    "content": "\"\"\"\nModified from https://github.com/Simon4Yan/eSPGAN\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ConvBlock(nn.Module):\n    \"\"\"Basic block with structure Conv-LeakyReLU->Pool\"\"\"\n    def __init__(self, in_dim, out_dim):\n        super(ConvBlock, self).__init__()\n        self.conv_block = nn.Sequential(\n            nn.Conv2d(in_dim, out_dim, kernel_size=4, stride=2, padding=1),\n            nn.LeakyReLU(0.2),\n            nn.MaxPool2d(kernel_size=2, stride=2)\n        )\n\n    def forward(self, x):\n        return self.conv_block(x)\n\n\nclass SiameseNetwork(nn.Module):\n    \"\"\"Siamese network whose input is an image of shape :math:`(3,H,W)` and output is an one-dimensional feature vector.\n\n    Args:\n        nsf (int): dimension of output feature representation.\n    \"\"\"\n    def __init__(self, nsf=64):\n        super(SiameseNetwork, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(3, nsf, kernel_size=4, stride=2, padding=1),\n            nn.LeakyReLU(0.2),\n            nn.MaxPool2d(kernel_size=2, stride=2),\n            ConvBlock(nsf, nsf * 2),\n            ConvBlock(nsf * 2, nsf * 4),\n        )\n        self.flatten = nn.Flatten()\n        self.fc1 = nn.Linear(2048, nsf * 2, bias=False)\n        self.leaky_relu = nn.LeakyReLU(0.2)\n        self.dropout = nn.Dropout(0.5)\n        self.fc2 = nn.Linear(nsf * 2, nsf, bias=False)\n\n    def forward(self, x):\n        x = self.flatten(self.conv(x))\n        x = self.fc1(x)\n        x = self.leaky_relu(x)\n        x = self.dropout(x)\n        x = self.fc2(x)\n        x = F.normalize(x)\n        return x\n"
  },
  {
    "path": "tllib/utils/__init__.py",
    "content": "from .logger import CompleteLogger\nfrom .meter import *\nfrom .data import ForeverDataIterator\n\n__all__ = ['metric', 'analysis', 'meter', 'data', 'logger']"
  },
  {
    "path": "tllib/utils/analysis/__init__.py",
    "content": "import torch\nfrom torch.utils.data import DataLoader\nimport torch.nn as nn\nimport tqdm\n\n\ndef collect_feature(data_loader: DataLoader, feature_extractor: nn.Module,\n                    device: torch.device, max_num_features=None) -> torch.Tensor:\n    \"\"\"\n    Fetch data from `data_loader`, and then use `feature_extractor` to collect features\n\n    Args:\n        data_loader (torch.utils.data.DataLoader): Data loader.\n        feature_extractor (torch.nn.Module): A feature extractor.\n        device (torch.device)\n        max_num_features (int): The max number of features to return\n\n    Returns:\n        Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\\mathcal{F}|`).\n    \"\"\"\n    feature_extractor.eval()\n    all_features = []\n    with torch.no_grad():\n        for i, data in enumerate(tqdm.tqdm(data_loader)):\n            if max_num_features is not None and i >= max_num_features:\n                break\n            inputs = data[0].to(device)\n            feature = feature_extractor(inputs).cpu()\n            all_features.append(feature)\n    return torch.cat(all_features, dim=0)\n"
  },
  {
    "path": "tllib/utils/analysis/a_distance.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom torch.utils.data import TensorDataset\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader\nfrom torch.optim import SGD\nfrom ..meter import AverageMeter\nfrom ..metric import binary_accuracy\n\n\nclass ANet(nn.Module):\n    def __init__(self, in_feature):\n        super(ANet, self).__init__()\n        self.layer = nn.Linear(in_feature, 1)\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, x):\n        x = self.layer(x)\n        x = self.sigmoid(x)\n        return x\n\n\ndef calculate(source_feature: torch.Tensor, target_feature: torch.Tensor,\n              device, progress=True, training_epochs=10):\n    \"\"\"\n    Calculate the :math:`\\mathcal{A}`-distance, which is a measure for distribution discrepancy.\n\n    The definition is :math:`dist_\\mathcal{A} = 2 (1-2\\epsilon)`, where :math:`\\epsilon` is the\n    test error of a classifier trained to discriminate the source from the target.\n\n    Args:\n        source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`\n        target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`\n        device (torch.device)\n        progress (bool): if True, displays a the progress of training A-Net\n        training_epochs (int): the number of epochs when training the classifier\n\n    Returns:\n        :math:`\\mathcal{A}`-distance\n    \"\"\"\n    source_label = torch.ones((source_feature.shape[0], 1))\n    target_label = torch.zeros((target_feature.shape[0], 1))\n    feature = torch.cat([source_feature, target_feature], dim=0)\n    label = torch.cat([source_label, target_label], dim=0)\n\n    dataset = TensorDataset(feature, label)\n    length = len(dataset)\n    train_size = int(0.8 * length)\n    val_size = length - train_size\n    train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])\n    train_loader = DataLoader(train_set, batch_size=2, shuffle=True)\n    val_loader = DataLoader(val_set, batch_size=8, shuffle=False)\n\n    anet = ANet(feature.shape[1]).to(device)\n    optimizer = SGD(anet.parameters(), lr=0.01)\n    a_distance = 2.0\n    for epoch in range(training_epochs):\n        anet.train()\n        for (x, label) in train_loader:\n            x = x.to(device)\n            label = label.to(device)\n            anet.zero_grad()\n            y = anet(x)\n            loss = F.binary_cross_entropy(y, label)\n            loss.backward()\n            optimizer.step()\n\n        anet.eval()\n        meter = AverageMeter(\"accuracy\", \":4.2f\")\n        with torch.no_grad():\n            for (x, label) in val_loader:\n                x = x.to(device)\n                label = label.to(device)\n                y = anet(x)\n                acc = binary_accuracy(y, label)\n                meter.update(acc, x.shape[0])\n        error = 1 - meter.avg / 100\n        a_distance = 2 * (1 - 2 * error)\n        if progress:\n            print(\"epoch {} accuracy: {} A-dist: {}\".format(epoch, meter.avg, a_distance))\n\n    return a_distance\n\n"
  },
  {
    "path": "tllib/utils/analysis/tsne.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch\nimport matplotlib\n\nmatplotlib.use('Agg')\nfrom sklearn.manifold import TSNE\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.colors as col\n\n\ndef visualize(source_feature: torch.Tensor, target_feature: torch.Tensor,\n              filename: str, source_color='r', target_color='b'):\n    \"\"\"\n    Visualize features from different domains using t-SNE.\n\n    Args:\n        source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`\n        target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`\n        filename (str): the file name to save t-SNE\n        source_color (str): the color of the source features. Default: 'r'\n        target_color (str): the color of the target features. Default: 'b'\n\n    \"\"\"\n    source_feature = source_feature.numpy()\n    target_feature = target_feature.numpy()\n    features = np.concatenate([source_feature, target_feature], axis=0)\n\n    # map features to 2-d using TSNE\n    X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)\n\n    # domain labels, 1 represents source while 0 represents target\n    domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature))))\n\n    # visualize using matplotlib\n    fig, ax = plt.subplots(figsize=(10, 10))\n    ax.spines['top'].set_visible(False)\n    ax.spines['right'].set_visible(False)\n    ax.spines['bottom'].set_visible(False)\n    ax.spines['left'].set_visible(False)\n    plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20)\n    plt.xticks([])\n    plt.yticks([])\n    plt.savefig(filename)\n"
  },
  {
    "path": "tllib/utils/data.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport itertools\nimport random\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import Sampler\nfrom torch.utils.data import DataLoader, Dataset\nfrom typing import TypeVar, Iterable, Dict, List\n\nT_co = TypeVar('T_co', covariant=True)\nT = TypeVar('T')\n\n\ndef send_to_device(tensor, device):\n    \"\"\"\n    Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.\n\n    Args:\n        tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):\n            The data to send to a given device.\n        device (:obj:`torch.device`):\n            The device to send the data to\n\n    Returns:\n        The same data structure as :obj:`tensor` with all tensors sent to the proper device.\n    \"\"\"\n    if isinstance(tensor, (list, tuple)):\n        return type(tensor)(send_to_device(t, device) for t in tensor)\n    elif isinstance(tensor, dict):\n        return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()})\n    elif not hasattr(tensor, \"to\"):\n        return tensor\n    return tensor.to(device)\n\n\nclass ForeverDataIterator:\n    r\"\"\"A data iterator that will never stop producing data\"\"\"\n\n    def __init__(self, data_loader: DataLoader, device=None):\n        self.data_loader = data_loader\n        self.iter = iter(self.data_loader)\n        self.device = device\n\n    def __next__(self):\n        try:\n            data = next(self.iter)\n            if self.device is not None:\n                data = send_to_device(data, self.device)\n        except StopIteration:\n            self.iter = iter(self.data_loader)\n            data = next(self.iter)\n            if self.device is not None:\n                data = send_to_device(data, self.device)\n        return data\n\n    def __len__(self):\n        return len(self.data_loader)\n\n\nclass RandomMultipleGallerySampler(Sampler):\n    r\"\"\"Sampler from `In defense of the Triplet Loss for Person Re-Identification\n    (ICCV 2017) <https://arxiv.org/pdf/1703.07737v2.pdf>`_. Assume there are :math:`N` identities in the dataset, this\n    implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\\times K`. During\n    training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees\n    every image in the dataset will eventually be selected and we are not wasting any training data.\n\n    Args:\n        dataset(list): each element of this list is a tuple (image_path, person_id, camera_id)\n        num_instances(int, optional): number of images to sample for every identity (:math:`K` here)\n    \"\"\"\n\n    def __init__(self, dataset, num_instances=4):\n        super(RandomMultipleGallerySampler, self).__init__(dataset)\n        self.dataset = dataset\n        self.num_instances = num_instances\n\n        self.idx_to_pid = {}\n        self.cid_list_per_pid = {}\n        self.idx_list_per_pid = {}\n\n        for idx, (_, pid, cid) in enumerate(dataset):\n            if pid not in self.cid_list_per_pid:\n                self.cid_list_per_pid[pid] = []\n                self.idx_list_per_pid[pid] = []\n\n            self.idx_to_pid[idx] = pid\n            self.cid_list_per_pid[pid].append(cid)\n            self.idx_list_per_pid[pid].append(idx)\n\n        self.pid_list = list(self.idx_list_per_pid.keys())\n        self.num_samples = len(self.pid_list)\n\n    def __len__(self):\n        return self.num_samples * self.num_instances\n\n    def __iter__(self):\n        def select_idxes(element_list, target_element):\n            assert isinstance(element_list, list)\n            return [i for i, element in enumerate(element_list) if element != target_element]\n\n        pid_idxes = torch.randperm(len(self.pid_list)).tolist()\n        final_idxes = []\n\n        for perm_id in pid_idxes:\n            i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]])\n            _, _, cid = self.dataset[i]\n\n            final_idxes.append(i)\n\n            pid_i = self.idx_to_pid[i]\n            cid_list = self.cid_list_per_pid[pid_i]\n            idx_list = self.idx_list_per_pid[pid_i]\n            selected_cid_list = select_idxes(cid_list, cid)\n\n            if selected_cid_list:\n                if len(selected_cid_list) >= self.num_instances:\n                    cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False)\n                else:\n                    cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True)\n                for cid_idx in cid_idxes:\n                    final_idxes.append(idx_list[cid_idx])\n            else:\n                selected_idxes = select_idxes(idx_list, i)\n                if not selected_idxes:\n                    continue\n                if len(selected_idxes) >= self.num_instances:\n                    pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False)\n                else:\n                    pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True)\n\n                for pid_idx in pid_idxes:\n                    final_idxes.append(idx_list[pid_idx])\n\n        return iter(final_idxes)\n\n\nclass CombineDataset(Dataset[T_co]):\n    r\"\"\"Dataset as a combination of multiple datasets.\n    The element of each dataset must be a list, and the i-th element of the combined dataset\n    is a list splicing of the i-th element of each sub dataset.\n    The length of the combined dataset is the minimum of the lengths of all sub datasets.\n\n    Arguments:\n        datasets (sequence): List of datasets to be concatenated\n    \"\"\"\n\n    def __init__(self, datasets: Iterable[Dataset]) -> None:\n        super(CombineDataset, self).__init__()\n        # Cannot verify that datasets is Sized\n        assert len(datasets) > 0, 'datasets should not be an empty iterable'  # type: ignore\n        self.datasets = list(datasets)\n\n    def __len__(self):\n        return min([len(d) for d in self.datasets])\n\n    def __getitem__(self, idx):\n        return list(itertools.chain(*[d[idx] for d in self.datasets]))\n\n\ndef concatenate(tensors):\n    \"\"\"concatenate multiple batches into one batch.\n    ``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format.\n    \"\"\"\n    if isinstance(tensors[0], torch.Tensor):\n        return torch.cat(tensors, dim=0)\n    elif isinstance(tensors[0], List):\n        ret = []\n        for i in range(len(tensors[0])):\n            ret.append(concatenate([t[i] for t in tensors]))\n        return ret\n    elif isinstance(tensors[0], Dict):\n        ret = dict()\n        for k in tensors[0].keys():\n            ret[k] = concatenate([t[k] for t in tensors])\n        return ret\n"
  },
  {
    "path": "tllib/utils/logger.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nimport sys\nimport time\n\nclass TextLogger(object):\n    \"\"\"Writes stream output to external text file.\n\n    Args:\n        filename (str): the file to write stream output\n        stream: the stream to read from. Default: sys.stdout\n    \"\"\"\n    def __init__(self, filename, stream=sys.stdout):\n        self.terminal = stream\n        self.log = open(filename, 'a')\n\n    def write(self, message):\n        self.terminal.write(message)\n        self.log.write(message)\n        self.flush()\n\n    def flush(self):\n        self.terminal.flush()\n        self.log.flush()\n\n    def close(self):\n        self.terminal.close()\n        self.log.close()\n\n\nclass CompleteLogger:\n    \"\"\"\n    A useful logger that\n\n    - writes outputs to files and displays them on the console at the same time.\n    - manages the directory of checkpoints and debugging images.\n\n    Args:\n        root (str): the root directory of logger\n        phase (str): the phase of training.\n\n    \"\"\"\n\n    def __init__(self, root, phase='train'):\n        self.root = root\n        self.phase = phase\n        self.visualize_directory = os.path.join(self.root, \"visualize\")\n        self.checkpoint_directory = os.path.join(self.root, \"checkpoints\")\n        self.epoch = 0\n\n        os.makedirs(self.root, exist_ok=True)\n        os.makedirs(self.visualize_directory, exist_ok=True)\n        os.makedirs(self.checkpoint_directory, exist_ok=True)\n\n        # redirect std out\n        now = time.strftime(\"%Y-%m-%d-%H_%M_%S\", time.localtime(time.time()))\n        log_filename = os.path.join(self.root, \"{}-{}.txt\".format(phase, now))\n        if os.path.exists(log_filename):\n            os.remove(log_filename)\n        self.logger = TextLogger(log_filename)\n        sys.stdout = self.logger\n        sys.stderr = self.logger\n        if phase != 'train':\n            self.set_epoch(phase)\n\n    def set_epoch(self, epoch):\n        \"\"\"Set the epoch number. Please use it during training.\"\"\"\n        os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True)\n        self.epoch = epoch\n\n    def _get_phase_or_epoch(self):\n        if self.phase == 'train':\n            return str(self.epoch)\n        else:\n            return self.phase\n\n    def get_image_path(self, filename: str):\n        \"\"\"\n        Get the full image path for a specific filename\n        \"\"\"\n        return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename)\n\n    def get_checkpoint_path(self, name=None):\n        \"\"\"\n        Get the full checkpoint path.\n\n        Args:\n            name (optional): the filename (without file extension) to save checkpoint.\n                If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``.\n                Otherwise, will be saved to ``{phase}.pth``.\n\n        \"\"\"\n        if name is None:\n            name = self._get_phase_or_epoch()\n        name = str(name)\n        return os.path.join(self.checkpoint_directory, name + \".pth\")\n\n    def close(self):\n        self.logger.close()\n"
  },
  {
    "path": "tllib/utils/meter.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, List\n\n\nclass AverageMeter(object):\n    r\"\"\"Computes and stores the average and current value.\n\n    Examples::\n\n        >>> # Initialize a meter to record loss\n        >>> losses = AverageMeter()\n        >>> # Update meter after every minibatch update\n        >>> losses.update(loss_value, batch_size)\n    \"\"\"\n    def __init__(self, name: str, fmt: Optional[str] = ':f'):\n        self.name = name\n        self.fmt = fmt\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        if self.count > 0:\n            self.avg = self.sum / self.count\n\n    def __str__(self):\n        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n        return fmtstr.format(**self.__dict__)\n\n\nclass AverageMeterDict(object):\n    def __init__(self, names: List, fmt: Optional[str] = ':f'):\n        self.dict = {\n            name: AverageMeter(name, fmt) for name in names\n        }\n\n    def reset(self):\n        for meter in self.dict.values():\n            meter.reset()\n\n    def update(self, accuracies, n=1):\n        for name, acc in accuracies.items():\n            self.dict[name].update(acc, n)\n\n    def average(self):\n        return {\n            name: meter.avg for name, meter in self.dict.items()\n        }\n\n    def __getitem__(self, item):\n        return self.dict[item]\n\n\nclass Meter(object):\n    \"\"\"Computes and stores the current value.\"\"\"\n    def __init__(self, name: str, fmt: Optional[str] = ':f'):\n        self.name = name\n        self.fmt = fmt\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n\n    def update(self, val):\n        self.val = val\n\n    def __str__(self):\n        fmtstr = '{name} {val' + self.fmt + '}'\n        return fmtstr.format(**self.__dict__)\n\n\nclass ProgressMeter(object):\n    def __init__(self, num_batches, meters, prefix=\"\"):\n        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n        self.meters = meters\n        self.prefix = prefix\n\n    def display(self, batch):\n        entries = [self.prefix + self.batch_fmtstr.format(batch)]\n        entries += [str(meter) for meter in self.meters]\n        print('\\t'.join(entries))\n\n    def _get_batch_fmtstr(self, num_batches):\n        num_digits = len(str(num_batches // 1))\n        fmt = '{:' + str(num_digits) + 'd}'\n        return '[' + fmt + '/' + fmt.format(num_batches) + ']'\n\n\n"
  },
  {
    "path": "tllib/utils/metric/__init__.py",
    "content": "import torch\nimport prettytable\n\n__all__ = ['keypoint_detection']\n\ndef binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:\n    \"\"\"Computes the accuracy for binary classification\"\"\"\n    with torch.no_grad():\n        batch_size = target.size(0)\n        pred = (output >= 0.5).float().t().view(-1)\n        correct = pred.eq(target.view(-1)).float().sum()\n        correct.mul_(100. / batch_size)\n        return correct\n\n\ndef accuracy(output, target, topk=(1,)):\n    r\"\"\"\n    Computes the accuracy over the k top predictions for the specified values of k\n\n    Args:\n        output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes`\n        target (tensor): :math:`(N)` where each value is :math:`0 \\leq \\text{targets}[i] \\leq C-1`\n        topk (sequence[int]): A list of top-N number.\n\n    Returns:\n        Top-N accuracies (N :math:`\\in` topK).\n    \"\"\"\n    with torch.no_grad():\n        maxk = max(topk)\n        batch_size = target.size(0)\n\n        _, pred = output.topk(maxk, 1, True, True)\n        pred = pred.t()\n        correct = pred.eq(target[None])\n\n        res = []\n        for k in topk:\n            correct_k = correct[:k].flatten().sum(dtype=torch.float32)\n            res.append(correct_k * (100.0 / batch_size))\n        return res\n\n\nclass ConfusionMatrix(object):\n    def __init__(self, num_classes):\n        self.num_classes = num_classes\n        self.mat = None\n\n    def update(self, target, output):\n        \"\"\"\n        Update confusion matrix.\n\n        Args:\n            target: ground truth\n            output: predictions of models\n\n        Shape:\n            - target: :math:`(minibatch, C)` where C means the number of classes.\n            - output: :math:`(minibatch, C)` where C means the number of classes.\n        \"\"\"\n        n = self.num_classes\n        if self.mat is None:\n            self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)\n        with torch.no_grad():\n            k = (target >= 0) & (target < n)\n            inds = n * target[k].to(torch.int64) + output[k]\n            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)\n\n    def reset(self):\n        self.mat.zero_()\n\n    def compute(self):\n        \"\"\"compute global accuracy, per-class accuracy and per-class IoU\"\"\"\n        h = self.mat.float()\n        acc_global = torch.diag(h).sum() / h.sum()\n        acc = torch.diag(h) / h.sum(1)\n        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))\n        return acc_global, acc, iu\n\n    # def reduce_from_all_processes(self):\n    #     if not torch.distributed.is_available():\n    #         return\n    #     if not torch.distributed.is_initialized():\n    #         return\n    #     torch.distributed.barrier()\n    #     torch.distributed.all_reduce(self.mat)\n\n    def __str__(self):\n        acc_global, acc, iu = self.compute()\n        return (\n            'global correct: {:.1f}\\n'\n            'average row correct: {}\\n'\n            'IoU: {}\\n'\n            'mean IoU: {:.1f}').format(\n                acc_global.item() * 100,\n                ['{:.1f}'.format(i) for i in (acc * 100).tolist()],\n                ['{:.1f}'.format(i) for i in (iu * 100).tolist()],\n                iu.mean().item() * 100)\n\n    def format(self, classes: list):\n        \"\"\"Get the accuracy and IoU for each class in the table format\"\"\"\n        acc_global, acc, iu = self.compute()\n\n        table = prettytable.PrettyTable([\"class\", \"acc\", \"iou\"])\n        for i, class_name, per_acc, per_iu in zip(range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()):\n            table.add_row([class_name, per_acc, per_iu])\n\n        return 'global correct: {:.1f}\\nmean correct:{:.1f}\\nmean IoU: {:.1f}\\n{}'.format(\n            acc_global.item() * 100, acc.mean().item() * 100, iu.mean().item() * 100, table.get_string())\n\n"
  },
  {
    "path": "tllib/utils/metric/keypoint_detection.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\n# TODO: add documentation\nimport numpy as np\n\n\ndef get_max_preds(batch_heatmaps):\n    '''\n    get predictions from score maps\n    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])\n    '''\n    assert isinstance(batch_heatmaps, np.ndarray), \\\n        'batch_heatmaps should be numpy.ndarray'\n    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'\n\n    batch_size = batch_heatmaps.shape[0]\n    num_joints = batch_heatmaps.shape[1]\n    width = batch_heatmaps.shape[3]\n    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))\n    idx = np.argmax(heatmaps_reshaped, 2)\n    maxvals = np.amax(heatmaps_reshaped, 2)\n\n    maxvals = maxvals.reshape((batch_size, num_joints, 1))\n    idx = idx.reshape((batch_size, num_joints, 1))\n\n    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)\n\n    preds[:, :, 0] = (preds[:, :, 0]) % width\n    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)\n\n    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))\n    pred_mask = pred_mask.astype(np.float32)\n\n    preds *= pred_mask\n    return preds, maxvals\n\n\ndef calc_dists(preds, target, normalize):\n    preds = preds.astype(np.float32)\n    target = target.astype(np.float32)\n    dists = np.zeros((preds.shape[1], preds.shape[0]))\n    for n in range(preds.shape[0]):\n        for c in range(preds.shape[1]):\n            if target[n, c, 0] > 1 and target[n, c, 1] > 1:\n                normed_preds = preds[n, c, :] / normalize[n]\n                normed_targets = target[n, c, :] / normalize[n]\n                dists[c, n] = np.linalg.norm(normed_preds - normed_targets)\n            else:\n                dists[c, n] = -1\n    return dists\n\n\ndef dist_acc(dists, thr=0.5):\n    ''' Return percentage below threshold while ignoring values with a -1 '''\n    dist_cal = np.not_equal(dists, -1)\n    num_dist_cal = dist_cal.sum()\n    if num_dist_cal > 0:\n        return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal\n    else:\n        return -1\n\n\ndef accuracy(output, target, hm_type='gaussian', thr=0.5):\n    '''\n    Calculate accuracy according to PCK,\n    but uses ground truth heatmap rather than x,y locations\n    First value to be returned is average accuracy across 'idxs',\n    followed by individual accuracies\n    '''\n    idx = list(range(output.shape[1]))\n    norm = 1.0\n    if hm_type == 'gaussian':\n        pred, _ = get_max_preds(output)\n        target, _ = get_max_preds(target)\n        h = output.shape[2]\n        w = output.shape[3]\n        norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10\n    dists = calc_dists(pred, target, norm)\n\n    acc = np.zeros(len(idx))\n    avg_acc = 0\n    cnt = 0\n\n    for i in range(len(idx)):\n        acc[i] = dist_acc(dists[idx[i]], thr)\n        if acc[i] >= 0:\n            avg_acc = avg_acc + acc[i]\n            cnt += 1\n\n    avg_acc = avg_acc / cnt if cnt != 0 else 0\n\n    return acc, avg_acc, cnt, pred\n"
  },
  {
    "path": "tllib/utils/metric/reid.py",
    "content": "# TODO: add documentation\n\"\"\"\nModified from https://github.com/yxgeee/MMT\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport os\nimport os.path as osp\nfrom collections import defaultdict\nimport time\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import average_precision_score\nfrom tllib.utils.meter import AverageMeter, ProgressMeter\n\n\ndef unique_sample(ids_dict, num):\n    \"\"\"Randomly choose one instance for each person id, these instances will not be selected again\"\"\"\n    mask = np.zeros(num, dtype=np.bool)\n    for _, indices in ids_dict.items():\n        i = np.random.choice(indices)\n        mask[i] = True\n    return mask\n\n\ndef cmc(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams, topk=100, separate_camera_set=False,\n        single_gallery_shot=False, first_match_break=False):\n    \"\"\"Compute Cumulative Matching Characteristics (CMC)\"\"\"\n    dist_mat = dist_mat.cpu().numpy()\n    m, n = dist_mat.shape\n    query_ids = np.asarray(query_ids)\n    gallery_ids = np.asarray(gallery_ids)\n    query_cams = np.asarray(query_cams)\n    gallery_cams = np.asarray(gallery_cams)\n    # Sort and find correct matches\n    indices = np.argsort(dist_mat, axis=1)\n    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])\n    # Compute CMC for each query\n    ret = np.zeros(topk)\n    num_valid_queries = 0\n    for i in range(m):\n        # Filter out the same id and same camera\n        valid = ((gallery_ids[indices[i]] != query_ids[i]) |\n                 (gallery_cams[indices[i]] != query_cams[i]))\n        if separate_camera_set:\n            # Filter out samples from same camera\n            valid &= (gallery_cams[indices[i]] != query_cams[i])\n        if not np.any(matches[i, valid]): continue\n        if single_gallery_shot:\n            repeat = 10\n            gids = gallery_ids[indices[i][valid]]\n            inds = np.where(valid)[0]\n            ids_dict = defaultdict(list)\n            for j, x in zip(inds, gids):\n                ids_dict[x].append(j)\n        else:\n            repeat = 1\n        for _ in range(repeat):\n            if single_gallery_shot:\n                # Randomly choose one instance for each id\n                sampled = (valid & unique_sample(ids_dict, len(valid)))\n                index = np.nonzero(matches[i, sampled])[0]\n            else:\n                index = np.nonzero(matches[i, valid])[0]\n            delta = 1. / (len(index) * repeat)\n            for j, k in enumerate(index):\n                if k - j >= topk: break\n                if first_match_break:\n                    ret[k - j] += 1\n                    break\n                ret[k - j] += delta\n        num_valid_queries += 1\n    if num_valid_queries == 0:\n        raise RuntimeError(\"No valid query\")\n    return ret.cumsum() / num_valid_queries\n\n\ndef mean_ap(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams):\n    \"\"\"Compute mean average precision (mAP)\"\"\"\n    dist_mat = dist_mat.cpu().numpy()\n    m, n = dist_mat.shape\n    query_ids = np.asarray(query_ids)\n    gallery_ids = np.asarray(gallery_ids)\n    query_cams = np.asarray(query_cams)\n    gallery_cams = np.asarray(gallery_cams)\n    # Sort and find correct matches\n    indices = np.argsort(dist_mat, axis=1)\n    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])\n    # Compute AP for each query\n    aps = []\n    for i in range(m):\n        # Filter out the same id and same camera\n        valid = ((gallery_ids[indices[i]] != query_ids[i]) |\n                 (gallery_cams[indices[i]] != query_cams[i]))\n        y_true = matches[i, valid]\n        y_score = -dist_mat[i][indices[i]][valid]\n        if not np.any(y_true): continue\n        aps.append(average_precision_score(y_true, y_score))\n    if len(aps) == 0:\n        raise RuntimeError(\"No valid query\")\n    return np.mean(aps)\n\n\ndef re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):\n    \"\"\"Perform re-ranking with distance matrix between query and gallery images `q_g_dist`, distance matrix between\n    query and query images `q_q_dist` and distance matrix between gallery and gallery images `g_g_dist`.\n    \"\"\"\n    q_g_dist = q_g_dist.cpu().numpy()\n    q_q_dist = q_q_dist.cpu().numpy()\n    g_g_dist = g_g_dist.cpu().numpy()\n\n    original_dist = np.concatenate(\n        [np.concatenate([q_q_dist, q_g_dist], axis=1),\n         np.concatenate([q_g_dist.T, g_g_dist], axis=1)],\n        axis=0)\n    original_dist = np.power(original_dist, 2).astype(np.float32)\n    original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0))\n    V = np.zeros_like(original_dist).astype(np.float32)\n    initial_rank = np.argsort(original_dist).astype(np.int32)\n\n    query_num = q_g_dist.shape[0]\n    gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]\n    all_num = gallery_num\n\n    for i in range(all_num):\n        # k-reciprocal neighbors\n        forward_k_neigh_index = initial_rank[i, :k1 + 1]\n        backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]\n        fi = np.where(backward_k_neigh_index == i)[0]\n        k_reciprocal_index = forward_k_neigh_index[fi]\n        k_reciprocal_expansion_index = k_reciprocal_index\n        for j in range(len(k_reciprocal_index)):\n            candidate = k_reciprocal_index[j]\n            candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2.)) + 1]\n            candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,\n                                               :int(np.around(k1 / 2.)) + 1]\n            fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]\n            candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]\n            if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2. / 3 * len(\n                    candidate_k_reciprocal_index):\n                k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)\n\n        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)\n        weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])\n        V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight)\n    original_dist = original_dist[:query_num, ]\n    if k2 != 1:\n        V_qe = np.zeros_like(V, dtype=np.float32)\n        for i in range(all_num):\n            V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)\n        V = V_qe\n        del V_qe\n    del initial_rank\n    invIndex = []\n    for i in range(gallery_num):\n        invIndex.append(np.where(V[:, i] != 0)[0])\n\n    jaccard_dist = np.zeros_like(original_dist, dtype=np.float32)\n\n    for i in range(query_num):\n        temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32)\n        indNonZero = np.where(V[i, :] != 0)[0]\n        indImages = [invIndex[ind] for ind in indNonZero]\n        for j in range(len(indNonZero)):\n            temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],\n                                                                               V[indImages[j], indNonZero[j]])\n        jaccard_dist[i] = 1 - temp_min / (2. - temp_min)\n\n    final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value\n    del original_dist\n    del V\n    del jaccard_dist\n    final_dist = final_dist[:query_num, query_num:]\n    return final_dist\n\n\ndef extract_reid_feature(data_loader, model, device, normalize, print_freq=200):\n    \"\"\"Extract feature for person ReID. If `normalize` is True, `cosine` distance will be employed as distance\n    metric, otherwise `euclidean` distance.\n    \"\"\"\n    batch_time = AverageMeter('Time', ':6.3f')\n    progress = ProgressMeter(\n        len(data_loader),\n        [batch_time],\n        prefix='Collect feature: ')\n\n    # switch to eval mode\n    model.eval()\n    feature_dict = dict()\n\n    with torch.no_grad():\n        end = time.time()\n        for i, (images_batch, filenames_batch, _, _) in enumerate(data_loader):\n\n            images_batch = images_batch.to(device)\n            features_batch = model(images_batch)\n            if normalize:\n                features_batch = F.normalize(features_batch)\n\n            for filename, feature in zip(filenames_batch, features_batch):\n                feature_dict[filename] = feature\n\n            # measure elapsed time\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if i % print_freq == 0:\n                progress.display(i)\n\n    return feature_dict\n\n\ndef pairwise_distance(feature_dict, query, gallery):\n    \"\"\"Compute pairwise distance between two sets of features\"\"\"\n\n    # concat features and convert to pytorch tensor\n    # we compute pairwise distance metric on cpu because it may require a large amount of GPU memory, if you are using\n    # gpu with a larger capacity, it's faster to calculate on gpu\n    x = torch.cat([feature_dict[f].unsqueeze(0) for f, _, _ in query], dim=0).cpu()\n    y = torch.cat([feature_dict[f].unsqueeze(0) for f, _, _ in gallery], dim=0).cpu()\n    m, n = x.size(0), y.size(0)\n    # flatten\n    x = x.view(m, -1)\n    y = y.view(n, -1)\n    # compute dist_mat\n    dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \\\n               torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() - \\\n               2 * torch.matmul(x, y.t())\n    return dist_mat\n\n\ndef evaluate_all(dist_mat, query, gallery, cmc_topk=(1, 5, 10), cmc_flag=False):\n    \"\"\"Compute CMC score, mAP and return\"\"\"\n    query_ids = [pid for _, pid, _ in query]\n    gallery_ids = [pid for _, pid, _ in gallery]\n    query_cams = [cid for _, _, cid in query]\n    gallery_cams = [cid for _, _, cid in gallery]\n\n    # Compute mean AP\n    mAP = mean_ap(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams)\n    print('Mean AP: {:4.1%}'.format(mAP))\n\n    if not cmc_flag:\n        return mAP\n\n    cmc_configs = {\n        'config': dict(separate_camera_set=False, single_gallery_shot=False, first_match_break=True)\n    }\n    cmc_scores = {name: cmc(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams, **params) for name, params in\n                  cmc_configs.items()}\n\n    print('CMC Scores:')\n    for k in cmc_topk:\n        print('  top-{:<4}{:12.1%}'.format(k, cmc_scores['config'][k - 1]))\n    return cmc_scores['config'][0], mAP\n\n\ndef validate(val_loader, model, query, gallery, device, criterion='cosine', cmc_flag=False, rerank=False):\n    assert criterion in ['cosine', 'euclidean']\n    # when criterion == 'cosine', normalize feature of single image into unit norm\n    normalize = (criterion == 'cosine')\n\n    feature_dict = extract_reid_feature(val_loader, model, device, normalize)\n    dist_mat = pairwise_distance(feature_dict, query, gallery)\n    results = evaluate_all(dist_mat, query=query, gallery=gallery, cmc_flag=cmc_flag)\n    if not rerank:\n        return results\n    # apply person re-ranking\n    print('Applying person re-ranking')\n    dist_mat_query = pairwise_distance(feature_dict, query, query)\n    dist_mat_gallery = pairwise_distance(feature_dict, gallery, gallery)\n    dist_mat = re_ranking(dist_mat, dist_mat_query, dist_mat_gallery)\n    return evaluate_all(dist_mat, query=query, gallery=gallery, cmc_flag=cmc_flag)\n\n\n# location parameters for visualization\nGRID_SPACING = 10\nQUERY_EXTRA_SPACING = 90\n# border width\nBW = 5\nGREEN = (0, 255, 0)\nRED = (0, 0, 255)\n\n\ndef visualize_ranked_results(data_loader, model, query, gallery, device, visualize_dir, criterion='cosine',\n                             rerank=False, width=128, height=256, topk=10):\n    \"\"\"Visualize ranker results. We first compute pair-wise distance between query images and gallery images. Then for\n    every query image, `topk` gallery images with least distance between given query image are selected. We plot the\n    query image and selected gallery images together. A green border denotes a match, and a red one denotes a mis-match.\n    \"\"\"\n    assert criterion in ['cosine', 'euclidean']\n    normalize = (criterion == 'cosine')\n\n    # compute pairwise distance matrix\n    feature_dict = extract_reid_feature(data_loader, model, device, normalize)\n    dist_mat = pairwise_distance(feature_dict, query, gallery)\n\n    if rerank:\n        dist_mat_query = pairwise_distance(feature_dict, query, query)\n        dist_mat_gallery = pairwise_distance(feature_dict, gallery, gallery)\n        dist_mat = re_ranking(dist_mat, dist_mat_query, dist_mat_gallery)\n\n    # make dir if not exists\n    os.makedirs(visualize_dir, exist_ok=True)\n\n    dist_mat = dist_mat.numpy()\n    num_q, num_g = dist_mat.shape\n    print('query images: {}'.format(num_q))\n    print('gallery images: {}'.format(num_g))\n\n    assert num_q == len(query)\n    assert num_g == len(gallery)\n\n    # start visualizing\n    import cv2\n    sorted_idxes = np.argsort(dist_mat, axis=1)\n    for q_idx in range(num_q):\n        q_img_path, q_pid, q_cid = query[q_idx]\n\n        q_img = cv2.imread(q_img_path)\n        q_img = cv2.resize(q_img, (width, height))\n        # use black border to denote query image\n        q_img = cv2.copyMakeBorder(\n            q_img, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=(0, 0, 0)\n        )\n        q_img = cv2.resize(q_img, (width, height))\n        num_cols = topk + 1\n        grid_img = 255 * np.ones(\n            (height, num_cols * width + topk * GRID_SPACING + QUERY_EXTRA_SPACING, 3), dtype=np.uint8\n        )\n        grid_img[:, :width, :] = q_img\n\n        # collect top-k gallery images with smallest distance\n        rank_idx = 1\n        for g_idx in sorted_idxes[q_idx, :]:\n            g_img_path, g_pid, g_cid = gallery[g_idx]\n            invalid = (q_pid == g_pid) & (q_cid == g_cid)\n            if not invalid:\n                matched = (g_pid == q_pid)\n                border_color = GREEN if matched else RED\n                g_img = cv2.imread(g_img_path)\n                g_img = cv2.resize(g_img, (width, height))\n                g_img = cv2.copyMakeBorder(\n                    g_img, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=border_color\n                )\n                g_img = cv2.resize(g_img, (width, height))\n                start = rank_idx * width + rank_idx * GRID_SPACING + QUERY_EXTRA_SPACING\n                end = (rank_idx + 1) * width + rank_idx * GRID_SPACING + QUERY_EXTRA_SPACING\n                grid_img[:, start:end, :] = g_img\n\n                rank_idx += 1\n                if rank_idx > topk:\n                    break\n\n        save_path = osp.basename(osp.splitext(q_img_path)[0])\n        cv2.imwrite(osp.join(visualize_dir, save_path + '.jpg'), grid_img)\n\n        if (q_idx + 1) % 100 == 0:\n            print('Visualize {}/{}'.format(q_idx + 1, num_q))\n\n    print('Visualization process is done, ranked results are saved to {}'.format(visualize_dir))\n"
  },
  {
    "path": "tllib/utils/scheduler.py",
    "content": "\"\"\"\nModified from https://github.com/yxgeee/MMT\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch\nfrom bisect import bisect_right\n\n\nclass WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):\n    r\"\"\"Starts with a warm-up phase, then decays the learning rate of each parameter group by gamma once the\n    number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr.\n\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        milestones (list): List of epoch indices. Must be increasing.\n        gamma (float): Multiplicative factor of learning rate decay.\n            Default: 0.1.\n        warmup_factor (float): a float number :math:`k` between 0 and 1, the start learning rate of warmup phase\n            will be set to :math:`k*initial\\_lr`\n        warmup_steps (int): number of warm-up steps.\n        warmup_method (str): \"constant\" denotes a constant learning rate during warm-up phase and \"linear\" denotes a\n            linear-increasing learning rate during warm-up phase.\n        last_epoch (int): The index of last epoch. Default: -1.\n    \"\"\"\n\n    def __init__(\n            self,\n            optimizer,\n            milestones,\n            gamma=0.1,\n            warmup_factor=1.0 / 3,\n            warmup_steps=500,\n            warmup_method=\"linear\",\n            last_epoch=-1,\n    ):\n        if not list(milestones) == sorted(milestones):\n            raise ValueError(\n                \"Milestones should be a list of\" \" increasing integers. Got {}\",\n                milestones,\n            )\n\n        if warmup_method not in (\"constant\", \"linear\"):\n            raise ValueError(\n                \"Only 'constant' or 'linear' warmup_method accepted\"\n                \"got {}\".format(warmup_method)\n            )\n        self.milestones = milestones\n        self.gamma = gamma\n        self.warmup_factor = warmup_factor\n        self.warmup_steps = warmup_steps\n        self.warmup_method = warmup_method\n        super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        warmup_factor = 1\n        if self.last_epoch < self.warmup_steps:\n            if self.warmup_method == \"constant\":\n                warmup_factor = self.warmup_factor\n            elif self.warmup_method == \"linear\":\n                alpha = float(self.last_epoch) / float(self.warmup_steps)\n                warmup_factor = self.warmup_factor * (1 - alpha) + alpha\n        return [\n            base_lr\n            * warmup_factor\n            * self.gamma ** bisect_right(self.milestones, self.last_epoch)\n            for base_lr in self.base_lrs\n        ]\n"
  },
  {
    "path": "tllib/vision/__init__.py",
    "content": "__all__ = ['datasets', 'models', 'transforms']\n"
  },
  {
    "path": "tllib/vision/datasets/__init__.py",
    "content": "from .imagelist import ImageList\r\nfrom .office31 import Office31\r\nfrom .officehome import OfficeHome\r\nfrom .visda2017 import VisDA2017\r\nfrom .officecaltech import OfficeCaltech\r\nfrom .domainnet import DomainNet\r\nfrom .imagenet_r import ImageNetR\r\nfrom .imagenet_sketch import ImageNetSketch\r\nfrom .pacs import PACS\r\nfrom .digits import *\r\nfrom .aircrafts import Aircraft\r\nfrom .cub200 import CUB200\r\nfrom .stanford_cars import StanfordCars\r\nfrom .stanford_dogs import StanfordDogs\r\nfrom .coco70 import COCO70\r\nfrom .oxfordpets import OxfordIIITPets\r\nfrom .dtd import DTD\r\nfrom .oxfordflowers import OxfordFlowers102\r\nfrom .patchcamelyon import PatchCamelyon\r\nfrom .retinopathy import Retinopathy\r\nfrom .eurosat import EuroSAT\r\nfrom .resisc45 import Resisc45\r\nfrom .food101 import Food101\r\nfrom .sun397 import SUN397\r\nfrom .caltech101 import Caltech101\r\nfrom .cifar import CIFAR10, CIFAR100\r\n\r\n__all__ = ['ImageList', 'Office31', 'OfficeHome', \"VisDA2017\", \"OfficeCaltech\", \"DomainNet\", \"ImageNetR\",\r\n           \"ImageNetSketch\", \"Aircraft\", \"cub200\", \"StanfordCars\", \"StanfordDogs\", \"COCO70\", \"OxfordIIITPets\", \"PACS\",\r\n           \"DTD\", \"OxfordFlowers102\", \"PatchCamelyon\", \"Retinopathy\", \"EuroSAT\", \"Resisc45\", \"Food101\", \"SUN397\",\r\n           \"Caltech101\", \"CIFAR10\", \"CIFAR100\"]\r\n"
  },
  {
    "path": "tllib/vision/datasets/_util.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import List\nfrom torchvision.datasets.utils import download_and_extract_archive\n\n\ndef download(root: str, file_name: str, archive_name: str, url_link: str):\n    \"\"\"\n    Download file from internet url link.\n\n    Args:\n        root (str) The directory to put downloaded files.\n        file_name: (str) The name of the unzipped file.\n        archive_name: (str) The name of archive(zipped file) downloaded.\n        url_link: (str) The url link to download data.\n\n    .. note::\n        If `file_name` already exists under path `root`, then it is not downloaded again.\n        Else `archive_name` will be downloaded from `url_link` and extracted to `file_name`.\n    \"\"\"\n    if not os.path.exists(os.path.join(root, file_name)):\n        print(\"Downloading {}\".format(file_name))\n        # if os.path.exists(os.path.join(root, archive_name)):\n        #     os.remove(os.path.join(root, archive_name))\n        try:\n            download_and_extract_archive(url_link, download_root=root, filename=archive_name, remove_finished=False)\n        except Exception:\n            print(\"Fail to download {} from url link {}\".format(archive_name, url_link))\n            print('Please check you internet connection.'\n                  \"Simply trying again may be fine.\")\n            exit(0)\n\n\ndef check_exits(root: str, file_name: str):\n    \"\"\"Check whether `file_name` exists under directory `root`. \"\"\"\n    if not os.path.exists(os.path.join(root, file_name)):\n        print(\"Dataset directory {} not found under {}\".format(file_name, root))\n        exit(-1)\n\n\ndef read_list_from_file(file_name: str) -> List[str]:\n    \"\"\"Read data from file and convert each line into an element in the list\"\"\"\n    result = []\n    with open(file_name, \"r\") as f:\n        for line in f.readlines():\n            result.append(line.strip())\n    return result\n"
  },
  {
    "path": "tllib/vision/datasets/aircrafts.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass Aircraft(ImageList):\n    \"\"\"`FVGC-Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ \\\n        is a benchmark for the fine-grained visual categorization of aircraft.  \\\n        The dataset contains 10,200 images of aircraft, with 100 images for each \\\n        of the 102 different aircraft variants.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        sample_rate (int): The sampling rates to sample random ``training`` images for each category.\n            Choices include 100, 50, 30, 15. Default: 100.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            train/\n            test/\n            image_list/\n                train_100.txt\n                train_50.txt\n                train_30.txt\n                train_15.txt\n                test.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/449157d27987463cbdb1/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/06804f17fdb947aa9401/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/164996d09cc749abbdeb/?dl=1\"),\n    ]\n    image_list = {\n        \"train\": \"image_list/train_100.txt\",\n        \"train100\": \"image_list/train_100.txt\",\n        \"train50\": \"image_list/train_50.txt\",\n        \"train30\": \"image_list/train_30.txt\",\n        \"train15\": \"image_list/train_15.txt\",\n        \"test\": \"image_list/test.txt\",\n        \"test100\": \"image_list/test.txt\",\n    }\n    CLASSES = ['707-320', '727-200', '737-200', '737-300', '737-400', '737-500', '737-600', '737-700', '737-800',\n               '737-900', '747-100', '747-200', '747-300', '747-400', '757-200', '757-300', '767-200', '767-300',\n               '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', 'A321', 'A330-200',\n               'A330-300', 'A340-200', 'A340-300', 'A340-500', 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12',\n               'BAE 146-200', 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47',\n               'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525', 'Cessna 560',\n               'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100',\n               'DHC-8-300', 'DR-400', 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145',\n               'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A-B', 'F-A-18', 'Falcon 2000', 'Falcon 900',\n               'Fokker 100', 'Fokker 50', 'Fokker 70', 'Global Express', 'Gulfstream IV', 'Gulfstream V',\n               'Hawk T1', 'Il-76', 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200', 'PA-28',\n               'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', 'Tu-154', 'Yak-42']\n\n    def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,\n                 **kwargs):\n\n        if split == 'train':\n            list_name = 'train' + str(sample_rate)\n            assert list_name in self.image_list\n            data_list_file = os.path.join(root, self.image_list[list_name])\n        else:\n            data_list_file = os.path.join(root, self.image_list['test'])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(Aircraft, self).__init__(root, Aircraft.CLASSES, data_list_file=data_list_file, **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/caltech101.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass Caltech101(ImageList):\n    \"\"\"`The Caltech101 Dataset <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ contains objects\n    belonging to 101 categories with about 40 to 800 images per category. Most categories have about 50 images.\n    The size of each image is roughly 300 x 200 pixels.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/d6d4b813a800403f835e/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/ed4d0de80da246f98171/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/db1c444200a848799683/?dl=1\")\n    ]\n\n    def __init__(self, root, split='train', download=True, **kwargs):\n        classes = ['accordion', 'airplanes', 'anchor', 'ant', 'background_google', 'barrel', 'bass', 'beaver',\n                   'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon',\n                   'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face',\n                   'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin',\n                   'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'faces', 'faces_easy',\n                   'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano',\n                   'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree',\n                   'kangaroo', 'ketch', 'lamp', 'laptop', 'leopards', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly',\n                   'menorah', 'metronome', 'minaret', 'motorbikes', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda',\n                   'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner',\n                   'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus',\n                   'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly',\n                   'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(Caltech101, self).__init__(root, classes, os.path.join(root, 'image_list', '{}.txt'.format(split)),\n                                         **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/cifar.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nfrom torchvision.datasets.cifar import CIFAR10 as CIFAR10Base, CIFAR100 as CIFAR100Base\r\n\r\n\r\nclass CIFAR10(CIFAR10Base):\r\n    \"\"\"\r\n    `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.\r\n    \"\"\"\r\n\r\n    def __init__(self, root, split='train', transform=None, download=True):\r\n        super(CIFAR10, self).__init__(root, train=split == 'train', transform=transform, download=download)\r\n        self.num_classes = 10\r\n\r\n\r\nclass CIFAR100(CIFAR100Base):\r\n    \"\"\"\r\n    `CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.\r\n    \"\"\"\r\n\r\n    def __init__(self, root, split='train', transform=None, download=True):\r\n        super(CIFAR100, self).__init__(root, train=split == 'train', transform=transform, download=download)\r\n        self.num_classes = 100\r\n"
  },
  {
    "path": "tllib/vision/datasets/coco70.py",
    "content": "\"\"\"\r\n@author: Yifei Ji\r\n@contact: jiyf990330@163.com\r\n\"\"\"\r\nimport os\r\nfrom typing import Optional\r\nfrom .imagelist import ImageList\r\nfrom ._util import download as download_data, check_exits\r\n\r\n\r\nclass COCO70(ImageList):\r\n    \"\"\"COCO-70 dataset is a large-scale classification dataset (1000 images per class) created from\r\n    `COCO <https://cocodataset.org/>`_ Dataset.\r\n    It is used to explore the effect of fine-tuning with a large amount of data.\r\n\r\n    Args:\r\n        root (str): Root directory of dataset\r\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\r\n        sample_rate (int): The sampling rates to sample random ``training`` images for each category.\r\n            Choices include 100, 50, 30, 15. Default: 100.\r\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\r\n            in root directory. If dataset is already downloaded, it is not downloaded again.\r\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\r\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\r\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\r\n\r\n    .. note:: In `root`, there will exist following files after downloading.\r\n        ::\r\n            train/\r\n            test/\r\n            image_list/\r\n                train_100.txt\r\n                train_50.txt\r\n                train_30.txt\r\n                train_15.txt\r\n                test.txt\r\n    \"\"\"\r\n    download_list = [\r\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/b008c0d823ad488c8be1/?dl=1\"),\r\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/75a895576d5e4e59a88d/?dl=1\"),\r\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/ec6e45bc830d42f0924a/?dl=1\"),\r\n    ]\r\n    image_list = {\r\n        \"train\": \"image_list/train_100.txt\",\r\n        \"train100\": \"image_list/train_100.txt\",\r\n        \"train50\": \"image_list/train_50.txt\",\r\n        \"train30\": \"image_list/train_30.txt\",\r\n        \"train15\": \"image_list/train_15.txt\",\r\n        \"test\": \"image_list/test.txt\",\r\n        \"test100\": \"image_list/test.txt\",\r\n    }\r\n    CLASSES =['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',\r\n              'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', 'bench', 'bird', 'cat', 'dog',\r\n              'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',\r\n              'handbag', 'tie', 'suitcase', 'skis', 'kite', 'baseball_bat', 'skateboard', 'surfboard',\r\n              'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',\r\n              'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake',\r\n              'chair', 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',\r\n              'remote', 'keyboard', 'cell_phone', 'microwave', 'oven', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'teddy_bear']\r\n\r\n    def __init__(self, root: str, split: str, sample_rate: Optional[int] =100, download: Optional[bool] = False, **kwargs):\r\n\r\n        if split == 'train':\r\n            list_name = 'train' + str(sample_rate)\r\n            assert list_name in self.image_list\r\n            data_list_file = os.path.join(root, self.image_list[list_name])\r\n        else:\r\n            data_list_file = os.path.join(root, self.image_list['test'])\r\n\r\n        if download:\r\n            list(map(lambda args: download_data(root, *args), self.download_list))\r\n        else:\r\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\r\n\r\n        super(COCO70, self).__init__(root, COCO70.CLASSES, data_list_file=data_list_file, **kwargs)\r\n"
  },
  {
    "path": "tllib/vision/datasets/cub200.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass CUB200(ImageList):\n    \"\"\"`Caltech-UCSD Birds-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_  \\\n    is a dataset for fine-grained visual recognition with 11,788 images in 200 bird species. \\\n    It is an extended version of the CUB-200 dataset, roughly doubling the number of images.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        sample_rate (int): The sampling rates to sample random ``training`` images for each category.\n            Choices include 100, 50, 30, 15. Default: 100.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            train/\n            test/\n            image_list/\n                train_100.txt\n                train_50.txt\n                train_30.txt\n                train_15.txt\n                test.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/c2a5952eb18b466b9fb0/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/63db4c49b57b43198b95/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/72e95cccdcaf4b42b4eb/?dl=1\"),\n    ]\n    image_list = {\n        \"train\": \"image_list/train_100.txt\",\n        \"train100\": \"image_list/train_100.txt\",\n        \"train50\": \"image_list/train_50.txt\",\n        \"train30\": \"image_list/train_30.txt\",\n        \"train15\": \"image_list/train_15.txt\",\n        \"test\": \"image_list/test.txt\",\n        \"test100\": \"image_list/test.txt\",\n    }\n    CLASSES = ['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani',\n               '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet',\n               '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird',\n               '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal',\n               '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee',\n               '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant',\n               '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow',\n               '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo',\n               '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher',\n               '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher',\n               '041.Scissor_tailed_Flycatcher', '042.Vermilion_Flycatcher', '043.Yellow_bellied_Flycatcher',\n               '044.Frigatebird', '045.Northern_Fulmar', '046.Gadwall', '047.American_Goldfinch',\n               '048.European_Goldfinch', '049.Boat_tailed_Grackle', '050.Eared_Grebe',\n               '051.Horned_Grebe', '052.Pied_billed_Grebe', '053.Western_Grebe', '054.Blue_Grosbeak',\n               '055.Evening_Grosbeak', '056.Pine_Grosbeak', '057.Rose_breasted_Grosbeak', '058.Pigeon_Guillemot',\n               '059.California_Gull', '060.Glaucous_winged_Gull', '061.Heermann_Gull', '062.Herring_Gull',\n               '063.Ivory_Gull', '064.Ring_billed_Gull', '065.Slaty_backed_Gull', '066.Western_Gull',\n               '067.Anna_Hummingbird', '068.Ruby_throated_Hummingbird', '069.Rufous_Hummingbird', '070.Green_Violetear',\n               '071.Long_tailed_Jaeger', '072.Pomarine_Jaeger', '073.Blue_Jay', '074.Florida_Jay', '075.Green_Jay',\n               '076.Dark_eyed_Junco', '077.Tropical_Kingbird', '078.Gray_Kingbird', '079.Belted_Kingfisher',\n               '080.Green_Kingfisher', '081.Pied_Kingfisher', '082.Ringed_Kingfisher', '083.White_breasted_Kingfisher',\n               '084.Red_legged_Kittiwake', '085.Horned_Lark', '086.Pacific_Loon', '087.Mallard',\n               '088.Western_Meadowlark', '089.Hooded_Merganser', '090.Red_breasted_Merganser', '091.Mockingbird',\n               '092.Nighthawk', '093.Clark_Nutcracker', '094.White_breasted_Nuthatch', '095.Baltimore_Oriole',\n               '096.Hooded_Oriole', '097.Orchard_Oriole', '098.Scott_Oriole', '099.Ovenbird', '100.Brown_Pelican',\n               '101.White_Pelican', '102.Western_Wood_Pewee', '103.Sayornis', '104.American_Pipit',\n               '105.Whip_poor_Will', '106.Horned_Puffin', '107.Common_Raven', '108.White_necked_Raven',\n               '109.American_Redstart', '110.Geococcyx', '111.Loggerhead_Shrike', '112.Great_Grey_Shrike',\n               '113.Baird_Sparrow', '114.Black_throated_Sparrow', '115.Brewer_Sparrow', '116.Chipping_Sparrow',\n               '117.Clay_colored_Sparrow', '118.House_Sparrow', '119.Field_Sparrow', '120.Fox_Sparrow',\n               '121.Grasshopper_Sparrow', '122.Harris_Sparrow', '123.Henslow_Sparrow', '124.Le_Conte_Sparrow',\n               '125.Lincoln_Sparrow', '126.Nelson_Sharp_tailed_Sparrow', '127.Savannah_Sparrow', '128.Seaside_Sparrow',\n               '129.Song_Sparrow', '130.Tree_Sparrow', '131.Vesper_Sparrow', '132.White_crowned_Sparrow',\n               '133.White_throated_Sparrow', '134.Cape_Glossy_Starling', '135.Bank_Swallow', '136.Barn_Swallow',\n               '137.Cliff_Swallow', '138.Tree_Swallow', '139.Scarlet_Tanager', '140.Summer_Tanager', '141.Artic_Tern',\n               '142.Black_Tern', '143.Caspian_Tern', '144.Common_Tern', '145.Elegant_Tern', '146.Forsters_Tern',\n               '147.Least_Tern', '148.Green_tailed_Towhee', '149.Brown_Thrasher', '150.Sage_Thrasher',\n               '151.Black_capped_Vireo', '152.Blue_headed_Vireo', '153.Philadelphia_Vireo', '154.Red_eyed_Vireo',\n               '155.Warbling_Vireo', '156.White_eyed_Vireo', '157.Yellow_throated_Vireo', '158.Bay_breasted_Warbler',\n               '159.Black_and_white_Warbler', '160.Black_throated_Blue_Warbler', '161.Blue_winged_Warbler',\n               '162.Canada_Warbler', '163.Cape_May_Warbler', '164.Cerulean_Warbler', '165.Chestnut_sided_Warbler',\n               '166.Golden_winged_Warbler', '167.Hooded_Warbler', '168.Kentucky_Warbler', '169.Magnolia_Warbler',\n               '170.Mourning_Warbler', '171.Myrtle_Warbler', '172.Nashville_Warbler', '173.Orange_crowned_Warbler',\n               '174.Palm_Warbler', '175.Pine_Warbler', '176.Prairie_Warbler', '177.Prothonotary_Warbler',\n               '178.Swainson_Warbler', '179.Tennessee_Warbler', '180.Wilson_Warbler', '181.Worm_eating_Warbler',\n               '182.Yellow_Warbler', '183.Northern_Waterthrush', '184.Louisiana_Waterthrush', '185.Bohemian_Waxwing',\n               '186.Cedar_Waxwing', '187.American_Three_toed_Woodpecker', '188.Pileated_Woodpecker',\n               '189.Red_bellied_Woodpecker', '190.Red_cockaded_Woodpecker', '191.Red_headed_Woodpecker',\n               '192.Downy_Woodpecker', '193.Bewick_Wren', '194.Cactus_Wren', '195.Carolina_Wren', '196.House_Wren',\n               '197.Marsh_Wren', '198.Rock_Wren', '199.Winter_Wren', '200.Common_Yellowthroat']\n\n    def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,\n                 **kwargs):\n\n        if split == 'train':\n            list_name = 'train' + str(sample_rate)\n            assert list_name in self.image_list\n            data_list_file = os.path.join(root, self.image_list[list_name])\n        else:\n            data_list_file = os.path.join(root, self.image_list['test'])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(CUB200, self).__init__(root, CUB200.CLASSES, data_list_file=data_list_file, **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/digits.py",
    "content": "\"\"\"\n@author: Junguang Jiang, Baixu Chen\n@contact: JiangJunguang1123@outlook.com, cbx_99_hasta@outlook.com\n\"\"\"\nimport os\nfrom typing import Optional, Tuple, Any\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass MNIST(ImageList):\n    \"\"\"`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.\n\n    Args:\n        root (str): Root directory of dataset where ``MNIST/processed/training.pt``\n            and  ``MNIST/processed/test.pt`` exist.\n        mode (str): The channel mode for image. Choices includes ``\"L\"```, ``\"RGB\"``.\n            Default: ``\"L\"```\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and\n            puts it in root directory. If dataset is already downloaded, it is not\n            downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/16feadf7fb3641c2be9a/?dl=1\"),\n        (\"mnist_train_image\", \"mnist_image.tar.gz\", \"https://cloud.tsinghua.edu.cn/f/c93080af28e54559aeeb/?dl=1\"),\n        # (\"mnist_test_image\", \"mnist_image.tar.gz\", \"https://cloud.tsinghua.edu.cn/f/c93080af28e54559aeeb/?dl=1\")\n    ]\n    image_list = {\n        \"train\": \"image_list/mnist_train.txt\",\n        \"test\": \"image_list/mnist_test.txt\"\n    }\n    CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',\n               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']\n\n    def __init__(self, root, mode=\"L\", split='train', download: Optional[bool] = True, **kwargs):\n        assert split in ['train', 'test']\n        data_list_file = os.path.join(root, self.image_list[split])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        assert mode in ['L', 'RGB']\n        self.mode = mode\n        super(MNIST, self).__init__(root, MNIST.CLASSES, data_list_file=data_list_file, **kwargs)\n\n    def __getitem__(self, index: int) -> Tuple[Any, int]:\n        \"\"\"\n        Args:\n            index (int): Index\n\n        return (tuple): (image, target) where target is index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        img = self.loader(path).convert(self.mode)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.target_transform is not None and target is not None:\n            target = self.target_transform(target)\n        return img, target\n\n    @classmethod\n    def get_classes(self):\n        return MNIST.CLASSES\n\n\nclass USPS(ImageList):\n    \"\"\"`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.\n        The data-format is : [label [index:value ]*256 \\\\n] * num_lines, where ``label`` lies in ``[1, 10]``.\n        The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``\n        and make pixel values in ``[0, 255]``.\n\n    Args:\n        root (str): Root directory of dataset to store``USPS`` data files.\n        mode (str): The channel mode for image. Choices includes ``\"L\"```, ``\"RGB\"``.\n            Default: ``\"L\"```\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n        download (bool, optional): If true, downloads the dataset from the internet and\n            puts it in root directory. If dataset is already downloaded, it is not\n            downloaded again.\n\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/721ceaf3c031413cb62f/?dl=1\"),\n        (\"usps_train_image\", \"usps_image.tar.gz\", \"https://cloud.tsinghua.edu.cn/f/c5bd329a00fb4dc79608/?dl=1\"),\n        # (\"usps_test_image\", \"usps_image.tar.gz\", \"https://cloud.tsinghua.edu.cn/f/c5bd329a00fb4dc79608/?dl=1\")\n    ]\n    image_list = {\n        \"train\": \"image_list/usps_train.txt\",\n        \"test\": \"image_list/usps_test.txt\"\n    }\n    CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',\n               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']\n\n    def __init__(self, root, mode=\"L\", split='train', download: Optional[bool] = True, **kwargs):\n        assert split in ['train', 'test']\n        data_list_file = os.path.join(root, self.image_list[split])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        assert mode in ['L', 'RGB']\n        self.mode = mode\n        super(USPS, self).__init__(root, USPS.CLASSES, data_list_file=data_list_file, **kwargs)\n\n    def __getitem__(self, index: int) -> Tuple[Any, int]:\n        \"\"\"\n        Args:\n            index (int): Index\n\n        return (tuple): (image, target) where target is index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        img = self.loader(path).convert(self.mode)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.target_transform is not None and target is not None:\n            target = self.target_transform(target)\n        return img, target\n\n\nclass SVHN(ImageList):\n    \"\"\"`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.\n    Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,\n    we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which\n    expect the class labels to be in the range `[0, C-1]`\n\n    .. warning::\n\n        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load data from `.mat` format.\n\n    Args:\n        root (str): Root directory of dataset where directory\n            ``SVHN`` exists.\n        mode (str): The channel mode for image. Choices includes ``\"L\"```, ``\"RGB\"``.\n            Default: ``\"RGB\"```\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n        download (bool, optional): If true, downloads the dataset from the internet and\n            puts it in root directory. If dataset is already downloaded, it is not\n            downloaded again.\n\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/12b35fb08f8049f98362/?dl=1\"),\n        (\"svhn_image\", \"svhn_image.tar.gz\", \"https://cloud.tsinghua.edu.cn/f/cc02de6cf81543378cce/?dl=1\")\n    ]\n    image_list = \"image_list/svhn_balanced.txt\"\n    # image_list = \"image_list/svhn.txt\"\n    CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',\n               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']\n\n    def __init__(self, root, mode=\"L\", download: Optional[bool] = True, **kwargs):\n        data_list_file = os.path.join(root, self.image_list)\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        assert mode in ['L', 'RGB']\n        self.mode = mode\n        super(SVHN, self).__init__(root, SVHN.CLASSES, data_list_file=data_list_file, **kwargs)\n\n    def __getitem__(self, index: int) -> Tuple[Any, int]:\n        \"\"\"\n        Args:\n            index (int): Index\n\n        return (tuple): (image, target) where target is index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        img = self.loader(path).convert(self.mode)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.target_transform is not None and target is not None:\n            target = self.target_transform(target)\n        return img, target\n\n\nclass MNISTRGB(MNIST):\n    def __init__(self, root, **kwargs):\n        super(MNISTRGB, self).__init__(root, mode='RGB', **kwargs)\n\n\nclass USPSRGB(USPS):\n    def __init__(self, root, **kwargs):\n        super(USPSRGB, self).__init__(root, mode='RGB', **kwargs)\n\n\nclass SVHNRGB(SVHN):\n    def __init__(self, root, **kwargs):\n        super(SVHNRGB, self).__init__(root, mode='RGB', **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/domainnet.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass DomainNet(ImageList):\n    \"\"\"`DomainNet <http://ai.bu.edu/M3SDA/#dataset>`_ (cleaned version, recommended)\n\n    See `Moment Matching for Multi-Source Domain Adaptation <https://arxiv.org/abs/1812.01754>`_ for details.\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'c'``:clipart, \\\n            ``'i'``: infograph, ``'p'``: painting, ``'q'``: quickdraw, ``'r'``: real, ``'s'``: sketch\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            clipart/\n            infograph/\n            painting/\n            quickdraw/\n            real/\n            sketch/\n            image_list/\n                clipart.txt\n                ...\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/bf0fe327e4b046eb89ba/?dl=1\"),\n        (\"clipart\", \"clipart.tgz\", \"https://cloud.tsinghua.edu.cn/f/f0515164a4864220b98b/?dl=1\"),\n        (\"infograph\", \"infograph.tgz\", \"https://cloud.tsinghua.edu.cn/f/98b19d5fc9884109a9cb/?dl=1\"),\n        (\"painting\", \"painting.tgz\", \"https://cloud.tsinghua.edu.cn/f/11285ce9fbd34bb7b28c/?dl=1\"),\n        (\"quickdraw\", \"quickdraw.tgz\", \"https://cloud.tsinghua.edu.cn/f/6faa9efb498b494abf66/?dl=1\"),\n        (\"real\", \"real.tgz\", \"https://cloud.tsinghua.edu.cn/f/17a101842c564959b525/?dl=1\"),\n        (\"sketch\", \"sketch.tgz\", \"https://cloud.tsinghua.edu.cn/f/b305add26e9d47349495/?dl=1\"),\n    ]\n    image_list = {\n        \"c\": \"clipart\",\n        \"i\": \"infograph\",\n        \"p\": \"painting\",\n        \"q\": \"quickdraw\",\n        \"r\": \"real\",\n        \"s\": \"sketch\",\n    }\n    CLASSES = ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil',\n               'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat',\n               'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench',\n               'bicycle', 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang',\n               'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket',\n               'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera',\n               'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan',\n               'cello', 'cell_phone', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud',\n               'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile',\n               'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 'diving_board', 'dog', 'dolphin', 'donut',\n               'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant',\n               'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire_hydrant',\n               'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 'floor_lamp', 'flower',\n               'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 'goatee',\n               'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones',\n               'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital',\n               'hot_air_balloon', 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream',\n               'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf',\n               'leg', 'light_bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster',\n               'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave',\n               'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom',\n               'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paintbrush',\n               'paint_can', 'palm_tree', 'panda', 'pants', 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut',\n               'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup_truck', 'picture_frame', 'pig', 'pillow',\n               'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 'popsicle', 'postcard', 'potato',\n               'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote_control',\n               'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw',\n               'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark',\n               'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag',\n               'smiley_face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat',\n               'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo',\n               'stethoscope', 'stitches', 'stop_sign', 'stove', 'strawberry', 'streetlight', 'string_bean', 'submarine',\n               'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 'syringe', 'table', 'teapot', 'teddy-bear',\n               'telephone', 'television', 'tennis_racquet', 'tent', 'The_Eiffel_Tower', 'The_Great_Wall_of_China',\n               'The_Mona_Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado',\n               'tractor', 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 't-shirt',\n               'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide',\n               'whale', 'wheel', 'windmill', 'wine_bottle', 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']\n\n    def __init__(self, root: str, task: str, split: Optional[str] = 'train', download: Optional[float] = False, **kwargs):\n        assert task in self.image_list\n        assert split in ['train', 'test']\n        data_list_file = os.path.join(root, \"image_list\", \"{}_{}.txt\".format(self.image_list[task], split))\n        print(\"loading {}\".format(data_list_file))\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda args: check_exits(root, args[0]), self.download_list))\n\n        super(DomainNet, self).__init__(root, DomainNet.CLASSES, data_list_file=data_list_file, **kwargs)\n\n    @classmethod\n    def domains(cls):\n        return list(cls.image_list.keys())\n"
  },
  {
    "path": "tllib/vision/datasets/dtd.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass DTD(ImageList):\n    \"\"\"\n    `The Describable Textures Dataset (DTD) <https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html>`_ is an \\\n        evolving collection of textural images in the wild, annotated with a series of human-centric attributes, \\\n         inspired by the perceptual properties of textures. \\\n         The task consists in classifying images of textural patterns (47 classes, with 120 training images each). \\\n         Some of the textures are banded, bubbly, meshed, lined, or porous. \\\n         The image size ranges between 300x300 and 640x640 pixels.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/2218bfa61bac46539dd7/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/08fd47d35fc94f36a508/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/15873fe162c343cca8ed/?dl=1\"),\n        (\"validation\", \"validation.tgz\", \"https://cloud.tsinghua.edu.cn/f/75c9ab22ebea4c3b87e7/?dl=1\"),\n    ]\n    CLASSES = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked',\n               'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy',\n               'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled',\n               'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous',\n               'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped',\n               'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']\n\n    def __init__(self, root, split, download=False, **kwargs):\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(DTD, self).__init__(root, DTD.CLASSES, os.path.join(root, \"image_list\", \"{}.txt\".format(split)), **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/eurosat.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass EuroSAT(ImageList):\n    \"\"\"\n    `EuroSAT <https://github.com/phelber/eurosat>`_ dataset consists in classifying \\\n        Sentinel-2 satellite images into 10 different types of land use (Residential, \\\n        Industrial, River, Highway, etc). \\\n        The spatial resolution corresponds to 10 meters per pixel, and the image size \\\n        is 64x64 pixels.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n    \"\"\"\n    CLASSES =['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture',\n                   'PermanentCrop', 'Residential', 'River', 'SeaLake']\n\n    def __init__(self, root, split='train', download=False, **kwargs):\n        if download:\n            download_data(root, \"eurosat\", \"eurosat.tgz\", \"https://cloud.tsinghua.edu.cn/f/9983d7ab86184d74bb17/?dl=1\")\n        else:\n            check_exits(root, \"eurosat\")\n        split = 'train[:21600]' if split == 'train' else 'train[21600:]'\n\n        root = os.path.join(root, \"eurosat\")\n        super(EuroSAT, self).__init__(root, EuroSAT.CLASSES, os.path.join(root, \"imagelist\", \"{}.txt\".format(split)), **kwargs)\n\n\n\n"
  },
  {
    "path": "tllib/vision/datasets/food101.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom torchvision.datasets.folder import ImageFolder\nimport os.path as osp\nfrom ._util import download as download_data, check_exits\n\n\nclass Food101(ImageFolder):\n    \"\"\"`Food-101 <https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/>`_ is a dataset\n    for fine-grained visual recognition with 101,000 images in 101 food categories.\n\n    Args:\n        root (str): Root directory of dataset.\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            train/\n            test/\n    \"\"\"\n    download_list = [\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/1d7bd727cc1e4ce2bef5/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/7e11992d7495417db32b/?dl=1\")\n    ]\n\n    def __init__(self, root, split='train', transform=None, download=True):\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n        super(Food101, self).__init__(osp.join(root, split), transform=transform)\n        self.num_classes = 101\n"
  },
  {
    "path": "tllib/vision/datasets/imagelist.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nimport warnings\nfrom typing import Optional, Callable, Tuple, Any, List, Iterable\nimport bisect\n\nfrom torch.utils.data.dataset import Dataset, T_co, IterableDataset\nimport torchvision.datasets as datasets\nfrom torchvision.datasets.folder import default_loader\n\n\nclass ImageList(datasets.VisionDataset):\n    \"\"\"A generic Dataset class for image classification\n\n    Args:\n        root (str): Root directory of dataset\n        classes (list[str]): The names of all the classes\n        data_list_file (str): File to read the image list from.\n        transform (callable, optional): A function/transform that  takes in an PIL image \\\n            and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `data_list_file`, each line has 2 values in the following format.\n        ::\n            source_dir/dog_xxx.png 0\n            source_dir/cat_123.png 1\n            target_dir/dog_xxy.png 0\n            target_dir/cat_nsdf3.png 1\n\n        The first value is the relative path of an image, and the second value is the label of the corresponding image.\n        If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`.\n    \"\"\"\n\n    def __init__(self, root: str, classes: List[str], data_list_file: str,\n                 transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):\n        super().__init__(root, transform=transform, target_transform=target_transform)\n        self.samples = self.parse_data_file(data_list_file)\n        self.targets = [s[1] for s in self.samples]\n        self.classes = classes\n        self.class_to_idx = {cls: idx\n                             for idx, cls in enumerate(self.classes)}\n        self.loader = default_loader\n        self.data_list_file = data_list_file\n\n    def __getitem__(self, index: int) -> Tuple[Any, int]:\n        \"\"\"\n        Args:\n            index (int): Index\n            return (tuple): (image, target) where target is index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.target_transform is not None and target is not None:\n            target = self.target_transform(target)\n        return img, target\n\n    def __len__(self) -> int:\n        return len(self.samples)\n\n    def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]:\n        \"\"\"Parse file to data list\n\n        Args:\n            file_name (str): The path of data file\n            return (list): List of (image path, class_index) tuples\n        \"\"\"\n        with open(file_name, \"r\") as f:\n            data_list = []\n            for line in f.readlines():\n                split_line = line.split()\n                target = split_line[-1]\n                path = ' '.join(split_line[:-1])\n                if not os.path.isabs(path):\n                    path = os.path.join(self.root, path)\n                target = int(target)\n                data_list.append((path, target))\n        return data_list\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"Number of classes\"\"\"\n        return len(self.classes)\n\n    @classmethod\n    def domains(cls):\n        \"\"\"All possible domain in this dataset\"\"\"\n        raise NotImplemented\n\n\nclass MultipleDomainsDataset(Dataset[T_co]):\n    r\"\"\"Dataset as a concatenation of multiple datasets.\n\n    This class is useful to assemble different existing datasets.\n\n    Args:\n        datasets (sequence): List of datasets to be concatenated\n    \"\"\"\n    datasets: List[Dataset[T_co]]\n    cumulative_sizes: List[int]\n\n    @staticmethod\n    def cumsum(sequence):\n        r, s = [], 0\n        for e in sequence:\n            l = len(e)\n            r.append(l + s)\n            s += l\n        return r\n\n    def __init__(self, domains: Iterable[Dataset], domain_names: Iterable[str], domain_ids) -> None:\n        super(MultipleDomainsDataset, self).__init__()\n        # Cannot verify that datasets is Sized\n        assert len(domains) > 0, 'datasets should not be an empty iterable'  # type: ignore[arg-type]\n        self.datasets = self.domains = list(domains)\n        for d in self.domains:\n            assert not isinstance(d, IterableDataset), \"MultipleDomainsDataset does not support IterableDataset\"\n        self.cumulative_sizes = self.cumsum(self.domains)\n        self.domain_names = domain_names\n        self.domain_ids = domain_ids\n\n    def __len__(self):\n        return self.cumulative_sizes[-1]\n\n    def __getitem__(self, idx):\n        if idx < 0:\n            if -idx > len(self):\n                raise ValueError(\"absolute value of index should not exceed dataset length\")\n            idx = len(self) + idx\n        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)\n        if dataset_idx == 0:\n            sample_idx = idx\n        else:\n            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]\n        return self.domains[dataset_idx][sample_idx] + (self.domain_ids[dataset_idx],)\n\n    @property\n    def cummulative_sizes(self):\n        warnings.warn(\"cummulative_sizes attribute is renamed to \"\n                      \"cumulative_sizes\", DeprecationWarning, stacklevel=2)\n        return self.cumulative_sizes"
  },
  {
    "path": "tllib/vision/datasets/imagenet_r.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nfrom typing import Optional\r\nimport os\r\nfrom .imagelist import ImageList\r\nfrom ._util import download as download_data, check_exits\r\n\r\n\r\nclass ImageNetR(ImageList):\r\n    \"\"\"ImageNet-R Dataset.\r\n\r\n    Args:\r\n        root (str): Root directory of dataset\r\n        task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \\\r\n            ``'D'``: dslr and ``'W'``: webcam.\r\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\r\n            in root directory. If dataset is already downloaded, it is not downloaded again.\r\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\r\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\r\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\r\n\r\n    .. note:: You need to put ``train`` directory of ImageNet-1K and ``imagenet_r`` directory of ImageNet-R\r\n        manually in `root` directory.\r\n\r\n        DALIB will only download ImageList automatically.\r\n        In `root`, there will exist following files after preparing.\r\n        ::\r\n            train/\r\n                n02128385/\r\n                ...\r\n            val/\r\n            imagenet-r/\r\n                n02128385/\r\n            image_list/\r\n                imagenet-train.txt\r\n                imagenet-r.txt\r\n                art.txt\r\n                ...\r\n    \"\"\"\r\n    download_list = [\r\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/7786eabd3565409c8c33/?dl=1\"),\r\n    ]\r\n    image_list = {\r\n        \"IN\": \"image_list/imagenet-train.txt\",\r\n        \"IN-val\": \"image_list/imagenet-val.txt\",\r\n        \"INR\": \"image_list/imagenet-r.txt\",\r\n        \"art\": \"art.txt\",\r\n        \"embroidery\": \"embroidery.txt\",\r\n        \"misc\": \"misc.txt\",\r\n        \"sculpture\": \"sculpture.txt\",\r\n        \"tattoo\": \"tattoo.txt\",\r\n        \"cartoon\": \"cartoon.txt\",\r\n        \"graffiti\": \"graffiti.txt\",\r\n        \"origami\": \"origami.txt\",\r\n        \"sketch\": \"sketch.txt\",\r\n        \"toy\": \"toy.txt\",\r\n        \"deviantart\": \"deviantart.txt\",\r\n        \"graphic\": \"graphic.txt\",\r\n        \"painting\": \"painting.txt\",\r\n        \"sticker\": \"sticker.txt\",\r\n        \"videogame\": \"videogame.txt\"\r\n    }\r\n    CLASSES = ['n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214'\r\n, 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677']\r\n\r\n    def __init__(self, root: str, task: str, split: Optional[str] = 'all', download: Optional[bool] = True, **kwargs):\r\n        assert task in self.image_list\r\n        assert split in [\"train\", \"val\", \"all\"]\r\n        if task == \"IN\" and split == \"val\":\r\n            task = \"IN-val\"\r\n\r\n        data_list_file = os.path.join(root, self.image_list[task])\r\n\r\n        if download:\r\n            list(map(lambda args: download_data(root, *args), self.download_list))\r\n        else:\r\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\r\n\r\n        super(ImageNetR, self).__init__(root, ImageNetR.CLASSES, data_list_file=data_list_file, **kwargs)\r\n\r\n    @classmethod\r\n    def domains(cls):\r\n        return list(cls.image_list.keys())"
  },
  {
    "path": "tllib/vision/datasets/imagenet_sketch.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nfrom typing import Optional\r\nimport os\r\nfrom torchvision.datasets.imagenet import ImageNet\r\nfrom .imagelist import ImageList\r\nfrom ._util import download as download_data, check_exits\r\n\r\n\r\nclass ImageNetSketch(ImageList):\r\n    \"\"\"ImageNet-Sketch Dataset.\r\n\r\n    Args:\r\n        root (str): Root directory of dataset\r\n        task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \\\r\n            ``'D'``: dslr and ``'W'``: webcam.\r\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\r\n            in root directory. If dataset is already downloaded, it is not downloaded again.\r\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\r\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\r\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\r\n\r\n    .. note:: You need to put ``train`` directory, ``metabin`` of ImageNet-1K and ``sketch`` directory of ImageNet-Sketch\r\n        manually in `root` directory.\r\n\r\n        DALIB will only download ImageList automatically.\r\n        In `root`, there will exist following files after preparing.\r\n        ::\r\n            metabin (from ImageNet)\r\n            train/\r\n                n02128385/\r\n                ...\r\n            val/\r\n            sketch/\r\n                n02128385/\r\n            image_list/\r\n                imagenet-train.txt\r\n                sketch.txt\r\n                ...\r\n    \"\"\"\r\n    download_list = [\r\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/7786eabd3565409c8c33/?dl=1\"),\r\n    ]\r\n    image_list = {\r\n        \"IN\": \"image_list/imagenet-train.txt\",\r\n        \"IN-val\": \"image_list/imagenet-val.txt\",\r\n        \"sketch\": \"image_list/sketch.txt\",\r\n    }\r\n\r\n    def __init__(self, root: str, task: str, split: Optional[str] = 'all', download: Optional[bool] = True, **kwargs):\r\n        assert task in self.image_list\r\n        assert split in [\"train\", \"val\", \"all\"]\r\n        if task == \"IN\" and split == \"val\":\r\n            task = \"IN-val\"\r\n\r\n        data_list_file = os.path.join(root, self.image_list[task])\r\n\r\n        if download:\r\n            list(map(lambda args: download_data(root, *args), self.download_list))\r\n        else:\r\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\r\n\r\n        super(ImageNetSketch, self).__init__(root, ImageNet(root).classes, data_list_file=data_list_file, **kwargs)\r\n\r\n    @classmethod\r\n    def domains(cls):\r\n        return list(cls.image_list.keys())"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/__init__.py",
    "content": "from .rendered_hand_pose import RenderedHandPose\nfrom .hand_3d_studio import Hand3DStudio, Hand3DStudioAll\nfrom .freihand import FreiHand\n\nfrom .surreal import SURREAL\nfrom .lsp import LSP\nfrom .human36m import Human36M\n\n"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/freihand.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport json\nimport time\nimport torch\nimport os\nimport os.path as osp\nfrom torchvision.datasets.utils import download_and_extract_archive\n\nfrom ...transforms.keypoint_detection import *\nfrom .keypoint_dataset import Hand21KeypointDataset\nfrom .util import *\n\n\n\"\"\" General util functions. \"\"\"\ndef _assert_exist(p):\n    msg = 'File does not exists: %s' % p\n    assert os.path.exists(p), msg\n\n\ndef json_load(p):\n    _assert_exist(p)\n    with open(p, 'r') as fi:\n        d = json.load(fi)\n    return d\n\n\ndef load_db_annotation(base_path, set_name=None):\n    if set_name is None:\n        # only training set annotations are released so this is a valid default choice\n        set_name = 'training'\n\n    print('Loading FreiHAND dataset index ...')\n    t = time.time()\n\n    # assumed paths to data containers\n    k_path = os.path.join(base_path, '%s_K.json' % set_name)\n    mano_path = os.path.join(base_path, '%s_mano.json' % set_name)\n    xyz_path = os.path.join(base_path, '%s_xyz.json' % set_name)\n\n    # load if exist\n    K_list = json_load(k_path)\n    mano_list = json_load(mano_path)\n    xyz_list = json_load(xyz_path)\n\n    # should have all the same length\n    assert len(K_list) == len(mano_list), 'Size mismatch.'\n    assert len(K_list) == len(xyz_list), 'Size mismatch.'\n\n    print('Loading of %d samples done in %.2f seconds' % (len(K_list), time.time()-t))\n    return list(zip(K_list, mano_list, xyz_list))\n\n\ndef projectPoints(xyz, K):\n    \"\"\" Project 3D coordinates into image space. \"\"\"\n    xyz = np.array(xyz)\n    K = np.array(K)\n    uv = np.matmul(K, xyz.T).T\n    return uv[:, :2] / uv[:, -1:]\n\n\n\"\"\" Dataset related functions. \"\"\"\ndef db_size(set_name):\n    \"\"\" Hardcoded size of the datasets. \"\"\"\n    if set_name == 'training':\n        return 32560  # number of unique samples (they exists in multiple 'versions')\n    elif set_name == 'evaluation':\n        return 3960\n    else:\n        assert 0, 'Invalid choice.'\n\n\nclass sample_version:\n    gs = 'gs'  # green screen\n    hom = 'hom'  # homogenized\n    sample = 'sample'  # auto colorization with sample points\n    auto = 'auto'  # auto colorization without sample points: automatic color hallucination\n\n    db_size = db_size('training')\n\n    @classmethod\n    def valid_options(cls):\n        return [cls.gs, cls.hom, cls.sample, cls.auto]\n\n\n    @classmethod\n    def check_valid(cls, version):\n        msg = 'Invalid choice: \"%s\" (must be in %s)' % (version, cls.valid_options())\n        assert version in cls.valid_options(), msg\n\n    @classmethod\n    def map_id(cls, id, version):\n        cls.check_valid(version)\n        return id + cls.db_size*cls.valid_options().index(version)\n\n\nclass FreiHand(Hand21KeypointDataset):\n    \"\"\"`FreiHand Dataset <https://lmb.informatik.uni-freiburg.de/projects/freihand/>`_\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.\n        task (str, optional): The post-processing option to create dataset. Choices include ``'gs'``: green screen \\\n            recording, ``'auto'``: auto colorization without sample points: automatic color hallucination, \\\n            ``'sample'``: auto colorization with sample points, ``'hom'``: homogenized, \\\n            and ``'all'``: all hands. Default: 'all'.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and\n            its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.\n        image_size (tuple): (width, height) of the image. Default: (256, 256)\n        heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            *.json\n            training/\n            evaluation/\n    \"\"\"\n    def __init__(self, root, split='train', task='all', download=True, **kwargs):\n        if download:\n            if not osp.exists(osp.join(root, \"training\")) or not osp.exists(osp.join(root, \"evaluation\")):\n                download_and_extract_archive(\"https://lmb.informatik.uni-freiburg.de/data/freihand/FreiHAND_pub_v2.zip\",\n                                             download_root=root, filename=\"FreiHAND_pub_v2.zip\", remove_finished=False,\n                                             extract_root=root)\n\n        assert split in ['train', 'test', 'all']\n        self.split = split\n\n        assert task in ['all', 'gs', 'auto', 'sample', 'hom']\n        self.task = task\n        if task == 'all':\n            samples = self.get_samples(root, 'gs') + self.get_samples(root, 'auto') + self.get_samples(root, 'sample') + self.get_samples(root, 'hom')\n        else:\n            samples = self.get_samples(root, task)\n        random.seed(42)\n        random.shuffle(samples)\n        samples_len = len(samples)\n        samples_split = min(int(samples_len * 0.2), 3200)\n        if self.split == 'train':\n            samples = samples[samples_split:]\n        elif self.split == 'test':\n            samples = samples[:samples_split]\n\n        super(FreiHand, self).__init__(root, samples, **kwargs)\n\n    def __getitem__(self, index):\n        sample = self.samples[index]\n        image_name = sample['name']\n        image_path = os.path.join(self.root, image_name)\n        image = Image.open(image_path)\n        keypoint3d_camera = np.array(sample['keypoint3d'])  # NUM_KEYPOINTS x 3\n        keypoint2d = np.array(sample['keypoint2d'])  # NUM_KEYPOINTS x 2\n        intrinsic_matrix = np.array(sample['intrinsic_matrix'])\n        Zc = keypoint3d_camera[:, 2]\n\n        # Crop the images such that the hand is at the center of the image\n        # The images will be 1.5 times larger than the hand\n        # The crop process will change Xc and Yc, leaving Zc with no changes\n        bounding_box = get_bounding_box(keypoint2d)\n        w, h = image.size\n        left, upper, right, lower = scale_box(bounding_box, w, h, 1.5)\n        image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d)\n\n        # Change all hands to right hands\n        if sample['left'] is False:\n            image, keypoint2d = hflip(image, keypoint2d)\n\n        image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)\n        keypoint2d = data['keypoint2d']\n        intrinsic_matrix = data['intrinsic_matrix']\n        keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)\n\n        # noramlize 2D pose:\n        visible = np.ones((self.num_keypoints, ), dtype=np.float32)\n        visible = visible[:, np.newaxis]\n        # 2D heatmap\n        target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)\n        target = torch.from_numpy(target)\n        target_weight = torch.from_numpy(target_weight)\n\n        # normalize 3D pose:\n        # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system\n        # and make distance between wrist and middle finger MCP joint to be of length 1\n        keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]\n        keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))\n        z = keypoint3d_n[:, 2]\n\n        meta = {\n            'image': image_name,\n            'keypoint2d': keypoint2d,  # （NUM_KEYPOINTS x 2）\n            'keypoint3d': keypoint3d_n,  # （NUM_KEYPOINTS x 3）\n            'z': z,\n        }\n\n        return image, target, target_weight, meta\n\n    def get_samples(self, root, version='gs'):\n        set = 'training'\n        # load annotations of this set\n        db_data_anno = load_db_annotation(root, set)\n\n        version_map = {\n            'gs': sample_version.gs,\n            'hom': sample_version.hom,\n            'sample': sample_version.sample,\n            'auto': sample_version.auto\n        }\n        samples = []\n        for idx in range(db_size(set)):\n            image_name = os.path.join(set, 'rgb',\n                                '%08d.jpg' % sample_version.map_id(idx, version_map[version]))\n            mask_name = os.path.join(set, 'mask', '%08d.jpg' % idx)\n            intrinsic_matrix, mano, keypoint3d = db_data_anno[idx]\n            keypoint2d = projectPoints(keypoint3d, intrinsic_matrix)\n\n            sample = {\n                'name': image_name,\n                'mask_name': mask_name,\n                'keypoint2d': keypoint2d,\n                'keypoint3d': keypoint3d,\n                'intrinsic_matrix': intrinsic_matrix,\n                'left': False\n            }\n            samples.append(sample)\n\n        return samples\n"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/hand_3d_studio.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nimport json\nimport random\nfrom PIL import ImageFile, Image\nimport torch\nimport os.path as osp\n\nfrom .._util import download as download_data, check_exits\nfrom .keypoint_dataset import Hand21KeypointDataset\nfrom .util import *\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\nclass Hand3DStudio(Hand21KeypointDataset):\n    \"\"\"`Hand-3d-Studio Dataset <https://www.yangangwang.com/papers/ZHAO-H3S-2020-02.html>`_\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.\n        task (str, optional): The task to create dataset. Choices include ``'noobject'``: only hands without objects, \\\n            ``'object'``: only hands interacting with hands, and ``'all'``: all hands. Default: 'noobject'.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and\n            its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.\n        image_size (tuple): (width, height) of the image. Default: (256, 256)\n        heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\n\n    .. note::\n        We found that the original H3D image is in high resolution while most part in an image is background,\n        thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            H3D_crop/\n                annotation.json\n                part1/\n                part2/\n                part3/\n                part4/\n                part5/\n    \"\"\"\n    def __init__(self, root, split='train', task='noobject', download=True, **kwargs):\n        assert split in ['train', 'test', 'all']\n        self.split = split\n        assert task in ['noobject', 'object', 'all']\n        self.task = task\n\n        if download:\n            download_data(root, \"H3D_crop\", \"H3D_crop.tar\", \"https://cloud.tsinghua.edu.cn/f/d4e612e44dc04d8eb01f/?dl=1\")\n        else:\n            check_exits(root, \"H3D_crop\")\n\n        root = osp.join(root, \"H3D_crop\")\n        # load labels\n        annotation_file = os.path.join(root, 'annotation.json')\n        print(\"loading from {}\".format(annotation_file))\n        with open(annotation_file) as f:\n            samples = list(json.load(f))\n        if task == 'noobject':\n            samples = [sample for sample in samples if int(sample['without_object']) == 1]\n        elif task == 'object':\n            samples = [sample for sample in samples if int(sample['without_object']) == 0]\n\n        random.seed(42)\n        random.shuffle(samples)\n        samples_len = len(samples)\n        samples_split = min(int(samples_len * 0.2), 3200)\n        if split == 'train':\n            samples = samples[samples_split:]\n        elif split == 'test':\n            samples = samples[:samples_split]\n\n        super(Hand3DStudio, self).__init__(root, samples, **kwargs)\n\n    def __getitem__(self, index):\n        sample = self.samples[index]\n        image_name = sample['name']\n        image_path = os.path.join(self.root, image_name)\n        image = Image.open(image_path)\n        keypoint3d_camera = np.array(sample['keypoint3d'])  # NUM_KEYPOINTS x 3\n        keypoint2d = np.array(sample['keypoint2d'])  # NUM_KEYPOINTS x 2\n        intrinsic_matrix = np.array(sample['intrinsic_matrix'])\n        Zc = keypoint3d_camera[:, 2]\n\n        image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)\n        keypoint2d = data['keypoint2d']\n        intrinsic_matrix = data['intrinsic_matrix']\n        keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)\n\n        # noramlize 2D pose:\n        visible = np.ones((self.num_keypoints, ), dtype=np.float32)\n        visible = visible[:, np.newaxis]\n        # 2D heatmap\n        target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)\n        target = torch.from_numpy(target)\n        target_weight = torch.from_numpy(target_weight)\n\n        # normalize 3D pose:\n        # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system\n        # and make distance between wrist and middle finger MCP joint to be of length 1\n        keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]\n        keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))\n\n        meta = {\n            'image': image_name,\n            'keypoint2d': keypoint2d,  # （NUM_KEYPOINTS x 2）\n            'keypoint3d': keypoint3d_n,  # （NUM_KEYPOINTS x 3）\n        }\n        return image, target, target_weight, meta\n\n\nclass Hand3DStudioAll(Hand3DStudio):\n    \"\"\"\n    `Hand-3d-Studio Dataset <https://www.yangangwang.com/papers/ZHAO-H3S-2020-02.html>`_\n\n    \"\"\"\n    def __init__(self,  root, task='all', **kwargs):\n        super(Hand3DStudioAll, self).__init__(root, task=task, **kwargs)"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/human36m.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport os\r\nimport json\r\nimport tqdm\r\nfrom PIL import ImageFile\r\nimport torch\r\nfrom .keypoint_dataset import Body16KeypointDataset\r\nfrom ...transforms.keypoint_detection import *\r\nfrom .util import *\r\n\r\nImageFile.LOAD_TRUNCATED_IMAGES = True\r\n\r\n\r\nclass Human36M(Body16KeypointDataset):\r\n    \"\"\"`Human3.6M Dataset <http://vision.imar.ro/human3.6m/description.php>`_\r\n\r\n    Args:\r\n        root (str): Root directory of dataset\r\n        split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.\r\n            Default: ``train``.\r\n        task (str, optional): Placeholder.\r\n        download (bool, optional): Placeholder.\r\n        transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and\r\n            its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.\r\n        image_size (tuple): (width, height) of the image. Default: (256, 256)\r\n        heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)\r\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\r\n\r\n    .. note:: You need to download Human36M manually.\r\n        Ensure that there exist following files in the `root` directory before you using this class.\r\n        ::\r\n            annotations/\r\n                Human36M_subject11_joint_3d.json\r\n                ...\r\n            images/\r\n\r\n    .. note::\r\n        We found that the original Human3.6M image is in high resolution while most part in an image is background,\r\n        thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training.\r\n        In `root`, there will exist following files after crop.\r\n        ::\r\n            Human36M_crop/\r\n            annotations/\r\n                keypoints2d_11.json\r\n                ...\r\n    \"\"\"\r\n    def __init__(self, root, split='train', task='all', download=True, **kwargs):\r\n        assert split in ['train', 'test', 'all']\r\n        self.split = split\r\n\r\n        samples = []\r\n        if self.split == 'train':\r\n            parts = [1, 5, 6, 7, 8]\r\n        elif self.split == 'test':\r\n            parts = [9, 11]\r\n        else:\r\n            parts = [1, 5, 6, 7, 8, 9, 11]\r\n\r\n        for part in parts:\r\n            annotation_file = os.path.join(root, 'annotations/keypoints2d_{}.json'.format(part))\r\n            if not os.path.exists(annotation_file):\r\n                self.preprocess(part, root)\r\n            print(\"loading\", annotation_file)\r\n            with open(annotation_file) as f:\r\n                samples.extend(json.load(f))\r\n        # decrease the number of test samples to decrease the time spent on test\r\n        random.seed(42)\r\n        if self.split == 'test':\r\n            samples = random.choices(samples, k=3200)\r\n        super(Human36M, self).__init__(root, samples, **kwargs)\r\n\r\n    def __getitem__(self, index):\r\n        sample = self.samples[index]\r\n        image_name = sample['name']\r\n        image_path = os.path.join(self.root, \"crop_images\", image_name)\r\n        image = Image.open(image_path)\r\n        keypoint3d_camera = np.array(sample['keypoint3d'])  # NUM_KEYPOINTS x 3\r\n        keypoint2d = np.array(sample['keypoint2d'])  # NUM_KEYPOINTS x 2\r\n        intrinsic_matrix = np.array(sample['intrinsic_matrix'])\r\n        Zc = keypoint3d_camera[:, 2]\r\n\r\n        image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)\r\n        keypoint2d = data['keypoint2d']\r\n        intrinsic_matrix = data['intrinsic_matrix']\r\n        keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)\r\n\r\n        # noramlize 2D pose:\r\n        visible = np.ones((self.num_keypoints, ), dtype=np.float32)\r\n        visible = visible[:, np.newaxis]\r\n        # 2D heatmap\r\n        target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)\r\n        target = torch.from_numpy(target)\r\n        target_weight = torch.from_numpy(target_weight)\r\n\r\n        # normalize 3D pose:\r\n        # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system\r\n        # and make distance between wrist and middle finger MCP joint to be of length 1\r\n        keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]\r\n        keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))\r\n\r\n        meta = {\r\n            'image': image_name,\r\n            'keypoint2d': keypoint2d,  # （NUM_KEYPOINTS x 2）\r\n            'keypoint3d': keypoint3d_n,  # （NUM_KEYPOINTS x 3）\r\n        }\r\n        return image, target, target_weight, meta\r\n\r\n    def preprocess(self, part, root):\r\n        body_index = [3, 2, 1, 4, 5, 6, 0, 11, 8, 10, 16, 15, 14, 11, 12, 13]\r\n        image_size = 512\r\n        print(\"preprocessing part\", part)\r\n        camera_json = os.path.join(root, \"annotations\", \"Human36M_subject{}_camera.json\".format(part))\r\n        data_json = os.path.join(root, \"annotations\", \"Human36M_subject{}_data.json\".format(part))\r\n        joint_3d_json = os.path.join(root, \"annotations\", \"Human36M_subject{}_joint_3d.json\".format(part))\r\n        with open(camera_json, \"r\") as f:\r\n            cameras = json.load(f)\r\n        with open(data_json, \"r\") as f:\r\n            data = json.load(f)\r\n            images = data['images']\r\n\r\n        with open(joint_3d_json, \"r\") as f:\r\n            joints_3d = json.load(f)\r\n\r\n        data = []\r\n\r\n        for i, image_data in enumerate(tqdm.tqdm(images)):\r\n            # downsample\r\n            if i % 5 == 0:\r\n                keypoint3d = np.array(joints_3d[str(image_data[\"action_idx\"])][str(image_data[\"subaction_idx\"])][\r\n                                          str(image_data[\"frame_idx\"])])\r\n                keypoint3d = keypoint3d[body_index, :]\r\n                keypoint3d[7, :] = 0.5 * (keypoint3d[12, :] + keypoint3d[13, :])\r\n                camera = cameras[str(image_data[\"cam_idx\"])]\r\n                R, T = np.array(camera[\"R\"]), np.array(camera['t'])[:, np.newaxis]\r\n                extrinsic_matrix = np.concatenate([R, T], axis=1)\r\n                keypoint3d_camera = np.matmul(extrinsic_matrix, np.hstack(\r\n                    (keypoint3d, np.ones((keypoint3d.shape[0], 1)))).T)  # (3 x NUM_KEYPOINTS)\r\n                Z_c = keypoint3d_camera[2:3, :]  # 1 x NUM_KEYPOINTS\r\n\r\n                f, c = np.array(camera[\"f\"]), np.array(camera['c'])\r\n                intrinsic_matrix = np.zeros((3, 3))\r\n                intrinsic_matrix[0, 0] = f[0]\r\n                intrinsic_matrix[1, 1] = f[1]\r\n                intrinsic_matrix[0, 2] = c[0]\r\n                intrinsic_matrix[1, 2] = c[1]\r\n                intrinsic_matrix[2, 2] = 1\r\n                keypoint2d = np.matmul(intrinsic_matrix, keypoint3d_camera)  # (3 x NUM_KEYPOINTS)\r\n                keypoint2d = keypoint2d[0: 2, :] / Z_c\r\n                keypoint2d = keypoint2d.T\r\n                src_image_path = os.path.join(root, \"images\", image_data['file_name'])\r\n                tgt_image_path = os.path.join(root, \"crop_images\", image_data['file_name'])\r\n                os.makedirs(os.path.dirname(tgt_image_path), exist_ok=True)\r\n                image = Image.open(src_image_path)\r\n\r\n                bounding_box = get_bounding_box(keypoint2d)\r\n                w, h = image.size\r\n                left, upper, right, lower = scale_box(bounding_box, w, h, 1.5)\r\n                image, keypoint2d = crop(image, upper, left, lower-upper+1, right-left+1, keypoint2d)\r\n                Z_c = Z_c.T\r\n\r\n                # Calculate XYZ from uvz\r\n                uv1 = np.concatenate([np.copy(keypoint2d), np.ones((16, 1))],\r\n                                     axis=1)  # NUM_KEYPOINTS x 3\r\n                uv1 = uv1 * Z_c  # NUM_KEYPOINTS x 3\r\n                keypoint3d_camera = np.matmul(np.linalg.inv(intrinsic_matrix), uv1.T).T\r\n\r\n                # resize image will change camera intrinsic matrix\r\n                w, h = image.size\r\n                image = image.resize((image_size, image_size))\r\n                image.save(tgt_image_path)\r\n\r\n                zoom_factor = float(w) / float(image_size)\r\n                keypoint2d /= zoom_factor\r\n                intrinsic_matrix[0, 0] /= zoom_factor\r\n                intrinsic_matrix[1, 1] /= zoom_factor\r\n                intrinsic_matrix[0, 2] /= zoom_factor\r\n                intrinsic_matrix[1, 2] /= zoom_factor\r\n\r\n                data.append({\r\n                    \"name\": image_data['file_name'],\r\n                    'keypoint2d': keypoint2d.tolist(),\r\n                    'keypoint3d': keypoint3d_camera.tolist(),\r\n                    'intrinsic_matrix': intrinsic_matrix.tolist(),\r\n                })\r\n\r\n        with open(os.path.join(root, \"annotations\", \"keypoints2d_{}.json\".format(part)), \"w\") as f:\r\n            json.dump(data, f)"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/keypoint_dataset.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nfrom abc import ABC\r\nimport numpy as np\r\nfrom torch.utils.data.dataset import Dataset\r\nfrom webcolors import name_to_rgb\r\nimport cv2\r\n\r\n\r\nclass KeypointDataset(Dataset, ABC):\r\n    \"\"\"A generic dataset class for image keypoint detection\r\n\r\n    Args:\r\n        root (str): Root directory of dataset\r\n        num_keypoints (int): Number of keypoints\r\n        samples (list): list of data\r\n        transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and\r\n            its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.\r\n        image_size (tuple): (width, height) of the image. Default: (256, 256)\r\n        heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)\r\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\r\n        keypoints_group (dict): a dict that stores the index of different types of keypoints\r\n        colored_skeleton (dict): a dict that stores the index and color of different skeleton\r\n    \"\"\"\r\n    def __init__(self, root, num_keypoints, samples, transforms=None, image_size=(256, 256), heatmap_size=(64, 64),\r\n                 sigma=2, keypoints_group=None, colored_skeleton=None):\r\n        self.root = root\r\n        self.num_keypoints = num_keypoints\r\n        self.samples = samples\r\n        self.transforms = transforms\r\n        self.image_size = image_size\r\n        self.heatmap_size = heatmap_size\r\n        self.sigma = sigma\r\n        self.keypoints_group = keypoints_group\r\n        self.colored_skeleton = colored_skeleton\r\n\r\n    def __len__(self):\r\n        return len(self.samples)\r\n\r\n    def visualize(self, image, keypoints, filename):\r\n        \"\"\"Visualize an image with its keypoints, and store the result into a file\r\n\r\n        Args:\r\n            image (PIL.Image):\r\n            keypoints (torch.Tensor): keypoints in shape K x 2\r\n            filename (str): the name of file to store\r\n        \"\"\"\r\n        assert self.colored_skeleton is not None\r\n\r\n        image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR).copy()\r\n        for (_, (line, color)) in self.colored_skeleton.items():\r\n            for i in range(len(line) - 1):\r\n                start, end = keypoints[line[i]], keypoints[line[i + 1]]\r\n                cv2.line(image, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), color=name_to_rgb(color),\r\n                         thickness=3)\r\n        for keypoint in keypoints:\r\n            cv2.circle(image, (int(keypoint[0]), int(keypoint[1])), 3, name_to_rgb('black'), 1)\r\n        cv2.imwrite(filename, image)\r\n\r\n    def group_accuracy(self, accuracies):\r\n        \"\"\" Group the accuracy of K keypoints into different kinds.\r\n\r\n        Args:\r\n            accuracies (list): accuracy of the K keypoints\r\n\r\n        Returns:\r\n            accuracy of ``N=len(keypoints_group)`` kinds of keypoints\r\n\r\n        \"\"\"\r\n        grouped_accuracies = dict()\r\n        for name, keypoints in self.keypoints_group.items():\r\n            grouped_accuracies[name] = sum([accuracies[idx] for idx in keypoints]) / len(keypoints)\r\n        return grouped_accuracies\r\n\r\n\r\nclass Body16KeypointDataset(KeypointDataset, ABC):\r\n    \"\"\"\r\n    Dataset with 16 body keypoints.\r\n    \"\"\"\r\n    # TODO: add image\r\n    head = (9,)\r\n    shoulder = (12, 13)\r\n    elbow = (11, 14)\r\n    wrist = (10, 15)\r\n    hip = (2, 3)\r\n    knee = (1, 4)\r\n    ankle = (0, 5)\r\n    all = (12, 13, 11, 14, 10, 15, 2, 3, 1, 4, 0, 5)\r\n    right_leg = (0, 1, 2, 8)\r\n    left_leg = (5, 4, 3, 8)\r\n    backbone = (8, 9)\r\n    right_arm = (10, 11, 12, 8)\r\n    left_arm = (15, 14, 13, 8)\r\n\r\n    def __init__(self, root, samples, **kwargs):\r\n        colored_skeleton = {\r\n            \"right_leg\": (self.right_leg, 'yellow'),\r\n            \"left_leg\": (self.left_leg, 'green'),\r\n            \"backbone\": (self.backbone, 'blue'),\r\n            \"right_arm\": (self.right_arm, 'purple'),\r\n            \"left_arm\": (self.left_arm, 'red'),\r\n        }\r\n        keypoints_group = {\r\n            \"head\": self.head,\r\n            \"shoulder\": self.shoulder,\r\n            \"elbow\": self.elbow,\r\n            \"wrist\": self.wrist,\r\n            \"hip\": self.hip,\r\n            \"knee\": self.knee,\r\n            \"ankle\": self.ankle,\r\n            \"all\": self.all\r\n        }\r\n        super(Body16KeypointDataset, self).__init__(root, 16, samples, keypoints_group=keypoints_group,\r\n                                                    colored_skeleton=colored_skeleton, **kwargs)\r\n\r\n\r\nclass Hand21KeypointDataset(KeypointDataset, ABC):\r\n    \"\"\"\r\n    Dataset with 21 hand keypoints.\r\n    \"\"\"\r\n    # TODO: add image\r\n    MCP = (1, 5, 9, 13, 17)\r\n    PIP = (2, 6, 10, 14, 18)\r\n    DIP = (3, 7, 11, 15, 19)\r\n    fingertip = (4, 8, 12, 16, 20)\r\n    all = tuple(range(21))\r\n    thumb = (0, 1, 2, 3, 4)\r\n    index_finger = (0, 5, 6, 7, 8)\r\n    middle_finger = (0, 9, 10, 11, 12)\r\n    ring_finger = (0, 13, 14, 15, 16)\r\n    little_finger = (0, 17, 18, 19, 20)\r\n\r\n    def __init__(self, root, samples, **kwargs):\r\n        colored_skeleton = {\r\n            \"thumb\": (self.thumb, 'yellow'),\r\n            \"index_finger\": (self.index_finger, 'green'),\r\n            \"middle_finger\": (self.middle_finger, 'blue'),\r\n            \"ring_finger\": (self.ring_finger, 'purple'),\r\n            \"little_finger\": (self.little_finger, 'red'),\r\n        }\r\n        keypoints_group = {\r\n            \"MCP\": self.MCP,\r\n            \"PIP\": self.PIP,\r\n            \"DIP\": self.DIP,\r\n            \"fingertip\": self.fingertip,\r\n            \"all\": self.all\r\n        }\r\n        super(Hand21KeypointDataset, self).__init__(root, 21, samples, keypoints_group=keypoints_group,\r\n                                                    colored_skeleton=colored_skeleton, **kwargs)\r\n"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/lsp.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport scipy.io as scio\nimport os\n\nfrom PIL import ImageFile\nimport torch\nfrom .keypoint_dataset import Body16KeypointDataset\nfrom ...transforms.keypoint_detection import *\nfrom .util import *\nfrom .._util import download as download_data, check_exits\n\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\nclass LSP(Body16KeypointDataset):\n    \"\"\"`Leeds Sports Pose Dataset <http://sam.johnson.io/research/lsp.html>`_\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): PlaceHolder.\n        task (str, optional): Placeholder.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transforms (callable, optional): PlaceHolder.\n        heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            lsp/\n                images/\n                joints.mat\n\n    .. note::\n        LSP is only used for target domain. Due to the small dataset size, the whole dataset is used\n        no matter what ``split`` is. Also, the transform is fixed.\n    \"\"\"\n    def __init__(self, root, split='train', task='all', download=True, image_size=(256, 256), transforms=None, **kwargs):\n        if download:\n            download_data(root, \"images\", \"lsp_dataset.zip\",\n                          \"https://cloud.tsinghua.edu.cn/f/46ea73c89abc46bfb125/?dl=1\")\n        else:\n            check_exits(root, \"lsp\")\n\n        assert split in ['train', 'test', 'all']\n        self.split = split\n\n        samples = []\n        annotations = scio.loadmat(os.path.join(root, \"joints.mat\"))['joints'].transpose((2, 1, 0))\n        for i in range(0, 2000):\n            image = \"im{0:04d}.jpg\".format(i+1)\n            annotation = annotations[i]\n            samples.append((image, annotation))\n\n        self.joints_index = (0, 1, 2, 3, 4, 5, 13, 13, 12, 13, 6, 7, 8, 9, 10, 11)\n        self.visible = np.array([1.] * 6 + [0, 0] + [1.] * 8, dtype=np.float32)\n        normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n        transforms = Compose([\n            ResizePad(image_size[0]),\n            ToTensor(),\n            normalize\n        ])\n        super(LSP, self).__init__(root, samples, transforms=transforms, image_size=image_size, **kwargs)\n\n    def __getitem__(self, index):\n        sample = self.samples[index]\n        image_name = sample[0]\n        image = Image.open(os.path.join(self.root, \"images\", image_name))\n        keypoint2d = sample[1][self.joints_index, :2]\n        image, data = self.transforms(image, keypoint2d=keypoint2d)\n        keypoint2d = data['keypoint2d']\n        visible = self.visible * (1-sample[1][self.joints_index, 2])\n        visible = visible[:, np.newaxis]\n\n        # 2D heatmap\n        target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)\n        target = torch.from_numpy(target)\n        target_weight = torch.from_numpy(target_weight)\n\n        meta = {\n            'image': image_name,\n            'keypoint2d': keypoint2d,  # （NUM_KEYPOINTS x 2）\n            'keypoint3d': np.zeros((self.num_keypoints, 3)).astype(keypoint2d.dtype),  # （NUM_KEYPOINTS x 3）\n        }\n        return image, target, target_weight, meta\n"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/rendered_hand_pose.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport torch\r\nimport os\r\nimport pickle\r\n\r\nfrom .._util import download as download_data, check_exits\r\nfrom ...transforms.keypoint_detection import *\r\nfrom .keypoint_dataset import Hand21KeypointDataset\r\nfrom .util import *\r\n\r\n\r\nclass RenderedHandPose(Hand21KeypointDataset):\r\n    \"\"\"`Rendered Handpose Dataset <https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html>`_\r\n\r\n    Args:\r\n        root (str): Root directory of dataset\r\n        split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.\r\n        task (str, optional): Placeholder.\r\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\r\n            in root directory. If dataset is already downloaded, it is not downloaded again.\r\n        transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and\r\n            its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.\r\n        image_size (tuple): (width, height) of the image. Default: (256, 256)\r\n        heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)\r\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\r\n\r\n    .. note:: In `root`, there will exist following files after downloading.\r\n        ::\r\n            RHD_published_v2/\r\n                training/\r\n                evaluation/\r\n    \"\"\"\r\n    def __init__(self, root, split='train', task='all', download=True, **kwargs):\r\n        if download:\r\n            download_data(root, \"RHD_published_v2\", \"RHD_v1-1.zip\", \"https://lmb.informatik.uni-freiburg.de/data/RenderedHandpose/RHD_v1-1.zip\")\r\n        else:\r\n            check_exits(root, \"RHD_published_v2\")\r\n\r\n        root = os.path.join(root, \"RHD_published_v2\")\r\n\r\n        assert split in ['train', 'test', 'all']\r\n        self.split = split\r\n        if split == 'all':\r\n            samples = self.get_samples(root, 'train') + self.get_samples(root, 'test')\r\n        else:\r\n            samples = self.get_samples(root, split)\r\n\r\n        super(RenderedHandPose, self).__init__(\r\n            root, samples, **kwargs)\r\n\r\n    def __getitem__(self, index):\r\n        sample = self.samples[index]\r\n        image_name = sample['name']\r\n        image_path = os.path.join(self.root, image_name)\r\n        image = Image.open(image_path)\r\n\r\n        keypoint3d_camera = np.array(sample['keypoint3d'])  # NUM_KEYPOINTS x 3\r\n        keypoint2d = np.array(sample['keypoint2d'])  # NUM_KEYPOINTS x 2\r\n        intrinsic_matrix = np.array(sample['intrinsic_matrix'])\r\n        Zc = keypoint3d_camera[:, 2]\r\n\r\n        # Crop the images such that the hand is at the center of the image\r\n        # The images will be 1.5 times larger than the hand\r\n        # The crop process will change Xc and Yc, leaving Zc with no changes\r\n        bounding_box = get_bounding_box(keypoint2d)\r\n        w, h = image.size\r\n        left, upper, right, lower = scale_box(bounding_box, w, h, 1.5)\r\n        image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d)\r\n\r\n        # Change all hands to right hands\r\n        if sample['left'] is False:\r\n            image, keypoint2d = hflip(image, keypoint2d)\r\n\r\n        image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)\r\n        keypoint2d = data['keypoint2d']\r\n        intrinsic_matrix = data['intrinsic_matrix']\r\n        keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)\r\n\r\n        # noramlize 2D pose:\r\n        visible = np.array(sample['visible'], dtype=np.float32)\r\n        visible = visible[:, np.newaxis]\r\n        # 2D heatmap\r\n        target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)\r\n        target = torch.from_numpy(target)\r\n        target_weight = torch.from_numpy(target_weight)\r\n\r\n        # normalize 3D pose:\r\n        # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system\r\n        # and make distance between wrist and middle finger MCP joint to be of length 1\r\n        keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]\r\n        keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))\r\n        z = keypoint3d_n[:, 2]\r\n\r\n        meta = {\r\n            'image': image_name,\r\n            'keypoint2d': keypoint2d,  # （NUM_KEYPOINTS x 2）\r\n            'keypoint3d': keypoint3d_n,  # （NUM_KEYPOINTS x 3）\r\n            'z': z,\r\n        }\r\n\r\n        return image, target, target_weight, meta\r\n\r\n    def get_samples(self, root, task, min_size=64):\r\n        if task == 'train':\r\n            set = 'training'\r\n        else:\r\n            set = 'evaluation'\r\n        # load annotations of this set\r\n        with open(os.path.join(root, set, 'anno_%s.pickle' % set), 'rb') as fi:\r\n            anno_all = pickle.load(fi)\r\n\r\n        samples = []\r\n        left_hand_index = [0, 4, 3, 2, 1, 8, 7, 6, 5, 12, 11, 10, 9, 16, 15, 14, 13, 20, 19, 18, 17]\r\n        right_hand_index = [i+21 for i in left_hand_index]\r\n        for sample_id, anno in anno_all.items():\r\n            image_name = os.path.join(set, 'color', '%.5d.png' % sample_id)\r\n            mask_name = os.path.join(set, 'mask', '%.5d.png' % sample_id)\r\n            keypoint2d = anno['uv_vis'][:, :2]\r\n            keypoint3d = anno['xyz']\r\n            intrinsic_matrix = anno['K']\r\n            visible = anno['uv_vis'][:, 2]\r\n\r\n            left_hand_keypoint2d = keypoint2d[left_hand_index] # NUM_KEYPOINTS x 2\r\n            left_box = get_bounding_box(left_hand_keypoint2d)\r\n            right_hand_keypoint2d = keypoint2d[right_hand_index]  # NUM_KEYPOINTS x 2\r\n            right_box = get_bounding_box(right_hand_keypoint2d)\r\n\r\n            w, h = 320, 320\r\n            scaled_left_box = scale_box(left_box, w, h, 1.5)\r\n            left, upper, right, lower = scaled_left_box\r\n            size = max(right - left, lower - upper)\r\n            if size > min_size and np.sum(visible[left_hand_index]) > 16 and area(*intersection(scaled_left_box, right_box)) / area(*scaled_left_box) < 0.3:\r\n                sample = {\r\n                    'name': image_name,\r\n                    'mask_name': mask_name,\r\n                    'keypoint2d': left_hand_keypoint2d,\r\n                    'visible': visible[left_hand_index],\r\n                    'keypoint3d': keypoint3d[left_hand_index],\r\n                    'intrinsic_matrix': intrinsic_matrix,\r\n                    'left': True\r\n                }\r\n                samples.append(sample)\r\n\r\n            scaled_right_box = scale_box(right_box, w, h, 1.5)\r\n            left, upper, right, lower = scaled_right_box\r\n            size = max(right - left, lower - upper)\r\n            if size > min_size and np.sum(visible[right_hand_index]) > 16 and area(*intersection(scaled_right_box, left_box)) / area(*scaled_right_box) < 0.3:\r\n                sample = {\r\n                    'name': image_name,\r\n                    'mask_name': mask_name,\r\n                    'keypoint2d': right_hand_keypoint2d,\r\n                    'visible': visible[right_hand_index],\r\n                    'keypoint3d': keypoint3d[right_hand_index],\r\n                    'intrinsic_matrix': intrinsic_matrix,\r\n                    'left': False\r\n                }\r\n                samples.append(sample)\r\n\r\n        return samples"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/surreal.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport os\r\nimport json\r\nfrom PIL import ImageFile\r\nimport torch\r\nfrom ...transforms.keypoint_detection import *\r\nfrom .util import *\r\nfrom .._util import download as download_data, check_exits\r\nfrom .keypoint_dataset import Body16KeypointDataset\r\n\r\nImageFile.LOAD_TRUNCATED_IMAGES = True\r\n\r\n\r\nclass SURREAL(Body16KeypointDataset):\r\n    \"\"\"`Surreal Dataset <https://www.di.ens.fr/willow/research/surreal/data/>`_\r\n\r\n    Args:\r\n        root (str): Root directory of dataset\r\n        split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.\r\n            Default: ``train``.\r\n        task (str, optional): Placeholder.\r\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\r\n            in root directory. If dataset is already downloaded, it is not downloaded again.\r\n        transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and\r\n            its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.\r\n        image_size (tuple): (width, height) of the image. Default: (256, 256)\r\n        heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)\r\n        sigma (int): sigma parameter when generate the heatmap. Default: 2\r\n\r\n    .. note::\r\n        We found that the original Surreal image is in high resolution while most part in an image is background,\r\n        thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training.\r\n\r\n    .. note:: In `root`, there will exist following files after downloading.\r\n        ::\r\n            train/\r\n            test/\r\n            val/\r\n    \"\"\"\r\n    def __init__(self, root, split='train', task='all', download=True, **kwargs):\r\n        assert split in ['train', 'test', 'val']\r\n        self.split = split\r\n\r\n        if download:\r\n            download_data(root, \"train/run0\", \"train0.tgz\", \"https://cloud.tsinghua.edu.cn/f/b13604f06ff1445c830a/?dl=1\")\r\n            download_data(root, \"train/run1\", \"train1.tgz\", \"https://cloud.tsinghua.edu.cn/f/919aefe2de3541c3b940/?dl=1\")\r\n            download_data(root, \"train/run1\", \"train2.tgz\", \"https://cloud.tsinghua.edu.cn/f/34864760ad4945b9bcd6/?dl=1\")\r\n            download_data(root, \"val\", \"val.tgz\", \"https://cloud.tsinghua.edu.cn/f/16b20f2e76684f848dc1/?dl=1\")\r\n            download_data(root, \"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/36c72d86e43540e0a913/?dl=1\")\r\n        else:\r\n            check_exits(root, \"train/run0\")\r\n            check_exits(root, \"train/run1\")\r\n            check_exits(root, \"train/run2\")\r\n            check_exits(root, \"val\")\r\n            check_exits(root, \"test\")\r\n\r\n        all_samples = []\r\n        for part in [0, 1, 2]:\r\n            annotation_file = os.path.join(root, split, 'run{}.json'.format(part))\r\n            print(\"loading\", annotation_file)\r\n            with open(annotation_file) as f:\r\n                samples = json.load(f)\r\n                for sample in samples:\r\n                    sample[\"image_path\"] = os.path.join(root, self.split, 'run{}'.format(part), sample['name'])\r\n                all_samples.extend(samples)\r\n\r\n        random.seed(42)\r\n        random.shuffle(all_samples)\r\n        samples_len = len(all_samples)\r\n        samples_split = min(int(samples_len * 0.2), 3200)\r\n        if self.split == 'train':\r\n            all_samples = all_samples[samples_split:]\r\n        elif self.split == 'test':\r\n            all_samples = all_samples[:samples_split]\r\n        self.joints_index = (7, 4, 1, 2, 5, 8, 0, 9, 12, 15, 20, 18, 13, 14, 19, 21)\r\n\r\n        super(SURREAL, self).__init__(root, all_samples, **kwargs)\r\n\r\n    def __getitem__(self, index):\r\n        sample = self.samples[index]\r\n        image_name = sample['name']\r\n\r\n        image_path = sample['image_path']\r\n        image = Image.open(image_path)\r\n        keypoint3d_camera = np.array(sample['keypoint3d'])[self.joints_index, :]  # NUM_KEYPOINTS x 3\r\n        keypoint2d = np.array(sample['keypoint2d'])[self.joints_index, :]  # NUM_KEYPOINTS x 2\r\n        intrinsic_matrix = np.array(sample['intrinsic_matrix'])\r\n        Zc = keypoint3d_camera[:, 2]\r\n\r\n        image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)\r\n        keypoint2d = data['keypoint2d']\r\n        intrinsic_matrix = data['intrinsic_matrix']\r\n        keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)\r\n\r\n        # noramlize 2D pose:\r\n        visible = np.array([1.] * 16, dtype=np.float32)\r\n        visible = visible[:, np.newaxis]\r\n\r\n        # 2D heatmap\r\n        target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)\r\n        target = torch.from_numpy(target)\r\n        target_weight = torch.from_numpy(target_weight)\r\n\r\n        # normalize 3D pose:\r\n        # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system\r\n        # and make distance between wrist and middle finger MCP joint to be of length 1\r\n        keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]\r\n        keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))\r\n\r\n        meta = {\r\n            'image': image_name,\r\n            'keypoint2d': keypoint2d,  # （NUM_KEYPOINTS x 2）\r\n            'keypoint3d': keypoint3d_n,  # （NUM_KEYPOINTS x 3）\r\n        }\r\n        return image, target, target_weight, meta\r\n\r\n    def __len__(self):\r\n        return len(self.samples)\r\n"
  },
  {
    "path": "tllib/vision/datasets/keypoint_detection/util.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\nimport numpy as np\r\nimport cv2\r\n\r\n\r\ndef generate_target(joints, joints_vis, heatmap_size, sigma, image_size):\r\n    \"\"\"Generate heatamap for joints.\r\n\r\n    Args:\r\n        joints: (K, 2)\r\n        joints_vis: (K, 1)\r\n        heatmap_size: W, H\r\n        sigma:\r\n        image_size:\r\n\r\n    Returns:\r\n\r\n    \"\"\"\r\n    num_joints = joints.shape[0]\r\n    target_weight = np.ones((num_joints, 1), dtype=np.float32)\r\n    target_weight[:, 0] = joints_vis[:, 0]\r\n\r\n    target = np.zeros((num_joints,\r\n                       heatmap_size[1],\r\n                       heatmap_size[0]),\r\n                      dtype=np.float32)\r\n\r\n    tmp_size = sigma * 3\r\n    image_size = np.array(image_size)\r\n    heatmap_size = np.array(heatmap_size)\r\n\r\n    for joint_id in range(num_joints):\r\n        feat_stride = image_size / heatmap_size\r\n        mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)\r\n        mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)\r\n        # Check that any part of the gaussian is in-bounds\r\n        ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]\r\n        br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]\r\n        if mu_x >= heatmap_size[0] or mu_y >= heatmap_size[1] \\\r\n                or mu_x < 0 or mu_y < 0:\r\n            # If not, just return the image as is\r\n            target_weight[joint_id] = 0\r\n            continue\r\n\r\n        # Generate gaussian\r\n        size = 2 * tmp_size + 1\r\n        x = np.arange(0, size, 1, np.float32)\r\n        y = x[:, np.newaxis]\r\n        x0 = y0 = size // 2\r\n        # The gaussian is not normalized, we want the center value to equal 1\r\n        g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))\r\n\r\n        # Usable gaussian range\r\n        g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]\r\n        g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]\r\n        # Image range\r\n        img_x = max(0, ul[0]), min(br[0], heatmap_size[0])\r\n        img_y = max(0, ul[1]), min(br[1], heatmap_size[1])\r\n\r\n        v = target_weight[joint_id]\r\n        if v > 0.5:\r\n            target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \\\r\n                g[g_y[0]:g_y[1], g_x[0]:g_x[1]]\r\n\r\n    return target, target_weight\r\n\r\n\r\ndef keypoint2d_to_3d(keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, Zc: np.ndarray):\r\n    \"\"\"Convert 2D keypoints to 3D keypoints\"\"\"\r\n    uv1 = np.concatenate([np.copy(keypoint2d), np.ones((keypoint2d.shape[0], 1))], axis=1).T * Zc  # 3 x NUM_KEYPOINTS\r\n    xyz = np.matmul(np.linalg.inv(intrinsic_matrix), uv1).T  # NUM_KEYPOINTS x 3\r\n    return xyz\r\n\r\n\r\ndef keypoint3d_to_2d(keypoint3d: np.ndarray, intrinsic_matrix: np.ndarray):\r\n    \"\"\"Convert 3D keypoints to 2D keypoints\"\"\"\r\n    keypoint2d = np.matmul(intrinsic_matrix, keypoint3d.T).T  # NUM_KEYPOINTS x 3\r\n    keypoint2d = keypoint2d[:, :2] / keypoint2d[:, 2:3]  # NUM_KEYPOINTS x 2\r\n    return keypoint2d\r\n\r\n\r\ndef scale_box(box, image_width, image_height, scale):\r\n    \"\"\"\r\n    Change `box` to a square box.\r\n    The side with of the square box will be `scale` * max(w, h)\r\n    where w and h is the width and height of the origin box\r\n    \"\"\"\r\n    left, upper, right, lower = box\r\n    center_x, center_y = (left + right) / 2, (upper + lower) / 2\r\n    w, h = right - left, lower - upper\r\n    side_with = min(round(scale * max(w, h)), min(image_width, image_height))\r\n    left = round(center_x - side_with / 2)\r\n    right = left + side_with - 1\r\n    upper = round(center_y - side_with / 2)\r\n    lower = upper + side_with - 1\r\n    if left < 0:\r\n        left = 0\r\n        right = side_with - 1\r\n    if right >= image_width:\r\n        right = image_width - 1\r\n        left = image_width - side_with\r\n    if upper < 0:\r\n        upper = 0\r\n        lower = side_with -1\r\n    if lower >= image_height:\r\n        lower = image_height - 1\r\n        upper = image_height - side_with\r\n    return left, upper, right, lower\r\n\r\n\r\ndef get_bounding_box(keypoint2d: np.array):\r\n    \"\"\"Get the bounding box for keypoints\"\"\"\r\n    left = np.min(keypoint2d[:, 0])\r\n    right = np.max(keypoint2d[:, 0])\r\n    upper = np.min(keypoint2d[:, 1])\r\n    lower = np.max(keypoint2d[:, 1])\r\n    return left, upper, right, lower\r\n\r\n\r\ndef visualize_heatmap(image, heatmaps, filename):\r\n    image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR).copy()\r\n    H, W = heatmaps.shape[1], heatmaps.shape[2]\r\n    resized_image = cv2.resize(image, (int(W), int(H)))\r\n    heatmaps = heatmaps.mul(255).clamp(0, 255).byte().cpu().numpy()\r\n    for k in range(heatmaps.shape[0]):\r\n        heatmap = heatmaps[k]\r\n        colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)\r\n        masked_image = colored_heatmap * 0.7 + resized_image * 0.3\r\n        cv2.imwrite(filename.format(k), masked_image)\r\n        \r\n\r\ndef area(left, upper, right, lower):\r\n    return max(right - left + 1, 0) * max(lower - upper + 1, 0)\r\n\r\n\r\ndef intersection(box_a, box_b):\r\n    left_a, upper_a, right_a, lower_a = box_a\r\n    left_b, upper_b, right_b, lower_b = box_b\r\n    return max(left_a, left_b), max(upper_a, upper_b), min(right_a, right_b), min(lower_a, lower_b)\r\n"
  },
  {
    "path": "tllib/vision/datasets/object_detection/__init__.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport numpy as np\nimport os\nimport xml.etree.ElementTree as ET\n\nfrom detectron2.data import (\n    MetadataCatalog,\n    DatasetCatalog,\n)\nfrom detectron2.utils.file_io import PathManager\nfrom detectron2.structures import BoxMode\nfrom tllib.vision.datasets._util import download as download_dataset\n\n\ndef parse_root_and_file_name(path):\n    path_list = path.split('/')\n    dataset_root = '/'.join(path_list[:-1])\n    file_name = path_list[-1]\n    if dataset_root == '':\n        dataset_root = '.'\n    return dataset_root, file_name\n\n\nclass VOCBase:\n    class_names = (\n        \"aeroplane\", \"bicycle\", \"bird\", \"boat\", \"bottle\", \"bus\", \"car\", \"cat\",\n        \"chair\", \"cow\", \"diningtable\", \"dog\", \"horse\", \"motorbike\", \"person\",\n        \"pottedplant\", \"sheep\", \"sofa\", \"train\", \"tvmonitor\"\n    )\n\n    def __init__(self, root, split=\"trainval\", year=2007, ext='.jpg', download=True):\n        self.name = \"{}_{}\".format(root, split)\n        self.name = self.name.replace(os.path.sep, \"_\")\n        if self.name not in MetadataCatalog.keys():\n            register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext)\n            MetadataCatalog.get(self.name).evaluator_type = \"pascal_voc\"\n        if download:\n            dataset_root, file_name = parse_root_and_file_name(root)\n            download_dataset(dataset_root, file_name, self.archive_name, self.dataset_url)\n\n\nclass VOC2007(VOCBase):\n    archive_name = 'VOC2007.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'\n\n    def __init__(self, root):\n        super(VOC2007, self).__init__(root)\n\n\nclass VOC2012(VOCBase):\n    archive_name = 'VOC2012.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/a7e7ab88f727408eaf32/?dl=1'\n\n    def __init__(self, root):\n        super(VOC2012, self).__init__(root, year=2012)\n\n\nclass VOC2007Test(VOCBase):\n    archive_name = 'VOC2007.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'\n\n    def __init__(self, root):\n        super(VOC2007Test, self).__init__(root, year=2007, split='test')\n\n\nclass Clipart(VOCBase):\n    archive_name = 'clipart.zip'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/c853a66786e2416a8f18/?dl=1'\n\n\nclass VOCPartialBase:\n    class_names = (\n        \"bicycle\", \"bird\", \"car\", \"cat\", \"dog\", \"person\",\n    )\n\n    def __init__(self, root, split=\"trainval\", year=2007, ext='.jpg', download=True):\n        self.name = \"{}_{}\".format(root, split)\n        self.name = self.name.replace(os.path.sep, \"_\")\n        if self.name not in MetadataCatalog.keys():\n            register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext)\n            MetadataCatalog.get(self.name).evaluator_type = \"pascal_voc\"\n        if download:\n            dataset_root, file_name = parse_root_and_file_name(root)\n            download_dataset(dataset_root, file_name, self.archive_name, self.dataset_url)\n\n\nclass VOC2007Partial(VOCPartialBase):\n    archive_name = 'VOC2007.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'\n\n    def __init__(self, root):\n        super(VOC2007Partial, self).__init__(root)\n\n\nclass VOC2012Partial(VOCPartialBase):\n    archive_name = 'VOC2012.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/a7e7ab88f727408eaf32/?dl=1'\n\n    def __init__(self, root):\n        super(VOC2012Partial, self).__init__(root, year=2012)\n\n\nclass VOC2007PartialTest(VOCPartialBase):\n    archive_name = 'VOC2007.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'\n\n    def __init__(self, root):\n        super(VOC2007PartialTest, self).__init__(root, year=2007, split='test')\n\n\nclass WaterColor(VOCPartialBase):\n    archive_name = 'watercolor.zip'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/9f322fd8496f4766ad93/?dl=1'\n\n    def __init__(self, root):\n        super(WaterColor, self).__init__(root, split='train')\n\n\nclass WaterColorTest(VOCPartialBase):\n    archive_name = 'watercolor.zip'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/9f322fd8496f4766ad93/?dl=1'\n\n    def __init__(self, root):\n        super(WaterColorTest, self).__init__(root, split='test')\n\n\nclass Comic(VOCPartialBase):\n    archive_name = 'comic.tar'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/030d7b4b649f46589b2d/?dl=1'\n\n    def __init__(self, root):\n        super(Comic, self).__init__(root, split='train')\n\n\nclass ComicTest(VOCPartialBase):\n    archive_name = 'comic.tar'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/030d7b4b649f46589b2d/?dl=1'\n\n    def __init__(self, root):\n        super(ComicTest, self).__init__(root, split='test')\n\n\nclass CityscapesBase:\n    class_names = (\n        \"bicycle\", \"bus\", \"car\", \"motorcycle\", \"person\", \"rider\", \"train\", \"truck\",\n    )\n\n    def __init__(self, root, split=\"trainval\", year=2007, ext='.png'):\n        self.name = \"{}_{}\".format(root, split)\n        self.name = self.name.replace(os.path.sep, \"_\")\n        if self.name not in MetadataCatalog.keys():\n            register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext,\n                                bbox_zero_based=True)\n            MetadataCatalog.get(self.name).evaluator_type = \"pascal_voc\"\n\n\nclass Cityscapes(CityscapesBase):\n    def __init__(self, root):\n        super(Cityscapes, self).__init__(root, split=\"trainval\")\n\n\nclass CityscapesTest(CityscapesBase):\n    def __init__(self, root):\n        super(CityscapesTest, self).__init__(root, split='test')\n\n\nclass FoggyCityscapes(Cityscapes):\n    pass\n\n\nclass FoggyCityscapesTest(CityscapesTest):\n    pass\n\n\nclass CityscapesCarBase:\n    class_names = (\n        \"car\",\n    )\n\n    def __init__(self, root, split=\"trainval\", year=2007, ext='.png', bbox_zero_based=True):\n        self.name = \"{}_{}\".format(root, split)\n        self.name = self.name.replace(os.path.sep, \"_\")\n        if self.name not in MetadataCatalog.keys():\n            register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext,\n                                bbox_zero_based=bbox_zero_based)\n            MetadataCatalog.get(self.name).evaluator_type = \"pascal_voc\"\n\n\nclass CityscapesCar(CityscapesCarBase):\n    pass\n\n\nclass CityscapesCarTest(CityscapesCarBase):\n    def __init__(self, root):\n        super(CityscapesCarTest, self).__init__(root, split='test')\n\n\nclass Sim10kCar(CityscapesCarBase):\n    def __init__(self, root):\n        super(Sim10kCar, self).__init__(root, split='trainval10k', ext='.jpg', bbox_zero_based=False)\n\n\nclass KITTICar(CityscapesCarBase):\n    def __init__(self, root):\n        super(KITTICar, self).__init__(root, split='trainval', ext='.jpg', bbox_zero_based=False)\n\n\nclass GTA5(CityscapesBase):\n    def __init__(self, root):\n        super(GTA5, self).__init__(root, split=\"trainval\", ext='.jpg')\n\n\ndef load_voc_instances(dirname: str, split: str, class_names, ext='.jpg', bbox_zero_based=False):\n    \"\"\"\n    Load Pascal VOC detection annotations to Detectron2 format.\n\n    Args:\n        dirname: Contain \"Annotations\", \"ImageSets\", \"JPEGImages\"\n        split (str): one of \"train\", \"test\", \"val\", \"trainval\"\n        class_names: list or tuple of class names\n    \"\"\"\n    with PathManager.open(os.path.join(dirname, \"ImageSets\", \"Main\", split + \".txt\")) as f:\n        fileids = np.loadtxt(f, dtype=np.str)\n\n    # Needs to read many small annotation files. Makes sense at local\n    annotation_dirname = PathManager.get_local_path(os.path.join(dirname, \"Annotations/\"))\n    dicts = []\n    skip_classes = set()\n    for fileid in fileids:\n        anno_file = os.path.join(annotation_dirname, fileid + \".xml\")\n        jpeg_file = os.path.join(dirname, \"JPEGImages\", fileid + ext)\n\n        with PathManager.open(anno_file) as f:\n            tree = ET.parse(f)\n\n        r = {\n            \"file_name\": jpeg_file,\n            \"image_id\": fileid,\n            \"height\": int(tree.findall(\"./size/height\")[0].text),\n            \"width\": int(tree.findall(\"./size/width\")[0].text),\n        }\n        instances = []\n\n        for obj in tree.findall(\"object\"):\n            cls = obj.find(\"name\").text\n            if cls not in class_names:\n                skip_classes.add(cls)\n                continue\n            # We include \"difficult\" samples in training.\n            # Based on limited experiments, they don't hurt accuracy.\n            # difficult = int(obj.find(\"difficult\").text)\n            # if difficult == 1:\n            # continue\n            bbox = obj.find(\"bndbox\")\n            bbox = [float(bbox.find(x).text) for x in [\"xmin\", \"ymin\", \"xmax\", \"ymax\"]]\n            # Original annotations are integers in the range [1, W or H]\n            # Assuming they mean 1-based pixel indices (inclusive),\n            # a box with annotation (xmin=1, xmax=W) covers the whole image.\n            # In coordinate space this is represented by (xmin=0, xmax=W)\n            if bbox_zero_based is False:\n                bbox[0] -= 1.0\n                bbox[1] -= 1.0\n            instances.append(\n                {\"category_id\": class_names.index(cls), \"bbox\": bbox, \"bbox_mode\": BoxMode.XYXY_ABS}\n            )\n        r[\"annotations\"] = instances\n        dicts.append(r)\n    print(\"Skip classes:\", list(skip_classes))\n    return dicts\n\n\ndef register_pascal_voc(name, dirname, split, year, class_names, **kwargs):\n    DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names, **kwargs))\n    MetadataCatalog.get(name).set(\n        thing_classes=list(class_names), dirname=dirname, year=year, split=split\n    )\n"
  },
  {
    "path": "tllib/vision/datasets/office31.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass Office31(ImageList):\n    \"\"\"Office31 Dataset.\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \\\n            ``'D'``: dslr and ``'W'``: webcam.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            amazon/\n                images/\n                    backpack/\n                        *.jpg\n                        ...\n            dslr/\n            webcam/\n            image_list/\n                amazon.txt\n                dslr.txt\n                webcam.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/2c1dd9fbcaa9455aa4ad/?dl=1\"),\n        (\"amazon\", \"amazon.tgz\", \"https://cloud.tsinghua.edu.cn/f/ec12dfcddade43ab8101/?dl=1\"),\n        (\"dslr\", \"dslr.tgz\", \"https://cloud.tsinghua.edu.cn/f/a41d818ae2f34da7bb32/?dl=1\"),\n        (\"webcam\", \"webcam.tgz\", \"https://cloud.tsinghua.edu.cn/f/8a41009a166e4131adcd/?dl=1\"),\n    ]\n    image_list = {\n        \"A\": \"image_list/amazon.txt\",\n        \"D\": \"image_list/dslr.txt\",\n        \"W\": \"image_list/webcam.txt\"\n    }\n    CLASSES = ['back_pack', 'bike', 'bike_helmet', 'bookcase', 'bottle', 'calculator', 'desk_chair', 'desk_lamp',\n               'desktop_computer', 'file_cabinet', 'headphones', 'keyboard', 'laptop_computer', 'letter_tray',\n               'mobile_phone', 'monitor', 'mouse', 'mug', 'paper_notebook', 'pen', 'phone', 'printer', 'projector',\n               'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):\n        assert task in self.image_list\n        data_list_file = os.path.join(root, self.image_list[task])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(Office31, self).__init__(root, Office31.CLASSES, data_list_file=data_list_file, **kwargs)\n\n    @classmethod\n    def domains(cls):\n        return list(cls.image_list.keys())"
  },
  {
    "path": "tllib/vision/datasets/officecaltech.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS, default_loader\nfrom torchvision.datasets.utils import download_and_extract_archive\nfrom ._util import check_exits\n\n\nclass OfficeCaltech(DatasetFolder):\n    \"\"\"Office+Caltech Dataset.\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \\\n            ``'D'``: dslr, ``'W'``:webcam and ``'C'``: caltech.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            amazon/\n                images/\n                    backpack/\n                        *.jpg\n                        ...\n            dslr/\n            webcam/\n            caltech/\n            image_list/\n                amazon.txt\n                dslr.txt\n                webcam.txt\n                caltech.txt\n    \"\"\"\n    directories = {\n        \"A\": \"amazon\",\n        \"D\": \"dslr\",\n        \"W\": \"webcam\",\n        \"C\": \"caltech\"\n    }\n    CLASSES = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard',\n               'laptop_computer', 'monitor', 'mouse', 'mug', 'projector']\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):\n        if download:\n            for dir in self.directories.values():\n                if not os.path.exists(os.path.join(root, dir)):\n                    download_and_extract_archive(url=\"https://cloud.tsinghua.edu.cn/f/eea518fa781a41d1b20e/?dl=1\",\n                                                 download_root=os.path.join(root, 'download'),\n                                                 filename=\"office-caltech.tgz\", remove_finished=False,\n                                                 extract_root=root)\n                    break\n        else:\n            list(map(lambda dir, _: check_exits(root, dir), self.directories.values()))\n\n        super(OfficeCaltech, self).__init__(\n            os.path.join(root, self.directories[task]), default_loader, extensions=IMG_EXTENSIONS, **kwargs)\n        self.classes = OfficeCaltech.CLASSES\n        self.class_to_idx = {cls: idx\n                             for idx, clss in enumerate(self.classes)\n                             for cls in clss}\n\n    @property\n    def num_classes(self):\n        \"\"\"Number of classes\"\"\"\n        return len(self.classes)\n\n    @classmethod\n    def domains(cls):\n        return list(cls.directories.keys())\n"
  },
  {
    "path": "tllib/vision/datasets/officehome.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass OfficeHome(ImageList):\n    \"\"\"`OfficeHome <http://hemanthdv.org/OfficeHome-Dataset/>`_ Dataset.\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \\\n            ``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            Art/\n                Alarm_Clock/*.jpg\n                ...\n            Clipart/\n            Product/\n            Real_World/\n            image_list/\n                Art.txt\n                Clipart.txt\n                Product.txt\n                Real_World.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/1b0171a188944313b1f5/?dl=1\"),\n        (\"Art\", \"Art.tgz\", \"https://cloud.tsinghua.edu.cn/f/6a006656b9a14567ade2/?dl=1\"),\n        (\"Clipart\", \"Clipart.tgz\", \"https://cloud.tsinghua.edu.cn/f/ae88aa31d2d7411dad79/?dl=1\"),\n        (\"Product\", \"Product.tgz\", \"https://cloud.tsinghua.edu.cn/f/f219b0ff35e142b3ab48/?dl=1\"),\n        (\"Real_World\", \"Real_World.tgz\", \"https://cloud.tsinghua.edu.cn/f/6c19f3f15bb24ed3951a/?dl=1\")\n    ]\n    image_list = {\n        \"Ar\": \"image_list/Art.txt\",\n        \"Cl\": \"image_list/Clipart.txt\",\n        \"Pr\": \"image_list/Product.txt\",\n        \"Rw\": \"image_list/Real_World.txt\",\n    }\n    CLASSES = ['Drill', 'Exit_Sign', 'Bottle', 'Glasses', 'Computer', 'File_Cabinet', 'Shelf', 'Toys', 'Sink',\n               'Laptop', 'Kettle', 'Folder', 'Keyboard', 'Flipflops', 'Pencil', 'Bed', 'Hammer', 'ToothBrush', 'Couch',\n               'Bike', 'Postit_Notes', 'Mug', 'Webcam', 'Desk_Lamp', 'Telephone', 'Helmet', 'Mouse', 'Pen', 'Monitor',\n               'Mop', 'Sneakers', 'Notebook', 'Backpack', 'Alarm_Clock', 'Push_Pin', 'Paper_Clip', 'Batteries', 'Radio',\n               'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair',\n               'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV',\n               'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker']\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):\n        assert task in self.image_list\n        data_list_file = os.path.join(root, self.image_list[task])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(OfficeHome, self).__init__(root, OfficeHome.CLASSES, data_list_file=data_list_file, **kwargs)\n\n    @classmethod\n    def domains(cls):\n        return list(cls.image_list.keys())"
  },
  {
    "path": "tllib/vision/datasets/openset/__init__.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom ..imagelist import ImageList\nfrom ..office31 import Office31\nfrom ..officehome import OfficeHome\nfrom ..visda2017 import VisDA2017\n\nfrom typing import Optional, ClassVar, Sequence\nfrom copy import deepcopy\n\n\n__all__ = ['Office31', 'OfficeHome', \"VisDA2017\"]\n\n\ndef open_set(dataset_class: ClassVar, public_classes: Sequence[str],\n            private_classes: Optional[Sequence[str]] = ()) -> ClassVar:\n    \"\"\"\n    Convert a dataset into its open-set version.\n\n    In other words, those samples which doesn't belong to `private_classes` will be marked as \"unknown\".\n\n    Be aware that `open_set` will change the label number of each category.\n\n    Args:\n        dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be open-set.\n        public_classes (sequence[str]): A sequence of which categories need to be kept in the open-set dataset.\\\n            Each element of `public_classes` must belong to the `classes` list of `dataset_class`.\n        private_classes (sequence[str], optional): A sequence of which categories need to be marked as \"unknown\" \\\n            in the open-set dataset. Each element of `private_classes` must belong to the `classes` list of \\\n            `dataset_class`. Default: ().\n\n    Examples::\n\n        >>> public_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard']\n        >>> private_classes = ['laptop_computer', 'monitor', 'mouse', 'mug', 'projector']\n        >>> # create a open-set dataset class which has classes\n        >>> # 'back_pack', 'bike', 'calculator', 'headphones', 'keyboard' and 'unknown'.\n        >>> OpenSetOffice31 = open_set(Office31, public_classes, private_classes)\n        >>> # create an instance of the open-set dataset\n        >>> dataset = OpenSetDataset(root=\"data/office31\", task=\"A\")\n\n    \"\"\"\n    if not (issubclass(dataset_class, ImageList)):\n        raise Exception(\"Only subclass of ImageList can be openset\")\n\n    class OpenSetDataset(dataset_class):\n        def __init__(self, **kwargs):\n            super(OpenSetDataset, self).__init__(**kwargs)\n            samples = []\n            all_classes = list(deepcopy(public_classes)) + [\"unknown\"]\n            for (path, label) in self.samples:\n                class_name = self.classes[label]\n                if class_name in public_classes:\n                    samples.append((path, all_classes.index(class_name)))\n                elif class_name in private_classes:\n                    samples.append((path, all_classes.index(\"unknown\")))\n            self.samples = samples\n            self.classes = all_classes\n            self.class_to_idx = {cls: idx\n                                 for idx, cls in enumerate(self.classes)}\n\n    return OpenSetDataset\n\n\ndef default_open_set(dataset_class: ClassVar, source: bool) -> ClassVar:\n    \"\"\"\n    Default open-set used in some paper.\n\n    Args:\n        dataset_class (class): Dataset class. Currently, dataset_class must be one of\n            :class:`~tllib.vision.datasets.office31.Office31`, :class:`~tllib.vision.datasets.officehome.OfficeHome`,\n            :class:`~tllib.vision.datasets.visda2017.VisDA2017`,\n        source (bool): Whether the dataset is used for source domain or not.\n    \"\"\"\n    if dataset_class == Office31:\n        public_classes = Office31.CLASSES[:20]\n        if source:\n            private_classes = ()\n        else:\n            private_classes = Office31.CLASSES[20:]\n    elif dataset_class == OfficeHome:\n        public_classes = sorted(OfficeHome.CLASSES)[:25]\n        if source:\n            private_classes = ()\n        else:\n            private_classes = sorted(OfficeHome.CLASSES)[25:]\n    elif dataset_class == VisDA2017:\n        public_classes = ('bicycle', 'bus', 'car', 'motorcycle', 'train', 'truck')\n        if source:\n            private_classes = ()\n        else:\n            private_classes = ('aeroplane', 'horse', 'knife', 'person', 'plant', 'skateboard')\n    else:\n        raise NotImplementedError(\"Unknown openset domain adaptation dataset: {}\".format(dataset_class.__name__))\n    return open_set(dataset_class, public_classes, private_classes)\n\n"
  },
  {
    "path": "tllib/vision/datasets/oxfordflowers.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass OxfordFlowers102(ImageList):\n    \"\"\"\n    `The Oxford Flowers 102 <https://www.robots.ox.ac.uk/~vgg/data/flowers/102/>`_ is a \\\n         consistent of 102 flower categories commonly occurring in the United Kingdom. \\\n         Each class consists of between 40 and 258 images. The images have large scale, \\\n         pose and light variations. In addition, there are categories that have large \\\n         variations within the category and several very similar categories. \\\n         The dataset is divided into a training set, a validation set and a test set. \\\n         The training set and validation set each consist of 10 images per class \\\n         (totalling 1020 images each). \\\n         The test set consists of the remaining 6149 images (minimum 20 per class).\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/161c7b222d6745408201/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/59b6a3fa3dac4404aa3b/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/ec77da479dfb471982fb/?dl=1\")\n    ]\n    CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold',\n               'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon',\n               \"colt's foot\", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower',\n               'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower',\n               'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers',\n               'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist',\n               'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort',\n               'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia',\n               'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion',\n               'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura',\n               'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone',\n               'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus',\n               'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple',\n               'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis',\n               'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen', 'watercress',\n               'canna lily', 'hippeastrum', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia',\n               'mallow', 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily']\n\n    def __init__(self, root, split='train', download=False, **kwargs):\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n        super(OxfordFlowers102, self).__init__(root, OxfordFlowers102.CLASSES,\n                                               os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/oxfordpets.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass OxfordIIITPets(ImageList):\n    \"\"\"`The Oxford-IIIT Pets <https://www.robots.ox.ac.uk/~vgg/data/pets/>`_ \\\n    is a 37-category pet dataset with roughly 200 images for each class.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        sample_rate (int): The sampling rates to sample random ``training`` images for each category.\n            Choices include 100, 50, 30, 15. Default: 100.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            train/\n            test/\n            image_list/\n                train_100.txt\n                train_50.txt\n                train_30.txt\n                train_15.txt\n                test.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/8295cfba35b148529bc3/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/89e422c95cb54fb7b0cc/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/dbf7ac10e25b4262b8e5/?dl=1\"),\n    ]\n    image_list = {\n        \"train\": \"image_list/train_100.txt\",\n        \"train100\": \"image_list/train_100.txt\",\n        \"train50\": \"image_list/train_50.txt\",\n        \"train30\": \"image_list/train_30.txt\",\n        \"train15\": \"image_list/train_15.txt\",\n        \"test\": \"image_list/test.txt\",\n        \"test100\": \"image_list/test.txt\",\n    }\n    CLASSES = ['Abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'Bengal',\n               'Birman', 'Bombay', 'boxer', 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel',\n               'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger',\n               'Maine_Coon', 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug', 'Ragdoll',\n               'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'Siamese', 'Sphynx',\n               'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']\n\n    def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,\n                 **kwargs):\n\n        if split == 'train':\n            list_name = 'train' + str(sample_rate)\n            assert list_name in self.image_list\n            data_list_file = os.path.join(root, self.image_list[list_name])\n        else:\n            data_list_file = os.path.join(root, self.image_list['test'])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(OxfordIIITPets, self).__init__(root, OxfordIIITPets.CLASSES, data_list_file=data_list_file, **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/pacs.py",
    "content": "from typing import Optional\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass PACS(ImageList):\n    \"\"\"`PACS Dataset <https://domaingeneralization.github.io/#data>`_.\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \\\n            ``'D'``: dslr and ``'W'``: webcam.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            art_painting/\n                dog/\n                    *.jpg\n                    ...\n            cartoon/\n            photo/\n            sketch\n            image_list/\n                art_painting.txt\n                cartoon.txt\n                photo.txt\n                sketch.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/603a1fea81f2415ab7e0/?dl=1\"),\n        (\"art_painting\", \"art_painting.tgz\", \"https://cloud.tsinghua.edu.cn/f/46684292e979402b8d87/?dl=1\"),\n        (\"cartoon\", \"cartoon.tgz\", \"https://cloud.tsinghua.edu.cn/f/7bfa413b34ec4f4fa384/?dl=1\"),\n        (\"photo\", \"photo.tgz\", \"https://cloud.tsinghua.edu.cn/f/45f71386a668475d8b42/?dl=1\"),\n        (\"sketch\", \"sketch.tgz\", \"https://cloud.tsinghua.edu.cn/f/4ba559535e4b4b6981e5/?dl=1\"),\n    ]\n    image_list = {\n        \"A\": \"image_list/art_painting_{}.txt\",\n        \"C\": \"image_list/cartoon_{}.txt\",\n        \"P\": \"image_list/photo_{}.txt\",\n        \"S\": \"image_list/sketch_{}.txt\"\n    }\n    CLASSES = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']\n\n    def __init__(self, root: str, task: str, split='all', download: Optional[bool] = True, **kwargs):\n        assert task in self.image_list\n        assert split in [\"train\", \"val\", \"all\", \"test\"]\n        if split == \"test\":\n            split = \"all\"\n        data_list_file = os.path.join(root, self.image_list[task].format(split))\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(PACS, self).__init__(root, PACS.CLASSES, data_list_file=data_list_file, target_transform=lambda x: x - 1,\n                                   **kwargs)\n\n    @classmethod\n    def domains(cls):\n        return list(cls.image_list.keys())\n"
  },
  {
    "path": "tllib/vision/datasets/partial/__init__.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom ..imagelist import ImageList\nfrom ..office31 import Office31\nfrom ..officehome import OfficeHome\nfrom ..visda2017 import VisDA2017\nfrom ..officecaltech import OfficeCaltech\nfrom .imagenet_caltech import ImageNetCaltech\nfrom .caltech_imagenet import CaltechImageNet\nfrom tllib.vision.datasets.partial.imagenet_caltech import ImageNetCaltech\nfrom typing import Sequence, ClassVar\n\n\n__all__ = ['Office31', 'OfficeHome', \"VisDA2017\", \"CaltechImageNet\", \"ImageNetCaltech\"]\n\n\ndef partial(dataset_class: ClassVar, partial_classes: Sequence[str]) -> ClassVar:\n    \"\"\"\n    Convert a dataset into its partial version.\n\n    In other words, those samples which doesn't belong to `partial_classes` will be discarded.\n    Yet `partial` will not change the label space of `dataset_class`.\n\n    Args:\n        dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be partial.\n        partial_classes (sequence[str]): A sequence of which categories need to be kept in the partial dataset.\\\n            Each element of `partial_classes` must belong to the `classes` list of `dataset_class`.\n\n    Examples::\n\n    >>> partial_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard']\n    >>> # create a partial dataset class\n    >>> PartialOffice31 = partial(Office31, partial_classes)\n    >>> # create an instance of the partial dataset\n    >>> dataset = PartialDataset(root=\"data/office31\", task=\"A\")\n\n    \"\"\"\n    if not (issubclass(dataset_class, ImageList)):\n        raise Exception(\"Only subclass of ImageList can be partial\")\n\n    class PartialDataset(dataset_class):\n        def __init__(self, **kwargs):\n            super(PartialDataset, self).__init__(**kwargs)\n            assert all([c in self.classes for c in partial_classes])\n            samples = []\n            for (path, label) in self.samples:\n                class_name = self.classes[label]\n                if class_name in partial_classes:\n                    samples.append((path, label))\n            self.samples = samples\n            self.partial_classes = partial_classes\n            self.partial_classes_idx = [self.class_to_idx[c] for c in partial_classes]\n\n    return PartialDataset\n\n\ndef default_partial(dataset_class: ClassVar) -> ClassVar:\n    \"\"\"\n    Default partial used in some paper.\n\n    Args:\n        dataset_class (class): Dataset class. Currently, dataset_class must be one of\n            :class:`~tllib.vision.datasets.office31.Office31`, :class:`~tllib.vision.datasets.officehome.OfficeHome`,\n            :class:`~tllib.vision.datasets.visda2017.VisDA2017`,\n            :class:`~tllib.vision.datasets.partial.imagenet_caltech.ImageNetCaltech`\n            and :class:`~tllib.vision.datasets.partial.caltech_imagenet.CaltechImageNet`.\n    \"\"\"\n    if dataset_class == Office31:\n        kept_classes = OfficeCaltech.CLASSES\n    elif dataset_class == OfficeHome:\n        kept_classes = sorted(OfficeHome.CLASSES)[:25]\n    elif dataset_class == VisDA2017:\n        kept_classes = sorted(VisDA2017.CLASSES)[:6]\n    elif dataset_class in [ImageNetCaltech, CaltechImageNet]:\n        kept_classes = dataset_class.CLASSES\n    else:\n        raise NotImplementedError(\"Unknown partial domain adaptation dataset: {}\".format(dataset_class.__name__))\n    return partial(dataset_class, kept_classes)"
  },
  {
    "path": "tllib/vision/datasets/partial/caltech_imagenet.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional\nimport os\nfrom ..imagelist import ImageList\nfrom .._util import download as download_data, check_exits\n\n_CLASSES = ['ak47', 'american flag', 'backpack', 'baseball bat', 'baseball glove', 'basketball hoop', 'bat',\n           'bathtub', 'bear', 'beer mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai 101',\n           'boom box', 'bowling ball', 'bowling pin', 'boxing glove', 'brain 101', 'breadmaker', 'buddha 101',\n           'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car tire',\n           'cartman', 'cd', 'centipede', 'cereal box', 'chandelier 101', 'chess board', 'chimp', 'chopsticks',\n           'cockroach', 'coffee mug', 'coffin', 'coin', 'comet', 'computer keyboard', 'computer monitor',\n           'computer mouse', 'conch', 'cormorant', 'covered wagon', 'cowboy hat', 'crab 101', 'desk globe',\n           'diamond ring', 'dice', 'dog', 'dolphin 101', 'doorknob', 'drinking straw', 'duck', 'dumb bell',\n           'eiffel tower', 'electric guitar 101', 'elephant 101', 'elk', 'ewer 101', 'eyeglasses', 'fern',\n           'fighter jet', 'fire extinguisher', 'fire hydrant', 'fire truck', 'fireworks', 'flashlight',\n           'floppy disk', 'football helmet', 'french horn', 'fried egg', 'frisbee', 'frog', 'frying pan',\n           'galaxy', 'gas pump', 'giraffe', 'goat', 'golden gate bridge', 'goldfish', 'golf ball', 'goose',\n           'gorilla', 'grand piano 101', 'grapes', 'grasshopper', 'guitar pick', 'hamburger', 'hammock',\n           'harmonica', 'harp', 'harpsichord', 'hawksbill 101', 'head phones', 'helicopter 101', 'hibiscus',\n           'homer simpson', 'horse', 'horseshoe crab', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass',\n           'house fly', 'human skeleton', 'hummingbird', 'ibis 101', 'ice cream cone', 'iguana', 'ipod',\n           'iris', 'jesus christ', 'joy stick', 'kangaroo 101', 'kayak', 'ketch 101', 'killer whale', 'knife',\n           'ladder', 'laptop 101', 'lathe', 'leopards 101', 'license plate', 'lightbulb', 'light house',\n           'lightning', 'llama 101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah 101',\n           'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes 101', 'mountain bike', 'mushroom',\n           'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm pilot', 'palm tree', 'paperclip',\n           'paper shredder', 'pci card', 'penguin', 'people', 'pez dispenser', 'photocopier', 'picnic table',\n           'playing card', 'porcupine', 'pram', 'praying mantis', 'pyramid', 'raccoon', 'radio telescope',\n           'rainbow', 'refrigerator', 'revolver 101', 'rifle', 'rotary phone', 'roulette wheel', 'saddle',\n           'saturn', 'school bus', 'scorpion 101', 'screwdriver', 'segway', 'self propelled lawn mower',\n           'sextant', 'sheet music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake',\n           'sneaker', 'snowmobile', 'soccer ball', 'socks', 'soda can', 'spaghetti', 'speed boat', 'spider',\n           'spoon', 'stained glass', 'starfish 101', 'steering wheel', 'stirrups', 'sunflower 101', 'superman',\n           'sushi', 'swan', 'swiss army knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy bear',\n           'teepee', 'telephone box', 'tennis ball', 'tennis court', 'tennis racket', 'theodolite', 'toaster',\n           'tomato', 'tombstone', 'top hat', 'touring bike', 'tower pisa', 'traffic light', 'treadmill',\n           'triceratops', 'tricycle', 'trilobite 101', 'tripod', 't shirt', 'tuning fork', 'tweezer',\n           'umbrella 101', 'unicorn', 'vcr', 'video projector', 'washing machine', 'watch 101', 'waterfall',\n           'watermelon', 'welding mask', 'wheelbarrow', 'windmill', 'wine bottle', 'xylophone', 'yarmulke',\n           'yo yo', 'zebra', 'airplanes 101', 'car side 101', 'faces easy 101', 'greyhound', 'tennis shoes',\n           'toad']\n\n\nclass CaltechImageNet(ImageList):\n    \"\"\"Caltech-ImageNet is constructed from `Caltech-256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ and\n    `ImageNet-1K <http://image-net.org/>`_ .\n\n    They share 84 common classes. Caltech-ImageNet keeps all classes of Caltech-256.\n    The label is based on the Caltech256 (class 0-255) . The private classes of ImageNet-1K is discarded.\n\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \\\n            ``'I'``: ImageNet-1K validation set.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory\n        since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.\n        In `root`, there will exist following files after downloading.\n        ::\n            train/\n                n01440764/\n                ...\n            val/\n            256_ObjectCategories/\n                001.ak47/\n                ...\n            image_list/\n                caltech_256_list.txt\n                ...\n    \"\"\"\n    image_list = {\n        \"C\": \"image_list/caltech_256_list.txt\",\n        \"I\": \"image_list/imagenet_val_84_list.txt\",\n    }\n    CLASSES = _CLASSES\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):\n        assert task in self.image_list\n        data_list_file = os.path.join(root, self.image_list[task])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), download_list))\n\n        if not os.path.exists(os.path.join(root, 'val')):\n            print(\"Please put train and val directory of ImageNet-1K manually under {} \"\n                  \"since ImageNet-1K is no longer publicly accessible.\".format(root))\n            exit(-1)\n\n        super(CaltechImageNet, self).__init__(root, CaltechImageNet.CLASSES, data_list_file=data_list_file, **kwargs)\n\n\nclass CaltechImageNetUniversal(ImageList):\n    \"\"\"Caltech-ImageNet-Universal is constructed from `Caltech-256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_\n        and `ImageNet-1K <http://image-net.org/>`_ .\n\n        They share 84 common classes. Caltech-ImageNet keeps all classes of Caltech-256.\n        The label is based on the Caltech256 (class 0-255) . The private classes of ImageNet-1K is grouped into class 256 (\"unknown\").\n        Thus, CaltechImageNetUniversal has 257 classes in total.\n\n        Args:\n            root (str): Root directory of dataset\n            task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \\\n                ``'I'``: ImageNet-1K validation set.\n            download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n                in root directory. If dataset is already downloaded, it is not downloaded again.\n            transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n                transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n            target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n        .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory\n            since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.\n            In `root`, there will exist following files after downloading.\n            ::\n                train/\n                    n01440764/\n                    ...\n                val/\n                256_ObjectCategories/\n                    001.ak47/\n                    ...\n                image_list/\n                    caltech_256_list.txt\n                    ...\n        \"\"\"\n    image_list = {\n        \"C\": \"image_list/caltech_256_list.txt\",\n        \"I\": \"image_list/imagenet_val_85_list.txt\",\n    }\n    CLASSES = _CLASSES + ['unknown']\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):\n        assert task in self.image_list\n        data_list_file = os.path.join(root, self.image_list[task])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), download_list))\n\n        if not os.path.exists(os.path.join(root, 'val')):\n            print(\"Please put train and val directory of ImageNet-1K manually under {} \"\n                  \"since ImageNet-1K is no longer publicly accessible.\".format(root))\n            exit(-1)\n\n        super(CaltechImageNetUniversal, self).__init__(root, CaltechImageNetUniversal.CLASSES,\n                                                     data_list_file=data_list_file, **kwargs)\n\n\n\ndownload_list = [\n    (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/a0d7ea37026946f98965/?dl=1\"),\n    (\"256_ObjectCategories\", \"256_ObjectCategories.tar\",\n     \"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar\"),\n]\n\n\n"
  },
  {
    "path": "tllib/vision/datasets/partial/imagenet_caltech.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom ..imagelist import ImageList\nfrom .._util import download as download_data, check_exits\n\n\n_CLASSES = [c[0] for c in [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'),\n                          ('great white shark', 'white shark', 'man-eater', 'man-eating shark', 'Carcharodon carcharias'),\n                          ('tiger shark', 'Galeocerdo cuvieri'), ('hammerhead', 'hammerhead shark'),\n                          ('electric ray', 'crampfish', 'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',),\n                          ('ostrich', 'Struthio camelus'), ('brambling', 'Fringilla montifringilla'),\n                          ('goldfinch', 'Carduelis carduelis'), ('house finch', 'linnet', 'Carpodacus mexicanus'),\n                          ('junco', 'snowbird'), ('indigo bunting', 'indigo finch', 'indigo bird', 'Passerina cyanea'),\n                          ('robin', 'American robin', 'Turdus migratorius'), ('bulbul',), ('jay',), ('magpie',),\n                          ('chickadee',), ('water ouzel', 'dipper'), ('kite',),\n                          ('bald eagle', 'American eagle', 'Haliaeetus leucocephalus'), ('vulture',),\n                          ('great grey owl', 'great gray owl', 'Strix nebulosa'),\n                          ('European fire salamander', 'Salamandra salamandra'), ('common newt', 'Triturus vulgaris'),\n                          ('eft',), ('spotted salamander', 'Ambystoma maculatum'),\n                          ('axolotl', 'mud puppy', 'Ambystoma mexicanum'), ('bullfrog', 'Rana catesbeiana'),\n                          ('tree frog', 'tree-frog'),\n                          ('tailed frog', 'bell toad', 'ribbed toad', 'tailed toad', 'Ascaphus trui'),\n                          ('loggerhead', 'loggerhead turtle', 'Caretta caretta'),\n                          ('leatherback turtle', 'leatherback', 'leathery turtle', 'Dermochelys coriacea'),\n                          ('mud turtle',), ('terrapin',), ('box turtle', 'box tortoise'), ('banded gecko',),\n                          ('common iguana', 'iguana', 'Iguana iguana'),\n                          ('American chameleon', 'anole', 'Anolis carolinensis'), ('whiptail', 'whiptail lizard'),\n                          ('agama',), ('frilled lizard', 'Chlamydosaurus kingi'), ('alligator lizard',),\n                          ('Gila monster', 'Heloderma suspectum'), ('green lizard', 'Lacerta viridis'),\n                          ('African chameleon', 'Chamaeleo chamaeleon'),\n                          ('Komodo dragon', 'Komodo lizard', 'dragon lizard', 'giant lizard', 'Varanus komodoensis'),\n                          ('African crocodile', 'Nile crocodile', 'Crocodylus niloticus'),\n                          ('American alligator', 'Alligator mississipiensis'), ('triceratops',),\n                          ('thunder snake', 'worm snake', 'Carphophis amoenus'),\n                          ('ringneck snake', 'ring-necked snake', 'ring snake'),\n                          ('hognose snake', 'puff adder', 'sand viper'), ('green snake', 'grass snake'),\n                          ('king snake', 'kingsnake'), ('garter snake', 'grass snake'), ('water snake',),\n                          ('vine snake',), ('night snake', 'Hypsiglena torquata'),\n                          ('boa constrictor', 'Constrictor constrictor'), ('rock python', 'rock snake', 'Python sebae'),\n                          ('Indian cobra', 'Naja naja'), ('green mamba',), ('sea snake',),\n                          ('horned viper', 'cerastes', 'sand viper', 'horned asp', 'Cerastes cornutus'),\n                          ('diamondback', 'diamondback rattlesnake', 'Crotalus adamanteus'),\n                          ('sidewinder', 'horned rattlesnake', 'Crotalus cerastes'), ('trilobite',),\n                          ('harvestman', 'daddy longlegs', 'Phalangium opilio'), ('scorpion',),\n                          ('black and gold garden spider', 'Argiope aurantia'), ('barn spider', 'Araneus cavaticus'),\n                          ('garden spider', 'Aranea diademata'), ('black widow', 'Latrodectus mactans'), ('tarantula',),\n                          ('wolf spider', 'hunting spider'), ('tick',), ('centipede',), ('black grouse',),\n                          ('ptarmigan',), ('ruffed grouse', 'partridge', 'Bonasa umbellus'),\n                          ('prairie chicken', 'prairie grouse', 'prairie fowl'), ('peacock',), ('quail',),\n                          ('partridge',), ('African grey', 'African gray', 'Psittacus erithacus'), ('macaw',),\n                          ('sulphur-crested cockatoo', 'Kakatoe galerita', 'Cacatua galerita'), ('lorikeet',),\n                          ('coucal',), ('bee eater',), ('hornbill',), ('hummingbird',), ('jacamar',), ('toucan',),\n                          ('drake',), ('red-breasted merganser', 'Mergus serrator'), ('goose',),\n                          ('black swan', 'Cygnus atratus'), ('tusker',), ('echidna', 'spiny anteater', 'anteater'), (\n                              'platypus', 'duckbill', 'duckbilled platypus', 'duck-billed platypus',\n                              'Ornithorhynchus anatinus'), ('wallaby', 'brush kangaroo'),\n                          ('koala', 'koalabear', 'kangaroo bear', 'native bear', 'Phascolarctos cinereus'), ('wombat',),\n                          ('jellyfish',), ('sea anemone', 'anemone'), ('brain coral',), ('flatworm', 'platyhelminth'),\n                          ('nematode', 'nematode worm', 'roundworm'), ('conch',), ('snail',), ('slug',),\n                          ('sea slug', 'nudibranch'), ('chiton', 'coat-of-mail shell', 'sea cradle', 'polyplacophore'),\n                          ('chambered nautilus', 'pearly nautilus', 'nautilus'), ('Dungeness crab', 'Cancer magister'),\n                          ('rock crab', 'Cancer irroratus'), ('fiddler crab',), (\n                              'king crab', 'Alaska crab', 'Alaskan king crab', 'Alaska king crab',\n                              'Paralithodes camtschatica'),\n                          ('American lobster', 'Northernlobster', 'Maine lobster', 'Homarus americanus'),\n                          ('spiny lobster', 'langouste', 'rock lobster', 'crawfish', 'crayfish', 'sea crawfish'),\n                          ('crayfish', 'crawfish', 'crawdad', 'crawdaddy'), ('hermit crab',), ('isopod',),\n                          ('white stork', 'Ciconia ciconia'), ('black stork', 'Ciconia nigra'), ('spoonbill',),\n                          ('flamingo',), ('little blue heron', 'Egretta caerulea'),\n                          ('American egret', 'great white heron', 'Egretta albus'), ('bittern',), ('crane',),\n                          ('limpkin', 'Aramus pictus'), ('European gallinule', 'Porphyrio porphyrio'),\n                          ('American coot', 'marsh hen', 'mud hen', 'water hen', 'Fulica americana'), ('bustard',),\n                          ('ruddy turnstone', 'Arenaria interpres'),\n                          ('red-backed sandpiper', 'dunlin', 'Erolia alpina'), ('redshank', 'Tringa totanus'),\n                          ('dowitcher',), ('oystercatcher', 'oyster catcher'), ('pelican',),\n                          ('king penguin', 'Aptenodytes patagonica'), ('albatross', 'mollymawk'),\n                          ('grey whale', 'gray whale', 'devilfish', 'Eschrichtius gibbosus', 'Eschrichtius robustus'),\n                          ('killer whale', 'killer', 'orca', 'grampus', 'sea wolf', 'Orcinusorca'),\n                          ('dugong', 'Dugong dugon'), ('sea lion',), ('Chihuahua',), ('Japanese spaniel',),\n                          ('Maltese dog', 'Maltese terrier', 'Maltese'), ('Pekinese', 'Pekingese', 'Peke'),\n                          ('Shih-Tzu',), ('Blenheim spaniel',), ('papillon',), ('toy terrier',),\n                          ('Rhodesian ridgeback',), ('Afghan hound', 'Afghan'), ('basset', 'basset hound'), ('beagle',),\n                          ('bloodhound', 'sleuthhound'), ('bluetick',), ('black-and-tan coonhound',),\n                          ('Walker hound', 'Walker foxhound'), ('English foxhound',), ('redbone',),\n                          ('borzoi', 'Russian wolfhound'), ('Irish wolfhound',), ('Italian greyhound',), ('whippet',),\n                          ('Ibizan hound', 'Ibizan Podenco'), ('Norwegian elkhound', 'elkhound'),\n                          ('otterhound', 'otter hound'), ('Saluki', 'gazelle hound'),\n                          ('Scottish deerhound', 'deerhound'), ('Weimaraner',),\n                          ('Staffordshire bullterrier', 'Staffordshire bull terrier'), (\n                              'American Staffordshire terrier', 'Staffordshire terrier', 'American pit bull terrier',\n                              'pit bull terrier'), ('Bedlington terrier',), ('Border terrier',),\n                          ('Kerry blue terrier',),\n                          ('Irish terrier',), ('Norfolkterrier',), ('Norwich terrier',), ('Yorkshire terrier',),\n                          ('wire-haired fox terrier',), ('Lakeland terrier',), ('Sealyham terrier', 'Sealyham'),\n                          ('Airedale', 'Airedale terrier'), ('cairn', 'cairn terrier'), ('Australian terrier',),\n                          ('Dandie Dinmont', 'Dandie Dinmont terrier'), ('Boston bull', 'Boston terrier'),\n                          ('miniature schnauzer',), ('giant schnauzer',), ('standard schnauzer',),\n                          ('Scotch terrier', 'Scottish terrier', 'Scottie'), ('Tibetan terrier', 'chrysanthemum dog'),\n                          ('silky terrier', 'Sydney silky'), ('soft-coated wheaten terrier',),\n                          ('West Highland white terrier',), ('Lhasa', 'Lhasa apso'), ('flat-coated retriever',),\n                          ('curly-coated retriever',), ('golden retriever',), ('Labrador retriever',),\n                          ('Chesapeake Bay retriever',), ('German short-haired pointer',),\n                          ('vizsla', 'Hungarian pointer'), ('English setter',), ('Irish setter', 'red setter'),\n                          ('Gordon setter',), ('Brittany spaniel',), ('clumber', 'clumber spaniel'),\n                          ('English springer', 'English springer spaniel'), ('Welsh springer spaniel',),\n                          ('cocker spaniel', 'English cocker spaniel', 'cocker'), ('Sussex spaniel',),\n                          ('Irish water spaniel',), ('kuvasz',), ('schipperke',), ('groenendael',), ('malinois',),\n                          ('briard',), ('kelpie',), ('komondor',), ('Old English sheepdog', 'bobtail'),\n                          ('Shetland sheepdog', 'Shetland sheep dog', 'Shetland'), ('collie',), ('Border collie',),\n                          ('Bouvier des Flandres', 'Bouviers des Flandres'), ('Rottweiler',),\n                          ('German shepherd', 'German shepherd dog', 'German police dog', 'alsatian'),\n                          ('Doberman', 'Doberman pinscher'), ('miniature pinscher',), ('Greater Swiss Mountain dog',),\n                          ('Bernese mountain dog',), ('Appenzeller',), ('EntleBucher',), ('boxer',), ('bull mastiff',),\n                          ('Tibetan mastiff',), ('French bulldog',), ('Great Dane',), ('Saint Bernard', 'St Bernard'),\n                          ('Eskimo dog', 'husky'), ('malamute', 'malemute', 'Alaskan malamute'), ('Siberian husky',),\n                          ('dalmatian', 'coach dog', 'carriage dog'),\n                          ('affenpinscher', 'monkey pinscher', 'monkey dog'), ('basenji',), ('pug', 'pug-dog'),\n                          ('Leonberg',), ('Newfoundland', 'Newfoundland dog'), ('Great Pyrenees',),\n                          ('Samoyed', 'Samoyede'), ('Pomeranian',), ('chow', 'chow chow'), ('keeshond',),\n                          ('Brabancon griffon',), ('Pembroke', 'Pembroke Welsh corgi'),\n                          ('Cardigan', 'Cardigan Welsh corgi'), ('toy poodle',), ('miniature poodle',),\n                          ('standard poodle',), ('Mexican hairless',),\n                          ('timber wolf', 'grey wolf', 'gray wolf', 'Canis lupus'),\n                          ('white wolf', 'Arctic wolf', 'Canis lupus tundrarum'),\n                          ('red wolf', 'maned wolf', 'Canis rufus', 'Canis niger'),\n                          ('coyote', 'prairie wolf', 'brush wolf', 'Canis latrans'),\n                          ('dingo', 'warrigal', 'warragal', 'Canis dingo'), ('dhole', 'Cuon alpinus'),\n                          ('African hunting dog', 'hyena dog', 'Cape hunting dog', 'Lycaon pictus'),\n                          ('hyena', 'hyaena'), ('red fox', 'Vulpes vulpes'), ('kit fox', 'Vulpes macrotis'),\n                          ('Arctic fox', 'white fox', 'Alopex lagopus'),\n                          ('grey fox', 'gray fox', 'Urocyon cinereoargenteus'), ('tabby', 'tabby cat'), ('tiger cat',),\n                          ('Persian cat',), ('Siamese cat', 'Siamese'), ('Egyptian cat',),\n                          ('cougar', 'puma', 'catamount', 'mountain lion', 'painter', 'panther', 'Felis concolor'),\n                          ('lynx', 'catamount'), ('leopard', 'Panthera pardus'),\n                          ('snow leopard', 'ounce', 'Panthera uncia'),\n                          ('jaguar', 'panther', 'Panthera onca', 'Felis onca'),\n                          ('lion', 'king of beasts', 'Panthera leo'), ('tiger', 'Panthera tigris'),\n                          ('cheetah', 'chetah', 'Acinonyx jubatus'), ('brown bear', 'bruin', 'Ursus arctos'),\n                          ('American black bear', 'black bear', 'Ursus americanus', 'Euarctos americanus'),\n                          ('ice bear', 'polar bear', 'Ursus Maritimus', 'Thalarctos maritimus'),\n                          ('sloth bear', 'Melursus ursinus', 'Ursus ursinus'), ('mongoose',), ('meerkat', 'mierkat'),\n                          ('tiger beetle',), ('ladybug', 'ladybeetle', 'lady beetle', 'ladybird', 'ladybird beetle'),\n                          ('ground beetle', 'carabid beetle'), ('long-horned beetle', 'longicorn', 'longicorn beetle'),\n                          ('leaf beetle', 'chrysomelid'), ('dung beetle',), ('rhinoceros beetle',), ('weevil',),\n                          ('fly',), ('bee',), ('ant', 'emmet', 'pismire'), ('grasshopper', 'hopper'), ('cricket',),\n                          ('walking stick', 'walkingstick', 'stick insect'), ('cockroach', 'roach'),\n                          ('mantis', 'mantid'), ('cicada', 'cicala'), ('leafhopper',), ('lacewing', 'lacewing fly'), (\n                              'dragonfly', 'darning needle', \"devil's darning needle\", 'sewing needle', 'snake feeder',\n                              'snake doctor', 'mosquito hawk', 'skeeter hawk'), ('damselfly',), ('admiral',),\n                          ('ringlet', 'ringlet butterfly'),\n                          ('monarch', 'monarch butterfly', 'milkweed butterfly', 'Danaus plexippus'),\n                          ('cabbage butterfly',), ('sulphur butterfly', 'sulfur butterfly'),\n                          ('lycaenid', 'lycaenid butterfly'), ('starfish', 'sea star'), ('sea urchin',),\n                          ('sea cucumber', 'holothurian'), ('wood rabbit', 'cottontail', 'cottontail rabbit'),\n                          ('hare',), ('Angora', 'Angora rabbit'), ('hamster',), ('porcupine', 'hedgehog'),\n                          ('fox squirrel', 'eastern fox squirrel', 'Sciurus niger'), ('marmot',), ('beaver',),\n                          ('guinea pig', 'Cavia cobaya'), ('sorrel',), ('zebra',),\n                          ('hog', 'pig', 'grunter', 'squealer', 'Sus scrofa'), ('wild boar', 'boar', 'Sus scrofa'),\n                          ('warthog',), ('hippopotamus', 'hippo', 'river horse', 'Hippopotamus amphibius'), ('ox',),\n                          ('water buffalo', 'water ox', 'Asiatic buffalo', 'Bubalus bubalis'), ('bison',),\n                          ('ram', 'tup'), (\n                              'bighorn', 'bighorn sheep', 'cimarron', 'Rocky Mountain bighorn', 'Rocky Mountain sheep',\n                              'Ovis canadensis'), ('ibex', 'Capra ibex'), ('hartebeest',),\n                          ('impala', 'Aepyceros melampus'),\n                          ('gazelle',), ('Arabian camel', 'dromedary', 'Camelus dromedarius'), ('llama',), ('weasel',),\n                          ('mink',), ('polecat', 'fitch', 'foulmart', 'foumart', 'Mustela putorius'),\n                          ('black-footed ferret', 'ferret', 'Mustela nigripes'), ('otter',),\n                          ('skunk', 'polecat', 'wood pussy'), ('badger',), ('armadillo',),\n                          ('three-toed sloth', 'ai', 'Bradypus tridactylus'),\n                          ('orangutan', 'orang', 'orangutang', 'Pongo pygmaeus'), ('gorilla', 'Gorilla gorilla'),\n                          ('chimpanzee', 'chimp', 'Pan troglodytes'), ('gibbon', 'Hylobates lar'),\n                          ('siamang', 'Hylobates syndactylus', 'Symphalangus syndactylus'), ('guenon', 'guenon monkey'),\n                          ('patas', 'hussar monkey', 'Erythrocebus patas'), ('baboon',), ('macaque',), ('langur',),\n                          ('colobus', 'colobus monkey'), ('proboscis monkey', 'Nasalis larvatus'), ('marmoset',),\n                          ('capuchin', 'ringtail', 'Cebus capucinus'), ('howler monkey', 'howler'),\n                          ('titi', 'titi monkey'), ('spider monkey', 'Ateles geoffroyi'),\n                          ('squirrel monkey', 'Saimiri sciureus'),\n                          ('Madagascar cat', 'ring-tailed lemur', 'Lemur catta'),\n                          ('indri', 'indris', 'Indri indri', 'Indri brevicaudatus'),\n                          ('Indian elephant', 'Elephas maximus'), ('African elephant', 'Loxodonta africana'),\n                          ('lesser panda', 'red panda', 'panda', 'bear cat', 'cat bear', 'Ailurus fulgens'),\n                          ('giant panda', 'panda', 'panda bear', 'coon bear', 'Ailuropoda melanoleuca'),\n                          ('barracouta', 'snoek'), ('eel',),\n                          ('coho', 'cohoe', 'coho salmon', 'blue jack', 'silver salmon', 'Oncorhynchus kisutch'),\n                          ('rock beauty', 'Holocanthus tricolor'), ('anemone fish',), ('sturgeon',),\n                          ('gar', 'garfish', 'garpike', 'billfish', 'Lepisosteus osseus'), ('lionfish',),\n                          ('puffer', 'pufferfish', 'blowfish', 'globefish'), ('abacus',), ('abaya',),\n                          ('academic gown', 'academic robe', \"judge's robe\"),\n                          ('accordion', 'piano accordion', 'squeeze box'), ('acoustic guitar',),\n                          ('aircraft carrier', 'carrier', 'flattop', 'attack aircraft carrier'), ('airliner',),\n                          ('airship', 'dirigible'), ('altar',), ('ambulance',), ('amphibian', 'amphibious vehicle'),\n                          ('analog clock',), ('apiary', 'bee house'), ('apron',), (\n                              'ashcan', 'trash can', 'garbage can', 'wastebin', 'ash bin', 'ash-bin', 'ashbin',\n                              'dustbin',\n                              'trash barrel', 'trash bin'), ('assault rifle', 'assault gun'),\n                          ('backpack', 'back pack', 'knapsack', 'packsack', 'rucksack', 'haversack'),\n                          ('bakery', 'bakeshop', 'bakehouse'), ('balance beam', 'beam'), ('balloon',),\n                          ('ballpoint', 'ballpoint pen', 'ballpen', 'Biro'), ('Band Aid',), ('banjo',),\n                          ('bannister', 'banister', 'balustrade', 'balusters', 'handrail'), ('barbell',),\n                          ('barber chair',), ('barbershop',), ('barn',), ('barometer',), ('barrel', 'cask'),\n                          ('barrow', 'garden cart', 'lawn cart', 'wheelbarrow'), ('baseball',), ('basketball',),\n                          ('bassinet',), ('bassoon',), ('bathing cap', 'swimming cap'), ('bath towel',),\n                          ('bathtub', 'bathing tub', 'bath', 'tub'), (\n                              'beach wagon', 'station wagon', 'wagon', 'estate car', 'beach waggon', 'station waggon',\n                              'waggon'), ('beacon', 'lighthouse', 'beacon light', 'pharos'), ('beaker',),\n                          ('bearskin', 'busby', 'shako'), ('beer bottle',), ('beer glass',), ('bell cote', 'bell cot'),\n                          ('bib',), ('bicycle-built-for-two', 'tandem bicycle', 'tandem'), ('bikini', 'two-piece'),\n                          ('binder', 'ring-binder'), ('binoculars', 'field glasses', 'opera glasses'), ('birdhouse',),\n                          ('boathouse',), ('bobsled', 'bobsleigh', 'bob'), ('bolo tie', 'bolo', 'bola tie', 'bola'),\n                          ('bonnet', 'poke bonnet'), ('bookcase',), ('bookshop', 'bookstore', 'bookstall'),\n                          ('bottlecap',), ('bow',), ('bow tie', 'bow-tie', 'bowtie'),\n                          ('brass', 'memorial tablet', 'plaque'), ('brassiere', 'bra', 'bandeau'),\n                          ('breakwater', 'groin', 'groyne', 'mole', 'bulwark', 'seawall', 'jetty'),\n                          ('breastplate', 'aegis', 'egis'), ('broom',), ('bucket', 'pail'), ('buckle',),\n                          ('bulletproof vest',), ('bullet train', 'bullet'), ('butcher shop', 'meat market'),\n                          ('cab', 'hack', 'taxi', 'taxicab'), ('caldron', 'cauldron'), ('candle', 'taper', 'wax light'),\n                          ('cannon',), ('canoe',), ('can opener', 'tin opener'), ('cardigan',), ('car mirror',),\n                          ('carousel', 'carrousel', 'merry-go-round', 'roundabout', 'whirligig'),\n                          (\"carpenter's kit\", 'tool kit'), ('carton',), ('car wheel',), (\n                              'cash machine', 'cash dispenser', 'automated teller machine', 'automatic teller machine',\n                              'automated teller', 'automatic teller', 'ATM'), ('cassette',), ('cassette player',),\n                          ('castle',), ('catamaran',), ('CD player',), ('cello', 'violoncello'),\n                          ('cellular telephone', 'cellular phone', 'cellphone', 'cell', 'mobile phone'), ('chain',),\n                          ('chainlink fence',), (\n                              'chain mail', 'ring mail', 'mail', 'chain armor', 'chain armour', 'ring armor',\n                              'ring armour'), ('chain saw', 'chainsaw'), ('chest',), ('chiffonier', 'commode'),\n                          ('chime', 'bell', 'gong'), ('china cabinet', 'china closet'), ('Christmas stocking',),\n                          ('church', 'church building'),\n                          ('cinema', 'movie theater', 'movie theatre', 'movie house', 'picture palace'),\n                          ('cleaver', 'meat cleaver', 'chopper'), ('cliff dwelling',), ('cloak',),\n                          ('clog', 'geta', 'patten', 'sabot'), ('cocktail shaker',), ('coffee mug',), ('coffeepot',),\n                          ('coil', 'spiral', 'volute', 'whorl', 'helix'), ('combination lock',),\n                          ('computer keyboard', 'keypad'), ('confectionery', 'confectionary', 'candy store'),\n                          ('container ship', 'containership', 'container vessel'), ('convertible',),\n                          ('corkscrew', 'bottle screw'), ('cornet', 'horn', 'trumpet', 'trump'), ('cowboy boot',),\n                          ('cowboy hat', 'ten-gallon hat'), ('cradle',), ('crane',), ('crash helmet',), ('crate',),\n                          ('crib', 'cot'), ('Crock Pot',), ('croquet ball',), ('crutch',), ('cuirass',),\n                          ('dam', 'dike', 'dyke'), ('desk',), ('desktop computer',), ('dial telephone', 'dial phone'),\n                          ('diaper', 'nappy', 'napkin'), ('digital clock',), ('digital watch',),\n                          ('dining table', 'board'), ('dishrag', 'dishcloth'),\n                          ('dishwasher', 'dish washer', 'dishwashing machine'), ('disk brake', 'disc brake'),\n                          ('dock', 'dockage', 'docking facility'), ('dogsled', 'dog sled', 'dog sleigh'), ('dome',),\n                          ('doormat', 'welcome mat'), ('drilling platform', 'offshore rig'),\n                          ('drum', 'membranophone', 'tympan'), ('drumstick',), ('dumbbell',), ('Dutch oven',),\n                          ('electric fan', 'blower'), ('electric guitar',), ('electric locomotive',),\n                          ('entertainment center',), ('envelope',), ('espresso maker',), ('face powder',),\n                          ('feather boa', 'boa'), ('file', 'file cabinet', 'filing cabinet'), ('fireboat',),\n                          ('fire engine', 'fire truck'), ('fire screen', 'fireguard'), ('flagpole', 'flagstaff'),\n                          ('flute', 'transverse flute'), ('folding chair',), ('football helmet',), ('forklift',),\n                          ('fountain',), ('fountain pen',), ('four-poster',), ('freight car',), ('French horn', 'horn'),\n                          ('frying pan', 'frypan', 'skillet'), ('fur coat',), ('garbage truck', 'dustcart'),\n                          ('gasmask', 'respirator', 'gas helmet'),\n                          ('gas pump', 'gasoline pump', 'petrol pump', 'island dispenser'), ('goblet',), ('go-kart',),\n                          ('golf ball',), ('golfcart', 'golf cart'), ('gondola',), ('gong', 'tam-tam'), ('gown',),\n                          ('grand piano', 'grand'), ('greenhouse', 'nursery', 'glasshouse'),\n                          ('grille', 'radiator grille'), ('grocery store', 'grocery', 'food market', 'market'),\n                          ('guillotine',), ('hair slide',), ('hair spray',), ('half track',), ('hammer',), ('hamper',),\n                          ('hand blower', 'blow dryer', 'blow drier', 'hair dryer', 'hair drier'),\n                          ('hand-held computer', 'hand-held microcomputer'),\n                          ('handkerchief', 'hankie', 'hanky', 'hankey'), ('hard disc', 'hard disk', 'fixed disk'),\n                          ('harmonica', 'mouth organ', 'harp', 'mouth harp'), ('harp',), ('harvester', 'reaper'),\n                          ('hatchet',), ('holster',), ('home theater', 'home theatre'), ('honeycomb',),\n                          ('hook', 'claw'), ('hoopskirt', 'crinoline'), ('horizontal bar', 'high bar'),\n                          ('horse cart', 'horse-cart'), ('hourglass',), ('iPod',), ('iron', 'smoothing iron'),\n                          (\"jack-o'-lantern\",), ('jean', 'blue jean', 'denim'), ('jeep', 'landrover'),\n                          ('jersey', 'T-shirt', 'tee shirt'), ('jigsaw puzzle',), ('jinrikisha', 'ricksha', 'rickshaw'),\n                          ('joystick',), ('kimono',), ('knee pad',), ('knot',), ('lab coat', 'laboratory coat'),\n                          ('ladle',), ('lampshade', 'lamp shade'), ('laptop', 'laptop computer'),\n                          ('lawn mower', 'mower'), ('lens cap', 'lens cover'),\n                          ('letter opener', 'paper knife', 'paperknife'), ('library',), ('lifeboat',),\n                          ('lighter', 'light', 'igniter', 'ignitor'), ('limousine', 'limo'), ('liner', 'ocean liner'),\n                          ('lipstick', 'lip rouge'), ('Loafer',), ('lotion',),\n                          ('loudspeaker', 'speaker', 'speaker unit', 'loudspeaker system', 'speaker system'),\n                          ('loupe', \"jeweler's loupe\"), ('lumbermill', 'sawmill'), ('magnetic compass',),\n                          ('mailbag', 'postbag'), ('mailbox', 'letter box'), ('maillot',), ('maillot', 'tank suit'),\n                          ('manhole cover',), ('maraca',), ('marimba', 'xylophone'), ('mask',), ('matchstick',),\n                          ('maypole',), ('maze', 'labyrinth'), ('measuring cup',),\n                          ('medicine chest', 'medicine cabinet'), ('megalith', 'megalithic structure'),\n                          ('microphone', 'mike'), ('microwave', 'microwave oven'), ('military uniform',), ('milk can',),\n                          ('minibus',), ('miniskirt', 'mini'), ('minivan',), ('missile',), ('mitten',),\n                          ('mixing bowl',), ('mobile home', 'manufactured home'), ('Model T',), ('modem',),\n                          ('monastery',), ('monitor',), ('moped',), ('mortar',), ('mortarboard',), ('mosque',),\n                          ('mosquito net',), ('motor scooter', 'scooter'),\n                          ('mountain bike', 'all-terrain bike', 'off-roader'), ('mountain tent',),\n                          ('mouse', 'computer mouse'), ('mousetrap',), ('moving van',), ('muzzle',), ('nail',),\n                          ('neck brace',), ('necklace',), ('nipple',), ('notebook', 'notebook computer'), ('obelisk',),\n                          ('oboe', 'hautboy', 'hautbois'), ('ocarina', 'sweet potato'),\n                          ('odometer', 'hodometer', 'mileometer', 'milometer'), ('oil filter',),\n                          ('organ', 'pipe organ'), ('oscilloscope', 'scope', 'cathode-ray oscilloscope', 'CRO'),\n                          ('overskirt',), ('oxcart',), ('oxygen mask',), ('packet',), ('paddle', 'boat paddle'),\n                          ('paddlewheel', 'paddle wheel'), ('padlock',), ('paintbrush',),\n                          ('pajama', 'pyjama', \"pj's\", 'jammies'), ('palace',), ('panpipe', 'pandean pipe', 'syrinx'),\n                          ('paper towel',), ('parachute', 'chute'), ('parallel bars', 'bars'), ('park bench',),\n                          ('parking meter',), ('passenger car', 'coach', 'carriage'), ('patio', 'terrace'),\n                          ('pay-phone', 'pay-station'), ('pedestal', 'plinth', 'footstall'),\n                          ('pencil box', 'pencil case'), ('pencil sharpener',), ('perfume', 'essence'), ('Petri dish',),\n                          ('photocopier',), ('pick', 'plectrum', 'plectron'), ('pickelhaube',),\n                          ('picket fence', 'paling'), ('pickup', 'pickup truck'), ('pier',),\n                          ('piggy bank', 'penny bank'), ('pill bottle',), ('pillow',), ('ping-pong ball',),\n                          ('pinwheel',), ('pirate', 'pirate ship'), ('pitcher', 'ewer'),\n                          ('plane', \"carpenter's plane\", 'woodworking plane'), ('planetarium',), ('plastic bag',),\n                          ('plate rack',), ('plow', 'plough'), ('plunger', \"plumber's helper\"),\n                          ('Polaroid camera', 'Polaroid Land camera'), ('pole',),\n                          ('police van', 'police wagon', 'paddy wagon', 'patrol wagon', 'wagon', 'black Maria'),\n                          ('poncho',), ('pool table', 'billiard table', 'snooker table'), ('pop bottle', 'soda bottle'),\n                          ('pot', 'flowerpot'), (\"potter's wheel\",), ('power drill',), ('prayer rug', 'prayer mat'),\n                          ('printer',), ('prison', 'prison house'), ('projectile', 'missile'), ('projector',),\n                          ('puck', 'hockey puck'), ('punching bag', 'punch bag', 'punching ball', 'punchball'),\n                          ('purse',), ('quill', 'quill pen'), ('quilt', 'comforter', 'comfort', 'puff'),\n                          ('racer', 'race car', 'racing car'), ('racket', 'racquet'), ('radiator',),\n                          ('radio', 'wireless'), ('radio telescope', 'radio reflector'), ('rain barrel',),\n                          ('recreational vehicle', 'RV', 'R.V.'), ('reel',), ('reflex camera',),\n                          ('refrigerator', 'icebox'), ('remote control', 'remote'),\n                          ('restaurant', 'eating house', 'eating place', 'eatery'),\n                          ('revolver', 'six-gun', 'six-shooter'), ('rifle',), ('rocking chair', 'rocker'),\n                          ('rotisserie',), ('rubber eraser', 'rubber', 'pencil eraser'), ('rugby ball',),\n                          ('rule', 'ruler'), ('running shoe',), ('safe',), ('safety pin',),\n                          ('saltshaker', 'salt shaker'), ('sandal',), ('sarong',), ('sax', 'saxophone'), ('scabbard',),\n                          ('scale', 'weighing machine'), ('school bus',), ('schooner',), ('scoreboard',),\n                          ('screen', 'CRT screen'), ('screw',), ('screwdriver',), ('seat belt', 'seatbelt'),\n                          ('sewing machine',), ('shield', 'buckler'), ('shoe shop', 'shoe-shop', 'shoe store'),\n                          ('shoji',), ('shopping basket',), ('shopping cart',), ('shovel',), ('shower cap',),\n                          ('shower curtain',), ('ski',), ('ski mask',), ('sleeping bag',), ('slide rule', 'slipstick'),\n                          ('sliding door',), ('slot', 'one-armed bandit'), ('snorkel',), ('snowmobile',),\n                          ('snowplow', 'snowplough'), ('soap dispenser',), ('soccer ball',), ('sock',),\n                          ('solar dish', 'solar collector', 'solar furnace'), ('sombrero',), ('soup bowl',),\n                          ('space bar',), ('space heater',), ('space shuttle',), ('spatula',), ('speedboat',),\n                          ('spider web', \"spider's web\"), ('spindle',), ('sports car', 'sport car'),\n                          ('spotlight', 'spot'), ('stage',), ('steam locomotive',), ('steel arch bridge',),\n                          ('steel drum',), ('stethoscope',), ('stole',), ('stone wall',), ('stopwatch', 'stop watch'),\n                          ('stove',), ('strainer',), ('streetcar', 'tram', 'tramcar', 'trolley', 'trolley car'),\n                          ('stretcher',), ('studio couch', 'day bed'), ('stupa', 'tope'),\n                          ('submarine', 'pigboat', 'sub', 'U-boat'), ('suit', 'suit of clothes'), ('sundial',),\n                          ('sunglass',), ('sunglasses', 'dark glasses', 'shades'),\n                          ('sunscreen', 'sunblock', 'sun blocker'), ('suspension bridge',), ('swab', 'swob', 'mop'),\n                          ('sweatshirt',), ('swimming trunks', 'bathing trunks'), ('swing',),\n                          ('switch', 'electric switch', 'electrical switch'), ('syringe',), ('table lamp',),\n                          ('tank', 'army tank', 'armored combat vehicle', 'armoured combat vehicle'), ('tape player',),\n                          ('teapot',), ('teddy', 'teddy bear'), ('television', 'television system'), ('tennis ball',),\n                          ('thatch', 'thatched roof'), ('theater curtain', 'theatre curtain'), ('thimble',),\n                          ('thresher', 'thrasher', 'threshing machine'), ('throne',), ('tile roof',), ('toaster',),\n                          ('tobacco shop', 'tobacconist shop', 'tobacconist'), ('toilet seat',), ('torch',),\n                          ('totem pole',), ('tow truck', 'tow car', 'wrecker'), ('toyshop',), ('tractor',),\n                          ('trailer truck', 'tractor trailer', 'trucking rig', 'rig', 'articulated lorry', 'semi'),\n                          ('tray',), ('trench coat',), ('tricycle', 'trike', 'velocipede'), ('trimaran',), ('tripod',),\n                          ('triumphal arch',), ('trolleybus', 'trolley coach', 'trackless trolley'), ('trombone',),\n                          ('tub', 'vat'), ('turnstile',), ('typewriter keyboard',), ('umbrella',),\n                          ('unicycle', 'monocycle'), ('upright', 'upright piano'), ('vacuum', 'vacuum cleaner'),\n                          ('vase',), ('vault',), ('velvet',), ('vending machine',), ('vestment',), ('viaduct',),\n                          ('violin', 'fiddle'), ('volleyball',), ('waffle iron',), ('wall clock',),\n                          ('wallet', 'billfold', 'notecase', 'pocketbook'), ('wardrobe', 'closet', 'press'),\n                          ('warplane', 'military plane'),\n                          ('washbasin', 'handbasin', 'washbowl', 'lavabo', 'wash-hand basin'),\n                          ('washer', 'automatic washer', 'washing machine'), ('water bottle',), ('water jug',),\n                          ('water tower',), ('whiskey jug',), ('whistle',), ('wig',), ('window screen',),\n                          ('window shade',), ('Windsor tie',), ('wine bottle',), ('wing',), ('wok',), ('wooden spoon',),\n                          ('wool', 'woolen', 'woollen'),\n                          ('worm fence', 'snake fence', 'snake-rail fence', 'Virginia fence'), ('wreck',), ('yawl',),\n                          ('yurt',), ('web site', 'website', 'internet site', 'site'), ('comic book',),\n                          ('crossword puzzle', 'crossword'), ('street sign',),\n                          ('traffic light', 'traffic signal', 'stoplight'),\n                          ('book jacket', 'dust cover', 'dust jacket', 'dust wrapper'), ('menu',), ('plate',),\n                          ('guacamole',), ('consomme',), ('hot pot', 'hotpot'), ('trifle',), ('ice cream', 'icecream'),\n                          ('ice lolly', 'lolly', 'lollipop', 'popsicle'), ('French loaf',), ('bagel', 'beigel'),\n                          ('pretzel',), ('cheeseburger',), ('hotdog', 'hot dog', 'red hot'), ('mashed potato',),\n                          ('head cabbage',), ('broccoli',), ('cauliflower',), ('zucchini', 'courgette'),\n                          ('spaghetti squash',), ('acorn squash',), ('butternut squash',), ('cucumber', 'cuke'),\n                          ('artichoke', 'globe artichoke'), ('bell pepper',), ('cardoon',), ('mushroom',),\n                          ('Granny Smith',), ('strawberry',), ('orange',), ('lemon',), ('fig',),\n                          ('pineapple', 'ananas'), ('banana',), ('jackfruit', 'jak', 'jack'), ('custard apple',),\n                          ('pomegranate',), ('hay',), ('carbonara',), ('chocolate sauce', 'chocolate syrup'),\n                          ('dough',), ('meat loaf', 'meatloaf'), ('pizza', 'pizza pie'), ('potpie',), ('burrito',),\n                          ('red wine',), ('espresso',), ('cup',), ('eggnog',), ('alp',), ('bubble',),\n                          ('cliff', 'drop', 'drop-off'), ('coral reef',), ('geyser',), ('lakeside', 'lakeshore'),\n                          ('promontory', 'headland', 'head', 'foreland'), ('sandbar', 'sand bar'),\n                          ('seashore', 'coast', 'seacoast', 'sea-coast'), ('valley', 'vale'), ('volcano',),\n                          ('ballplayer', 'baseball player'), ('groom', 'bridegroom'), ('scuba diver',), ('rapeseed',),\n                          ('daisy',), (\"yellow lady's slipper\", 'yellow lady-slipper', 'Cypripedium calceolus',\n                                       'Cypripedium parviflorum'), ('corn',), ('acorn',),\n                          ('hip', 'rose hip', 'rosehip'), ('buckeye', 'horse chestnut', 'conker'), ('coral fungus',),\n                          ('agaric',), ('gyromitra',), ('stinkhorn', 'carrion fungus'), ('earthstar',),\n                          ('hen-of-the-woods', 'hen of the woods', 'Polyporus frondosus', 'Grifola frondosa'),\n                          ('bolete',), ('ear', 'spike', 'capitulum'),\n                          ('toilet tissue', 'toilet paper', 'bathroom tissue')]]\n\n\nclass ImageNetCaltech(ImageList):\n    \"\"\"ImageNet-Caltech is constructed from `Caltech-256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ and\n    `ImageNet-1K <http://image-net.org/>`_ .\n\n    They share 84 common classes. ImageNet-Caltech keeps all classes of ImageNet-1K.\n    The label is based on the ImageNet-1K (class 0-999) . The private classes of Caltech-256 is discarded.\n\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \\\n            ``'I'``: ImageNet-1K training set.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory\n        since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.\n        In `root`, there will exist following files after downloading.\n        ::\n            train/\n                n01440764/\n                ...\n            val/\n            256_ObjectCategories/\n                001.ak47/\n                ...\n            image_list/\n                caltech_256_list.txt\n                ...\n    \"\"\"\n    image_list = {\n        \"I\": \"image_list/imagenet_train_1000_list.txt\",\n        \"C\": \"image_list/caltech_84_list.txt\",\n    }\n    CLASSES = _CLASSES\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):\n        assert task in self.image_list\n        data_list_file = os.path.join(root, self.image_list[task])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), download_list))\n\n        if not os.path.exists(os.path.join(root, 'train')):\n            print(\"Please put train and val directory of ImageNet-1K manually under {} \"\n                  \"since ImageNet-1K is no longer publicly accessible.\".format(root))\n            exit(-1)\n\n        super(ImageNetCaltech, self).__init__(root, ImageNetCaltech.CLASSES, data_list_file=data_list_file, **kwargs)\n\n\nclass ImageNetCaltechUniversal(ImageList):\n    \"\"\"ImageNet-Caltech-Universal is constructed from `Caltech-256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_\n    and `ImageNet-1K <http://image-net.org/>`_ .\n\n    They share 84 common classes. ImageNet-Caltech keeps all classes of ImageNet-1K.\n    The label is based on the ImageNet-1K (class 0-999) . The private classes of Caltech-256 is grouped into class 1000 (\"unknown\").\n    Thus, ImageNetCaltechUniversal has 1001 classes in total.\n\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \\\n            ``'I'``: ImageNet-1K training set.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory\n        since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.\n        In `root`, there will exist following files after downloading.\n        ::\n            train/\n                n01440764/\n                ...\n            val/\n            256_ObjectCategories/\n                001.ak47/\n                ...\n            image_list/\n                caltech_256_list.txt\n                ...\n    \"\"\"\n    image_list = {\n        \"I\": \"image_list/imagenet_train_1000_list.txt\",\n        \"C\": \"image_list/caltech_85_list.txt\",\n    }\n    CLASSES = _CLASSES + [\"unknown\"]\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):\n        assert task in self.image_list\n        data_list_file = os.path.join(root, self.image_list[task])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), download_list))\n\n        if not os.path.exists(os.path.join(root, 'train')):\n            print(\"Please put train and val directory of ImageNet-1K manually under {} \"\n                  \"since ImageNet-1K is no longer publicly accessible.\".format(root))\n            exit(-1)\n\n        super(ImageNetCaltechUniversal, self).__init__(root, ImageNetCaltechUniversal.CLASSES, data_list_file=data_list_file, **kwargs)\n\n\ndownload_list = [\n    (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/a0d7ea37026946f98965/?dl=1\"),\n    (\"256_ObjectCategories\", \"256_ObjectCategories.tar\",\n     \"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar\"),\n]\n"
  },
  {
    "path": "tllib/vision/datasets/patchcamelyon.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass PatchCamelyon(ImageList):\n    \"\"\"\n    The `PatchCamelyon <https://patchcamelyon.grand-challenge.org/>`_ dataset contains \\\n        327680 images of histopathologic scans of lymph node sections. \\\n        The classification task consists in predicting the presence of metastatic tissue \\\n         in given image (i.e., two classes). All images are 96x96 pixels\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n    \"\"\"\n    CLASSES = ['0', '1']\n\n    def __init__(self, root, split, download=False,  **kwargs):\n        if download:\n            download_data(root, \"patch_camelyon\", \"patch_camelyon.tgz\", \"https://cloud.tsinghua.edu.cn/f/21360b3441a54274b843/?dl=1\")\n        else:\n            check_exits(root, \"patch_camelyon\")\n\n        root = os.path.join(root, \"patch_camelyon\")\n        super(PatchCamelyon, self).__init__(root, PatchCamelyon.CLASSES, os.path.join(root, \"imagelist\", \"{}.txt\".format(split)), **kwargs)\n\n"
  },
  {
    "path": "tllib/vision/datasets/regression/__init__.py",
    "content": "from .image_regression import ImageRegression\nfrom .dsprites import DSprites\nfrom .mpi3d import MPI3D"
  },
  {
    "path": "tllib/vision/datasets/regression/dsprites.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, Sequence\nimport os\nfrom .._util import download as download_data, check_exits\nfrom .image_regression import ImageRegression\n\n\nclass DSprites(ImageRegression):\n    \"\"\"`DSprites <https://github.com/deepmind/dsprites-dataset>`_ Dataset.\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \\\n            ``'N'``: Noisy and ``'S'``: Scream.\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        factors (sequence[str]): Factors selected. Default: ('scale', 'position x', 'position y').\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            color/\n                ...\n            noisy/\n            scream/\n            image_list/\n                color_train.txt\n                noisy_train.txt\n                scream_train.txt\n                color_test.txt\n                noisy_test.txt\n                scream_test.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/4392ef903ed14017a042/?dl=1\"),\n        (\"color\", \"color.tgz\", \"https://cloud.tsinghua.edu.cn/f/6d243c589d384ff5a212/?dl=1\"),\n        (\"noisy\", \"noisy.tgz\", \"https://cloud.tsinghua.edu.cn/f/9a23ede3be1740328637/?dl=1\"),\n        (\"scream\", \"scream.tgz\", \"https://cloud.tsinghua.edu.cn/f/8fc4d34311bb4db6bcde/?dl=1\"),\n    ]\n    image_list = {\n        \"C\": \"color\",\n        \"N\": \"noisy\",\n        \"S\": \"scream\"\n    }\n    FACTORS = ('none', 'shape', 'scale', 'orientation', 'position x', 'position y')\n\n    def __init__(self, root: str, task: str, split: Optional[str] = 'train',\n                 factors: Sequence[str] = ('scale', 'position x', 'position y'),\n                 download: Optional[bool] = True, target_transform=None, **kwargs):\n        assert task in self.image_list\n        assert split in ['train', 'test']\n        for factor in factors:\n            assert factor in self.FACTORS\n\n        factor_index = [self.FACTORS.index(factor) for factor in factors]\n\n        if target_transform is None:\n            target_transform = lambda x: x[list(factor_index)]\n        else:\n            target_transform = lambda x: target_transform(x[list(factor_index)])\n\n        data_list_file = os.path.join(root, \"image_list\", \"{}_{}.txt\".format(self.image_list[task], split))\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(DSprites, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform,\n                                       **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/regression/image_regression.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import Optional, Callable, Tuple, Any, List, Sequence\nimport torchvision.datasets as datasets\nfrom torchvision.datasets.folder import default_loader\nimport numpy as np\n\n\nclass ImageRegression(datasets.VisionDataset):\n    \"\"\"A generic Dataset class for domain adaptation in image regression\n\n    Args:\n        root (str): Root directory of dataset\n        factors (sequence[str]): Factors selected. Default: ('scale', 'position x', 'position y').\n        data_list_file (str): File to read the image list from.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note::\n        In `data_list_file`, each line has `1+len(factors)` values in the following format.\n        ::\n            source_dir/dog_xxx.png x11, x12, ...\n            source_dir/cat_123.png x21, x22, ...\n            target_dir/dog_xxy.png x31, x32, ...\n            target_dir/cat_nsdf3.png x41, x42, ...\n\n        The first value is the relative path of an image, and the rest values are the ground truth of the corresponding factors.\n        If your data_list_file has different formats, please over-ride :meth:`ImageRegression.parse_data_file`.\n    \"\"\"\n    def __init__(self, root: str, factors: Sequence[str], data_list_file: str,\n                 transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):\n        super().__init__(root, transform=transform, target_transform=target_transform)\n        self.samples = self.parse_data_file(data_list_file)\n        self.factors = factors\n        self.loader = default_loader\n        self.data_list_file = data_list_file\n\n    def __getitem__(self, index: int) -> Tuple[Any, Tuple[float]]:\n        \"\"\"\n        Args:\n            index (int): Index\n\n        Returns:\n            (image, target) where target is a numpy float array.\n        \"\"\"\n        path, target = self.samples[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.target_transform is not None and target is not None:\n            target = self.target_transform(target)\n        return img, target\n\n    def __len__(self) -> int:\n        return len(self.samples)\n\n    def parse_data_file(self, file_name: str) -> List[Tuple[str, Any]]:\n        \"\"\"Parse file to data list\n\n        Args:\n            file_name (str): The path of data file\n\n        Returns:\n            List of (image path, (factors)) tuples\n        \"\"\"\n        with open(file_name, \"r\") as f:\n            data_list = []\n            for line in f.readlines():\n                data = line.split()\n                path = str(data[0])\n                target = np.array([float(d) for d in data[1:]], dtype=np.float)\n                if not os.path.isabs(path):\n                    path = os.path.join(self.root, path)\n                data_list.append((path, target))\n        return data_list\n\n    @property\n    def num_factors(self) -> int:\n        return len(self.factors)"
  },
  {
    "path": "tllib/vision/datasets/regression/mpi3d.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Optional, Sequence\nimport os\nfrom .._util import download as download_data, check_exits\nfrom .image_regression import ImageRegression\n\n\nclass MPI3D(ImageRegression):\n    \"\"\"`MPI3D <https://arxiv.org/abs/1906.03292>`_ Dataset.\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \\\n            ``'N'``: Noisy and ``'S'``: Scream.\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        factors (sequence[str]): Factors selected. Default: ('horizontal axis', 'vertical axis').\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            real/\n                ...\n            realistic/\n            toy/\n            image_list/\n                real_train.txt\n                realistic_train.txt\n                toy_train.txt\n                real_test.txt\n                realistic_test.txt\n                toy_test.txt\n        \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/feacec494d5347b7a6aa/?dl=1\"),\n        (\"real\", \"real.tgz\", \"https://cloud.tsinghua.edu.cn/f/605dd842cd9d4071a0ae/?dl=1\"),\n        (\"realistic\", \"realistic.tgz\", \"https://cloud.tsinghua.edu.cn/f/05743f3071054cc29e25/?dl=1\"),\n        (\"toy\", \"toy.tgz\", \"https://cloud.tsinghua.edu.cn/f/1511dff7853d4abea38f/?dl=1\"),\n    ]\n    image_list = {\n        \"RL\": \"real\",\n        \"RC\": \"realistic\",\n        \"T\": \"toy\"\n    }\n    FACTORS = ('horizontal axis', 'vertical axis')\n\n    def __init__(self, root: str, task: str, split: Optional[str] = 'train',\n                 factors: Sequence[str] = ('horizontal axis', 'vertical axis'),\n                 download: Optional[bool] = True, target_transform=None, **kwargs):\n        assert task in self.image_list\n        assert split in ['train', 'test']\n        for factor in factors:\n            assert factor in self.FACTORS\n\n        factor_index = [self.FACTORS.index(factor) for factor in factors]\n\n        if target_transform is None:\n            target_transform = lambda x: x[list(factor_index)] / 39.\n        else:\n            target_transform = lambda x: target_transform(x[list(factor_index)]) / 39.\n\n        data_list_file = os.path.join(root, \"image_list\", \"{}_{}.txt\".format(self.image_list[task], split))\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(MPI3D, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform, **kwargs)\n\n"
  },
  {
    "path": "tllib/vision/datasets/reid/__init__.py",
    "content": "from .market1501 import Market1501\nfrom .dukemtmc import DukeMTMC\nfrom .msmt17 import MSMT17\nfrom .personx import PersonX\nfrom .unreal import UnrealPerson\n\n__all__ = ['Market1501', 'DukeMTMC', 'MSMT17', 'PersonX', 'UnrealPerson']\n"
  },
  {
    "path": "tllib/vision/datasets/reid/basedataset.py",
    "content": "\"\"\"\nModified from https://github.com/yxgeee/MMT\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport os.path as osp\nimport numpy as np\n\n\nclass BaseDataset(object):\n    \"\"\"\n    Base class of reid dataset\n    \"\"\"\n\n    def get_imagedata_info(self, data):\n        pids, cams = [], []\n        for _, pid, camid in data:\n            pids += [pid]\n            cams += [camid]\n        pids = set(pids)\n        cams = set(cams)\n        num_pids = len(pids)\n        num_cams = len(cams)\n        num_imgs = len(data)\n        return num_pids, num_imgs, num_cams\n\n    def get_videodata_info(self, data, return_tracklet_stats=False):\n        pids, cams, tracklet_stats = [], [], []\n        for img_paths, pid, camid in data:\n            pids += [pid]\n            cams += [camid]\n            tracklet_stats += [len(img_paths)]\n        pids = set(pids)\n        cams = set(cams)\n        num_pids = len(pids)\n        num_cams = len(cams)\n        num_tracklets = len(data)\n        if return_tracklet_stats:\n            return num_pids, num_tracklets, num_cams, tracklet_stats\n        return num_pids, num_tracklets, num_cams\n\n    def print_dataset_statistics(self, train, query, galler):\n        raise NotImplementedError\n\n    def check_before_run(self, required_files):\n        \"\"\"Checks if required files exist before going deeper.\n        Args:\n            required_files (str or list): string file name(s).\n        \"\"\"\n        if isinstance(required_files, str):\n            required_files = [required_files]\n\n        for fpath in required_files:\n            if not osp.exists(fpath):\n                raise RuntimeError('\"{}\" is not found'.format(fpath))\n\n    @property\n    def images_dir(self):\n        return None\n\n\nclass BaseImageDataset(BaseDataset):\n    \"\"\"\n    Base class of image reid dataset\n    \"\"\"\n\n    def print_dataset_statistics(self, train, query, gallery):\n        num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)\n        num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)\n        num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)\n\n        print(\"Dataset statistics:\")\n        print(\"  ----------------------------------------\")\n        print(\"  subset   | # ids | # images | # cameras\")\n        print(\"  ----------------------------------------\")\n        print(\"  train    | {:5d} | {:8d} | {:9d}\".format(num_train_pids, num_train_imgs, num_train_cams))\n        print(\"  query    | {:5d} | {:8d} | {:9d}\".format(num_query_pids, num_query_imgs, num_query_cams))\n        print(\"  gallery  | {:5d} | {:8d} | {:9d}\".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))\n        print(\"  ----------------------------------------\")\n\n\nclass BaseVideoDataset(BaseDataset):\n    \"\"\"\n    Base class of video reid dataset\n    \"\"\"\n\n    def print_dataset_statistics(self, train, query, gallery):\n        num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \\\n            self.get_videodata_info(train, return_tracklet_stats=True)\n\n        num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \\\n            self.get_videodata_info(query, return_tracklet_stats=True)\n\n        num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \\\n            self.get_videodata_info(gallery, return_tracklet_stats=True)\n\n        tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats\n        min_num = np.min(tracklet_stats)\n        max_num = np.max(tracklet_stats)\n        avg_num = np.mean(tracklet_stats)\n\n        print(\"Dataset statistics:\")\n        print(\"  -------------------------------------------\")\n        print(\"  subset   | # ids | # tracklets | # cameras\")\n        print(\"  -------------------------------------------\")\n        print(\"  train    | {:5d} | {:11d} | {:9d}\".format(num_train_pids, num_train_tracklets, num_train_cams))\n        print(\"  query    | {:5d} | {:11d} | {:9d}\".format(num_query_pids, num_query_tracklets, num_query_cams))\n        print(\"  gallery  | {:5d} | {:11d} | {:9d}\".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))\n        print(\"  -------------------------------------------\")\n        print(\"  number of images per tracklet: {} ~ {}, average {:.2f}\".format(min_num, max_num, avg_num))\n        print(\"  -------------------------------------------\")\n"
  },
  {
    "path": "tllib/vision/datasets/reid/convert.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport os.path as osp\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\n\ndef convert_to_pytorch_dataset(dataset, root=None, transform=None, return_idxes=False):\n    class ReidDataset(Dataset):\n        def __init__(self, dataset, root, transform):\n            super(ReidDataset, self).__init__()\n            self.dataset = dataset\n            self.root = root\n            self.transform = transform\n            self.return_idxes = return_idxes\n\n        def __len__(self):\n            return len(self.dataset)\n\n        def __getitem__(self, index):\n            fname, pid, cid = self.dataset[index]\n            fpath = fname\n            if self.root is not None:\n                fpath = osp.join(self.root, fname)\n\n            img = Image.open(fpath).convert('RGB')\n\n            if self.transform is not None:\n                img = self.transform(img)\n\n            if not self.return_idxes:\n                return img, fname, pid, cid\n            else:\n                return img, fname, pid, cid, index\n\n    return ReidDataset(dataset, root, transform)\n"
  },
  {
    "path": "tllib/vision/datasets/reid/dukemtmc.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom .basedataset import BaseImageDataset\nfrom typing import Callable\nfrom PIL import Image\nimport os\nimport os.path as osp\nimport glob\nimport re\nfrom tllib.vision.datasets._util import download\n\n\nclass DukeMTMC(BaseImageDataset):\n    \"\"\"DukeMTMC-reID dataset from `Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking\n    (ECCV 2016) <https://arxiv.org/pdf/1609.01775v2.pdf>`_.\n\n    Dataset statistics:\n        - identities: 1404 (train + query)\n        - images:16522 (train) + 2228 (query) + 17661 (gallery)\n        - cameras: 8\n\n    Args:\n        root (str): Root directory of dataset\n        verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True\n    \"\"\"\n    dataset_dir = '.'\n    archive_name = 'DukeMTMC-reID.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/cb80f49905ee4e8eb9f0/?dl=1'\n\n    def __init__(self, root, verbose=True):\n        super(DukeMTMC, self).__init__()\n        download(root, self.dataset_dir, self.archive_name, self.dataset_url)\n        self.relative_dataset_dir = self.dataset_dir\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')\n\n        required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]\n        self.check_before_run(required_files)\n\n        train = self.process_dir(self.train_dir, relabel=True)\n        query = self.process_dir(self.query_dir, relabel=False)\n        gallery = self.process_dir(self.gallery_dir, relabel=False)\n\n        if verbose:\n            print(\"=> DukeMTMC-reID loaded\")\n            self.print_dataset_statistics(train, query, gallery)\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def process_dir(self, dir_path, relabel=False):\n        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n        pattern = re.compile(r'([-\\d]+)_c(\\d)')\n\n        pid_container = set()\n        for img_path in img_paths:\n            pid, _ = map(int, pattern.search(img_path).groups())\n            pid_container.add(pid)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n\n        dataset = []\n        for img_path in img_paths:\n            pid, cid = map(int, pattern.search(img_path).groups())\n            assert 1 <= cid <= 8\n            cid -= 1  # index starts from 0\n            if relabel:\n                pid = pid2label[pid]\n            dataset.append((img_path, pid, cid))\n\n        return dataset\n\n    def translate(self, transform: Callable, target_root: str):\n        \"\"\" Translate an image and save it into a specified directory\n\n        Args:\n            transform (callable): a transform function that maps images from one domain to another domain\n            target_root (str): the root directory to save images\n\n        \"\"\"\n        os.makedirs(target_root, exist_ok=True)\n        translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)\n\n        translated_train_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/bounding_box_train')\n        translated_query_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/query')\n        translated_gallery_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/bounding_box_test')\n\n        print(\"Translating dataset with image to image transform...\")\n        self.translate_dir(transform, self.train_dir, translated_train_dir)\n        self.translate_dir(None, self.query_dir, translated_query_dir)\n        self.translate_dir(None, self.gallery_dir, translated_gallery_dir)\n        print(\"Translation process is done, save dataset to {}\".format(translated_dataset_dir))\n\n    def translate_dir(self, transform, origin_dir: str, target_dir: str):\n        image_list = os.listdir(origin_dir)\n        for image_name in image_list:\n            if not image_name.endswith(\".jpg\"):\n                continue\n            image_path = osp.join(origin_dir, image_name)\n            image = Image.open(image_path)\n            translated_image_path = osp.join(target_dir, image_name)\n            translated_image = image\n            if transform:\n                translated_image = transform(image)\n\n            os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)\n            translated_image.save(translated_image_path)\n"
  },
  {
    "path": "tllib/vision/datasets/reid/market1501.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom .basedataset import BaseImageDataset\nfrom typing import Callable\nfrom PIL import Image\nimport os\nimport os.path as osp\nimport glob\nimport re\nfrom tllib.vision.datasets._util import download\n\n\nclass Market1501(BaseImageDataset):\n    \"\"\"Market1501 dataset from `Scalable Person Re-identification: A Benchmark (ICCV 2015)\n    <https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=7410490>`_.\n\n    Dataset statistics:\n        - identities: 1501 (+1 for background)\n        - images: 12936 (train) + 3368 (query) + 15913 (gallery)\n        - cameras: 6\n\n    Args:\n        root (str): Root directory of dataset\n        verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True\n    \"\"\"\n    dataset_dir = 'Market-1501-v15.09.15'\n    archive_name = 'Market-1501-v15.09.15.tgz'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/29e5f015a7314531b645/?dl=1'\n\n    def __init__(self, root, verbose=True):\n        super(Market1501, self).__init__()\n        download(root, self.dataset_dir, self.archive_name, self.dataset_url)\n        self.relative_dataset_dir = self.dataset_dir\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')\n\n        required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]\n        self.check_before_run(required_files)\n\n        train = self.process_dir(self.train_dir, relabel=True)\n        query = self.process_dir(self.query_dir, relabel=False)\n        gallery = self.process_dir(self.gallery_dir, relabel=False)\n\n        if verbose:\n            print(\"=> Market1501 loaded\")\n            self.print_dataset_statistics(train, query, gallery)\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def process_dir(self, dir_path, relabel=False):\n        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n        pattern = re.compile(r'([-\\d]+)_c(\\d)')\n\n        pid_container = set()\n        for img_path in img_paths:\n            pid, _ = map(int, pattern.search(img_path).groups())\n            if pid == -1:\n                continue  # junk images are just ignored\n            pid_container.add(pid)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n\n        dataset = []\n        for img_path in img_paths:\n            pid, cid = map(int, pattern.search(img_path).groups())\n            if pid == -1:\n                continue  # junk images are just ignored\n            assert 0 <= pid <= 1501  # pid == 0 means background\n            assert 1 <= cid <= 6\n            cid -= 1  # index starts from 0\n            if relabel:\n                pid = pid2label[pid]\n            dataset.append((img_path, pid, cid))\n\n        return dataset\n\n    def translate(self, transform: Callable, target_root: str):\n        \"\"\" Translate an image and save it into a specified directory\n\n        Args:\n            transform (callable): a transform function that maps images from one domain to another domain\n            target_root (str): the root directory to save images\n\n        \"\"\"\n        os.makedirs(target_root, exist_ok=True)\n        translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)\n\n        translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train')\n        translated_query_dir = osp.join(translated_dataset_dir, 'query')\n        translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test')\n\n        print(\"Translating dataset with image to image transform...\")\n        self.translate_dir(transform, self.train_dir, translated_train_dir)\n        self.translate_dir(None, self.query_dir, translated_query_dir)\n        self.translate_dir(None, self.gallery_dir, translated_gallery_dir)\n        print(\"Translation process is done, save dataset to {}\".format(translated_dataset_dir))\n\n    def translate_dir(self, transform, origin_dir: str, target_dir: str):\n        image_list = os.listdir(origin_dir)\n        for image_name in image_list:\n            if not image_name.endswith(\".jpg\"):\n                continue\n            image_path = osp.join(origin_dir, image_name)\n            image = Image.open(image_path)\n            translated_image_path = osp.join(target_dir, image_name)\n            translated_image = image\n            if transform:\n                translated_image = transform(image)\n\n            os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)\n            translated_image.save(translated_image_path)\n"
  },
  {
    "path": "tllib/vision/datasets/reid/msmt17.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom .basedataset import BaseImageDataset\nfrom typing import Callable\nfrom PIL import Image\nimport os\nimport os.path as osp\nfrom tllib.vision.datasets._util import download\n\n\nclass MSMT17(BaseImageDataset):\n    \"\"\"MSMT17 dataset from `Person Transfer GAN to Bridge Domain Gap for Person Re-Identification (CVPR 2018)\n    <https://arxiv.org/pdf/1711.08565.pdf>`_.\n\n    Dataset statistics:\n        - identities: 4101\n        - images: 32621 (train) + 11659 (query) + 82161 (gallery)\n        - cameras: 15\n\n    Args:\n        root (str): Root directory of dataset\n        verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True\n    \"\"\"\n    dataset_dir = '.'\n    archive_name = 'MSMT17_V1.zip'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/c254ea490cfa4115940d/?dl=1'\n\n    def __init__(self, root, verbose=True):\n        super(MSMT17, self).__init__()\n        download(root, self.dataset_dir, self.archive_name, self.dataset_url)\n        self.relative_dataset_dir = self.dataset_dir\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')\n\n        required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]\n        self.check_before_run(required_files)\n\n        self.train = self.process_dir(self.train_dir)\n        self.query = self.process_dir(self.query_dir)\n        self.gallery = self.process_dir(self.gallery_dir)\n        if verbose:\n            print(\"=> MSMT17 loaded\")\n            self.print_dataset_statistics(self.train, self.query, self.gallery)\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def process_dir(self, dir_path):\n        image_list = os.listdir(dir_path)\n        dataset = []\n        pid_container = set()\n\n        for image_path in image_list:\n            pid, cid, _ = image_path.split('_')\n            pid = int(pid)\n            cid = int(cid[1:]) - 1  # index starts from 0\n            full_image_path = osp.join(dir_path, image_path)\n            dataset.append((full_image_path, pid, cid))\n            pid_container.add(pid)\n\n        # check if pid starts from 0 and increments with 1\n        for idx, pid in enumerate(pid_container):\n            assert idx == pid, \"See code comment for explanation\"\n        return dataset\n\n    def translate(self, transform: Callable, target_root: str):\n        \"\"\" Translate an image and save it into a specified directory\n\n        Args:\n            transform (callable): a transform function that maps images from one domain to another domain\n            target_root (str): the root directory to save images\n\n        \"\"\"\n        os.makedirs(target_root, exist_ok=True)\n        translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)\n\n        translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train')\n        translated_query_dir = osp.join(translated_dataset_dir, 'query')\n        translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test')\n\n        print(\"Translating dataset with image to image transform...\")\n        self.translate_dir(transform, self.train_dir, translated_train_dir)\n        self.translate_dir(None, self.query_dir, translated_query_dir)\n        self.translate_dir(None, self.gallery_dir, translated_gallery_dir)\n        print(\"Translation process is done, save dataset to {}\".format(translated_dataset_dir))\n\n    def translate_dir(self, transform, origin_dir: str, target_dir: str):\n        image_list = os.listdir(origin_dir)\n        for image_name in image_list:\n            if not image_name.endswith(\".jpg\"):\n                continue\n            image_path = osp.join(origin_dir, image_name)\n            image = Image.open(image_path)\n            translated_image_path = osp.join(target_dir, image_name)\n            translated_image = image\n            if transform:\n                translated_image = transform(image)\n\n            os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)\n            translated_image.save(translated_image_path)\n"
  },
  {
    "path": "tllib/vision/datasets/reid/personx.py",
    "content": "\"\"\"\nModified from https://github.com/yxgeee/SpCL\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom .basedataset import BaseImageDataset\nfrom typing import Callable\nfrom PIL import Image\nimport os\nimport os.path as osp\nimport glob\nimport re\nfrom tllib.vision.datasets._util import download\n\n\nclass PersonX(BaseImageDataset):\n    \"\"\"PersonX dataset from `Dissecting Person Re-identification from the Viewpoint of Viewpoint (CVPR 2019)\n    <https://arxiv.org/pdf/1812.02162.pdf>`_.\n\n    Dataset statistics:\n        - identities: 1266\n        - images: 9840 (train) + 5136 (query) + 30816 (gallery)\n        - cameras: 6\n\n    Args:\n        root (str): Root directory of dataset\n        verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True\n    \"\"\"\n    dataset_dir = '.'\n    archive_name = 'PersonX.zip'\n    dataset_url = 'https://cloud.tsinghua.edu.cn/f/f506cd11d6b646729bd1/?dl=1'\n\n    def __init__(self, root, verbose=True):\n        super(PersonX, self).__init__()\n        download(root, self.dataset_dir, self.archive_name, self.dataset_url)\n        self.relative_dataset_dir = self.dataset_dir\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')\n\n        required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]\n        self.check_before_run(required_files)\n\n        train = self.process_dir(self.train_dir, relabel=True)\n        query = self.process_dir(self.query_dir, relabel=False)\n        gallery = self.process_dir(self.gallery_dir, relabel=False)\n\n        if verbose:\n            print(\"=> PersonX loaded\")\n            self.print_dataset_statistics(train, query, gallery)\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def process_dir(self, dir_path, relabel=False):\n        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n        pattern = re.compile(r'([-\\d]+)_c([-\\d]+)')\n        cam2label = {3: 1, 4: 2, 8: 3, 10: 4, 11: 5, 12: 6}\n\n        pid_container = set()\n        for img_path in img_paths:\n            pid, _ = map(int, pattern.search(img_path).groups())\n            pid_container.add(pid)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n\n        dataset = []\n        for img_path in img_paths:\n            pid, cid = map(int, pattern.search(img_path).groups())\n            assert (cid in cam2label.keys())\n            cid = cam2label[cid]\n            cid -= 1  # index starts from 0\n            if relabel:\n                pid = pid2label[pid]\n            dataset.append((img_path, pid, cid))\n\n        return dataset\n\n    def translate(self, transform: Callable, target_root: str):\n        \"\"\" Translate an image and save it into a specified directory\n\n        Args:\n            transform (callable): a transform function that maps images from one domain to another domain\n            target_root (str): the root directory to save images\n\n        \"\"\"\n        os.makedirs(target_root, exist_ok=True)\n        translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)\n\n        translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train')\n        translated_query_dir = osp.join(translated_dataset_dir, 'query')\n        translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test')\n\n        print(\"Translating dataset with image to image transform...\")\n        self.translate_dir(transform, self.train_dir, translated_train_dir)\n        self.translate_dir(None, self.query_dir, translated_query_dir)\n        self.translate_dir(None, self.gallery_dir, translated_gallery_dir)\n        print(\"Translation process is done, save dataset to {}\".format(translated_dataset_dir))\n\n    def translate_dir(self, transform, origin_dir: str, target_dir: str):\n        image_list = os.listdir(origin_dir)\n        for image_name in image_list:\n            if not image_name.endswith(\".jpg\"):\n                continue\n            image_path = osp.join(origin_dir, image_name)\n            image = Image.open(image_path)\n            translated_image_path = osp.join(target_dir, image_name)\n            translated_image = image\n            if transform:\n                translated_image = transform(image)\n\n            os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)\n            translated_image.save(translated_image_path)\n"
  },
  {
    "path": "tllib/vision/datasets/reid/unreal.py",
    "content": "\"\"\"\nModified from https://github.com/SikaStar/IDM\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom .basedataset import BaseImageDataset\nfrom typing import Callable\nimport os.path as osp\nfrom tllib.vision.datasets._util import download\n\n\nclass UnrealPerson(BaseImageDataset):\n    \"\"\"UnrealPerson dataset from `UnrealPerson: An Adaptive Pipeline towards Costless Person Re-identification\n    (CVPR 2021) <https://arxiv.org/pdf/2012.04268v2.pdf>`_.\n\n    Dataset statistics:\n        - identities: 3000\n        - images: 120,000\n        - cameras: 34\n\n    Args:\n        root (str): Root directory of dataset\n        verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True\n    \"\"\"\n    dataset_dir = '.'\n    download_list = [\n        (\"list_unreal_train.txt\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/a51b22fd760743e7bca6/?dl=1\"),\n        (\"unreal_v1.1\", \"unreal_v1.1.tar\", \"https://cloud.tsinghua.edu.cn/f/a8806bb3bf1744dda5b1/?dl=1\"),\n        (\"unreal_v1.2\", \"unreal_v1.2.tar\", \"https://cloud.tsinghua.edu.cn/f/449224485a654c5baa8f/?dl=1\"),\n        (\"unreal_v1.3\", \"unreal_v1.3.tar\", \"https://cloud.tsinghua.edu.cn/f/069f3162f74849c09c10/?dl=1\"),\n        (\"unreal_v2.1\", \"unreal_v2.1.tar\", \"https://cloud.tsinghua.edu.cn/f/a791aaa42674466eb183/?dl=1\"),\n        (\"unreal_v2.2\", \"unreal_v2.2.tar\", \"https://cloud.tsinghua.edu.cn/f/b601d9f54f964248bd0e/?dl=1\"),\n        (\"unreal_v2.3\", \"unreal_v2.3.tar\", \"https://cloud.tsinghua.edu.cn/f/311ec60e810b42d48d12/?dl=1\"),\n        (\"unreal_v3.1\", \"unreal_v3.1.tar\", \"https://cloud.tsinghua.edu.cn/f/d51b7c1d125e4632bcf9/?dl=1\"),\n        (\"unreal_v3.2\", \"unreal_v3.2.tar\", \"https://cloud.tsinghua.edu.cn/f/4efbd969ea2f4e8197e8/?dl=1\"),\n        (\"unreal_v3.3\", \"unreal_v3.3.tar\", \"https://cloud.tsinghua.edu.cn/f/a3cc3d9c460247848fb7/?dl=1\"),\n        (\"unreal_v4.1\", \"unreal_v4.1.tar\", \"https://cloud.tsinghua.edu.cn/f/ca05183ac9cd4be5a53b/?dl=1\"),\n        (\"unreal_v4.2\", \"unreal_v4.2.tar\", \"https://cloud.tsinghua.edu.cn/f/b90722cbd754496f9f40/?dl=1\"),\n        (\"unreal_v4.3\", \"unreal_v4.3.tar\", \"https://cloud.tsinghua.edu.cn/f/547ae646c3d346038297/?dl=1\"),\n    ]\n\n    def __init__(self, root, verbose=True):\n        super(UnrealPerson, self).__init__()\n        list(map(lambda args: download(root, *args), self.download_list))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_list = osp.join(self.dataset_dir, 'list_unreal_train.txt')\n\n        required_files = [self.dataset_dir, self.train_list]\n        self.check_before_run(required_files)\n\n        train = self.process_dir(self.train_list)\n        self.train = train\n        self.query = []\n        self.gallery = []\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n\n        if verbose:\n            print(\"=> UnrealPerson loaded\")\n            print(\"  ----------------------------------------\")\n            print(\"  subset   | # ids | # cams | # images\")\n            print(\"  ----------------------------------------\")\n            print(\"  train    | {:5d} | {:5d} | {:8d}\"\n                  .format(self.num_train_pids, self.num_train_cams, self.num_train_imgs))\n            print(\"  ----------------------------------------\")\n\n    def process_dir(self, list_file):\n        with open(list_file, 'r') as f:\n            lines = f.readlines()\n        dataset = []\n        pid_container = set()\n        for line in lines:\n            line = line.strip()\n            pid = line.split(' ')[1]\n            pid_container.add(pid)\n\n        pid2label = {pid: label for label, pid in enumerate(sorted(pid_container))}\n\n        for line in lines:\n            line = line.strip()\n            fname, pid, cid = line.split(' ')[0], line.split(' ')[1], int(line.split(' ')[2])\n            img_path = osp.join(self.dataset_dir, fname)\n            dataset.append((img_path, pid2label[pid], cid))\n\n        return dataset\n\n    def translate(self, transform: Callable, target_root: str):\n        raise NotImplementedError\n"
  },
  {
    "path": "tllib/vision/datasets/resisc45.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\n\nfrom torchvision.datasets.folder import ImageFolder\nimport random\n\n\nclass Resisc45(ImageFolder):\n    \"\"\"`Resisc45 <http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html>`_ dataset \\\n        is a scene classification task from remote sensing images. There are 45 classes, \\\n        containing 700 images each, including tennis court, ship, island, lake, \\\n        parking lot, sparse residential, or stadium. \\\n        The image size is RGB 256x256 pixels.\n\n    .. note:: You need to download the source data manually into `root` directory.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    \"\"\"\n    def __init__(self, root, split='train', download=False, **kwargs):\n        super(Resisc45, self).__init__(root, **kwargs)\n        random.seed(0)\n        random.shuffle(self.samples)\n        if split == 'train':\n            self.samples = self.samples[:25200]\n        else:\n            self.samples = self.samples[25200:]\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"Number of classes\"\"\"\n        return len(self.classes)\n"
  },
  {
    "path": "tllib/vision/datasets/retinopathy.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .imagelist import ImageList\n\n\nclass Retinopathy(ImageList):\n    \"\"\"`Retinopathy <https://www.kaggle.com/c/diabetic-retinopathy-detection/data>`_ dataset \\\n        consists of image-label pairs with high-resolution retina images, and labels that indicate \\\n        the presence of Diabetic Retinopahy (DR) in a 0-4 scale (No DR, Mild, Moderate, Severe, \\\n        or Proliferative DR).\n\n    .. note:: You need to download the source data manually into `root` directory.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    \"\"\"\n    CLASSES = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']\n\n    def __init__(self, root, split, download=False, **kwargs):\n\n        super(Retinopathy, self).__init__(os.path.join(root, split), Retinopathy.CLASSES, os.path.join(root, \"image_list\", \"{}.txt\".format(split)), **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/segmentation/__init__.py",
    "content": "from .segmentation_list import SegmentationList\nfrom .cityscapes import Cityscapes, FoggyCityscapes\nfrom .gta5 import GTA5\nfrom .synthia import Synthia\n\n__all__ = [\"SegmentationList\", \"Cityscapes\", \"GTA5\", \"Synthia\", \"FoggyCityscapes\"]\n"
  },
  {
    "path": "tllib/vision/datasets/segmentation/cityscapes.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .segmentation_list import SegmentationList\nfrom .._util import download as download_data\n\n\nclass Cityscapes(SegmentationList):\n    \"\"\"`Cityscapes <https://www.cityscapes-dataset.com/>`_ is a real-world semantic segmentation dataset collected\n    in driving scenarios.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``val``.\n        data_folder (str, optional): Sub-directory of the image. Default: 'leftImg8bit'.\n        label_folder (str, optional): Sub-directory of the label. Default: 'gtFine'.\n        mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.\n        transforms (callable, optional): A function/transform that  takes in  (PIL image, label) pair \\\n            and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.\n\n    .. note:: You need to download Cityscapes manually.\n        Ensure that there exist following files in the `root` directory before you using this class.\n        ::\n            leftImg8bit/\n                train/\n                val/\n                test/\n            gtFine/\n                train/\n                val/\n                test/\n    \"\"\"\n\n    CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',\n               'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',\n               'bicycle']\n\n    ID_TO_TRAIN_ID = {\n        7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,\n        19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,\n        26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18\n    }\n    TRAIN_ID_TO_COLOR = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),\n                                  (190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),\n                                  (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),\n                                  (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100),\n                                  (0, 0, 230), (119, 11, 32), [0, 0, 0]]\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/a23536bb8a724e91af39/?dl=1\"),\n    ]\n    EVALUATE_CLASSES = CLASSES\n\n    def __init__(self, root, split='train', data_folder='leftImg8bit', label_folder='gtFine', **kwargs):\n        assert split in ['train', 'val']\n\n        # download meta information from Internet\n        list(map(lambda args: download_data(root, *args), self.download_list))\n        data_list_file = os.path.join(root, \"image_list\", \"{}.txt\".format(split))\n        self.split = split\n        super(Cityscapes, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file,\n                                         os.path.join(data_folder, split), os.path.join(label_folder, split),\n                                         id_to_train_id=Cityscapes.ID_TO_TRAIN_ID,\n                                         train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs)\n\n    def parse_label_file(self, label_list_file):\n        with open(label_list_file, \"r\") as f:\n            label_list = [line.strip().replace(\"leftImg8bit\", \"gtFine_labelIds\") for line in f.readlines()]\n        return label_list\n\n\nclass FoggyCityscapes(Cityscapes):\n    \"\"\"`Foggy Cityscapes <https://www.cityscapes-dataset.com/>`_ is a real-world semantic segmentation dataset collected\n    in foggy driving scenarios.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``val``.\n        data_folder (str, optional): Sub-directory of the image. Default: 'leftImg8bit'.\n        label_folder (str, optional): Sub-directory of the label. Default: 'gtFine'.\n        beta (float, optional): The parameter for foggy. Choices includes: 0.005, 0.01, 0.02. Default: 0.02\n        mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.\n        transforms (callable, optional): A function/transform that  takes in  (PIL image, label) pair \\\n            and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.\n\n    .. note:: You need to download Cityscapes manually.\n        Ensure that there exist following files in the `root` directory before you using this class.\n        ::\n            leftImg8bit_foggy/\n                train/\n                val/\n                test/\n            gtFine/\n                train/\n                val/\n                test/\n    \"\"\"\n    def __init__(self, root, split='train', data_folder='leftImg8bit_foggy', label_folder='gtFine', beta=0.02, **kwargs):\n        assert beta in [0.02, 0.01, 0.005]\n        self.beta = beta\n        super(FoggyCityscapes, self).__init__(root, split, data_folder, label_folder, **kwargs)\n\n    def parse_data_file(self, file_name):\n        \"\"\"Parse file to image list\n\n        Args:\n            file_name (str): The path of data file\n\n        Returns:\n            List of image path\n        \"\"\"\n        with open(file_name, \"r\") as f:\n            data_list = [line.strip().replace(\"leftImg8bit\", \"leftImg8bit_foggy_beta_{}\".format(self.beta)) for line in f.readlines()]\n        return data_list\n"
  },
  {
    "path": "tllib/vision/datasets/segmentation/gta5.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .segmentation_list import SegmentationList\nfrom .cityscapes import Cityscapes\nfrom .._util import download as download_data\n\n\nclass GTA5(SegmentationList):\n    \"\"\"`GTA5 <https://download.visinf.tu-darmstadt.de/data/from_games/>`_\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``.\n        data_folder (str, optional): Sub-directory of the image. Default: 'images'.\n        label_folder (str, optional): Sub-directory of the label. Default: 'labels'.\n        mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.\n        transforms (callable, optional): A function/transform that  takes in  (PIL image, label) pair \\\n            and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.\n\n    .. note:: You need to download GTA5 manually.\n        Ensure that there exist following directories in the `root` directory before you using this class.\n        ::\n            images/\n            labels/\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/f719733e339544e9a330/?dl=1\"),\n    ]\n\n    def __init__(self, root, split='train', data_folder='images', label_folder='labels', **kwargs):\n        assert split in ['train']\n        # download meta information from Internet\n        list(map(lambda args: download_data(root, *args), self.download_list))\n        data_list_file = os.path.join(root, \"image_list\", \"{}.txt\".format(split))\n        self.split = split\n        super(GTA5, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder, label_folder,\n                                   id_to_train_id=Cityscapes.ID_TO_TRAIN_ID,\n                                   train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/segmentation/segmentation_list.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import Sequence, Optional, Dict, Callable\nfrom PIL import Image\nimport tqdm\nimport numpy as np\nfrom torch.utils import data\nimport torch\n\n\nclass SegmentationList(data.Dataset):\n    \"\"\"A generic Dataset class for domain adaptation in image segmentation\n\n    Args:\n        root (str): Root directory of dataset\n        classes (seq[str]): The names of all the classes\n        data_list_file (str): File to read the image list from.\n        label_list_file (str): File to read the label list from.\n        data_folder (str): Sub-directory of the image.\n        label_folder (str): Sub-directory of the label.\n        mean (seq[float]): mean BGR value. Normalize and convert to the image if not None. Default: None.\n        id_to_train_id (dict, optional): the map between the id on the label and the actual train id.\n        train_id_to_color (seq, optional): the map between the train id and the color.\n        transforms (callable, optional): A function/transform that  takes in  (PIL Image, label) pair \\\n            and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.\n\n    .. note:: In ``data_list_file``, each line is the relative path of an image.\n        If your data_list_file has different formats, please over-ride :meth:`~SegmentationList.parse_data_file`.\n        ::\n            source_dir/dog_xxx.png\n            target_dir/dog_xxy.png\n\n        In ``label_list_file``, each line is the relative path of an label.\n        If your label_list_file has different formats, please over-ride :meth:`~SegmentationList.parse_label_file`.\n\n    .. warning:: When mean is not None, please do not provide Normalize and ToTensor in transforms.\n\n    \"\"\"\n    def __init__(self, root: str, classes: Sequence[str], data_list_file: str, label_list_file: str,\n                 data_folder: str, label_folder: str,\n                 id_to_train_id: Optional[Dict] = None, train_id_to_color: Optional[Sequence] = None,\n                 transforms: Optional[Callable] = None):\n        self.root = root\n        self.classes = classes\n        self.data_list_file = data_list_file\n        self.label_list_file = label_list_file\n        self.data_folder = data_folder\n        self.label_folder = label_folder\n        self.ignore_label = 255\n        self.id_to_train_id = id_to_train_id\n        self.train_id_to_color = np.array(train_id_to_color)\n        self.data_list = self.parse_data_file(self.data_list_file)\n        self.label_list = self.parse_label_file(self.label_list_file)\n        self.transforms = transforms\n\n    def parse_data_file(self, file_name):\n        \"\"\"Parse file to image list\n\n        Args:\n            file_name (str): The path of data file\n\n        Returns:\n            List of image path\n        \"\"\"\n        with open(file_name, \"r\") as f:\n            data_list = [line.strip() for line in f.readlines()]\n        return data_list\n\n    def parse_label_file(self, file_name):\n        \"\"\"Parse file to label list\n\n        Args:\n            file_name (str): The path of data file\n\n        Returns:\n            List of label path\n        \"\"\"\n        with open(file_name, \"r\") as f:\n            label_list = [line.strip() for line in f.readlines()]\n        return label_list\n\n    def __len__(self):\n        return len(self.data_list)\n\n    def __getitem__(self, index):\n        image_name = self.data_list[index]\n        label_name = self.label_list[index]\n        image = Image.open(os.path.join(self.root, self.data_folder, image_name)).convert('RGB')\n        label = Image.open(os.path.join(self.root, self.label_folder, label_name))\n        image, label = self.transforms(image, label)\n\n        # remap label\n        if isinstance(label, torch.Tensor):\n            label = label.numpy()\n        label = np.asarray(label, np.int64)\n        label_copy = self.ignore_label * np.ones(label.shape, dtype=np.int64)\n        if self.id_to_train_id:\n            for k, v in self.id_to_train_id.items():\n                label_copy[label == k] = v\n\n        return image, label_copy.copy()\n\n    @property\n    def num_classes(self) -> int:\n        \"\"\"Number of classes\"\"\"\n        return len(self.classes)\n\n    def decode_target(self, target):\n        \"\"\" Decode label (each value is integer) into the corresponding RGB value.\n\n        Args:\n            target (numpy.array): label in shape H x W\n\n        Returns:\n            RGB label (PIL Image) in shape H x W x 3\n        \"\"\"\n        target = target.copy()\n        target[target == 255] = self.num_classes # unknown label is black on the RGB label\n        target = self.train_id_to_color[target]\n        return Image.fromarray(target.astype(np.uint8))\n\n    def collect_image_paths(self):\n        \"\"\"Return a list of the absolute path of all the images\"\"\"\n        return [os.path.join(self.root, self.data_folder, image_name) for image_name in self.data_list]\n\n    @staticmethod\n    def _save_pil_image(image, path):\n        os.makedirs(os.path.dirname(path), exist_ok=True)\n        image.save(path)\n\n    def translate(self, transform: Callable, target_root: str, color=False):\n        \"\"\" Translate an image and save it into a specified directory\n\n        Args:\n            transform (callable): a transform function that maps (image, label) pair from one domain to another domain\n            target_root (str): the root directory to save images and labels\n\n        \"\"\"\n        os.makedirs(target_root, exist_ok=True)\n        for image_name, label_name in zip(tqdm.tqdm(self.data_list), self.label_list):\n            image_path = os.path.join(target_root, self.data_folder, image_name)\n            label_path = os.path.join(target_root, self.label_folder, label_name)\n            if os.path.exists(image_path) and os.path.exists(label_path):\n                continue\n            image = Image.open(os.path.join(self.root, self.data_folder, image_name)).convert('RGB')\n            label = Image.open(os.path.join(self.root, self.label_folder, label_name))\n\n            translated_image, translated_label = transform(image, label)\n            self._save_pil_image(translated_image, image_path)\n            self._save_pil_image(translated_label, label_path)\n            if color:\n                colored_label = self.decode_target(np.array(translated_label))\n                file_name, file_ext = os.path.splitext(label_name)\n                self._save_pil_image(colored_label, os.path.join(target_root, self.label_folder,\n                                                                 \"{}_color{}\".format(file_name, file_ext)))\n\n    @property\n    def evaluate_classes(self):\n        \"\"\"The name of classes to be evaluated\"\"\"\n        return self.classes\n\n    @property\n    def ignore_classes(self):\n        \"\"\"The name of classes to be ignored\"\"\"\n        return list(set(self.classes) - set(self.evaluate_classes))"
  },
  {
    "path": "tllib/vision/datasets/segmentation/synthia.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom .segmentation_list import SegmentationList\nfrom .cityscapes import Cityscapes\nfrom .._util import download as download_data\n\n\nclass Synthia(SegmentationList):\n    \"\"\"`SYNTHIA <https://synthia-dataset.net/>`_\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``.\n        data_folder (str, optional): Sub-directory of the image. Default: 'RGB'.\n        label_folder (str, optional): Sub-directory of the label. Default: 'synthia_mapped_to_cityscapes'.\n        mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.\n        transforms (callable, optional): A function/transform that  takes in  (PIL image, label) pair \\\n            and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.\n\n    .. note:: You need to download GTA5 manually.\n        Ensure that there exist following directories in the `root` directory before you using this class.\n        ::\n            RGB/\n            synthia_mapped_to_cityscapes/\n    \"\"\"\n    ID_TO_TRAIN_ID = {\n        3: 0, 4: 1, 2: 2, 21: 3, 5: 4, 7: 5,\n        15: 6, 9: 7, 6: 8, 16: 9, 1: 10, 10: 11, 17: 12,\n        8: 13, 18: 14, 19: 15, 20: 16, 12: 17, 11: 18\n    }\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/1c652d518e0347e2800d/?dl=1\"),\n    ]\n\n    def __init__(self, root, split='train', data_folder='RGB', label_folder='synthia_mapped_to_cityscapes', **kwargs):\n        assert split in ['train']\n        # download meta information from Internet\n        list(map(lambda args: download_data(root, *args), self.download_list))\n        data_list_file = os.path.join(root, \"image_list\", \"{}.txt\".format(split))\n        super(Synthia, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder,\n                                      label_folder, id_to_train_id=Synthia.ID_TO_TRAIN_ID,\n                                      train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs)\n\n    @property\n    def evaluate_classes(self):\n        return [\n            'road', 'sidewalk', 'building', 'traffic light', 'traffic sign',\n            'vegetation', 'sky', 'person', 'rider', 'car', 'bus', 'motorcycle', 'bicycle'\n        ]\n"
  },
  {
    "path": "tllib/vision/datasets/stanford_cars.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass StanfordCars(ImageList):\n    \"\"\"`The Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ \\\n    contains 16,185 images of 196 classes of cars. \\\n    Each category has been split roughly in a 50-50 split. \\\n    There are 8,144 images for training and 8,041 images for testing.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        sample_rate (int): The sampling rates to sample random ``training`` images for each category.\n            Choices include 100, 50, 30, 15. Default: 100.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            train/\n            test/\n            image_list/\n                train_100.txt\n                train_50.txt\n                train_30.txt\n                train_15.txt\n                test.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/aeeb690e9886442aa267/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/fd80c30c120a42a08fd3/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/01e6b279f20440cb8bf9/?dl=1\"),\n    ]\n    image_list = {\n        \"train\": \"image_list/train_100.txt\",\n        \"train100\": \"image_list/train_100.txt\",\n        \"train50\": \"image_list/train_50.txt\",\n        \"train30\": \"image_list/train_30.txt\",\n        \"train15\": \"image_list/train_15.txt\",\n        \"test\": \"image_list/test.txt\",\n        \"test100\": \"image_list/test.txt\",\n    }\n    CLASSES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19',\n               '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36',\n               '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53',\n               '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70',\n               '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87',\n               '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103',\n               '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118',\n               '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133',\n               '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148',\n               '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163',\n               '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178',\n               '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '192', '193',\n               '194', '195', '196']\n\n    def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,\n                 **kwargs):\n\n        if split == 'train':\n            list_name = 'train' + str(sample_rate)\n            assert list_name in self.image_list\n            data_list_file = os.path.join(root, self.image_list[list_name])\n        else:\n            data_list_file = os.path.join(root, self.image_list['test'])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(StanfordCars, self).__init__(root, StanfordCars.CLASSES, data_list_file=data_list_file, **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/stanford_dogs.py",
    "content": "\"\"\"\n@author: Yifei Ji\n@contact: jiyf990330@163.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass StanfordDogs(ImageList):\n    \"\"\"`The Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ \\\n    contains 20,580 images of 120 breeds of dogs from around the world. \\\n    Each category is composed of exactly 100 training examples and around 72 testing examples.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        sample_rate (int): The sampling rates to sample random ``training`` images for each category.\n            Choices include 100, 50, 30, 15. Default: 100.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            train/\n            test/\n            image_list/\n                train_100.txt\n                train_50.txt\n                train_30.txt\n                train_15.txt\n                test.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/7685b13c549a4591b429/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/9f19a6d1b14b4f1e8d13/?dl=1\"),\n        (\"test\", \"test.tgz\", \"https://cloud.tsinghua.edu.cn/f/a497b21e31cc4bfc9d45/?dl=1\"),\n    ]\n    image_list = {\n        \"train\": \"image_list/train_100.txt\",\n        \"train100\": \"image_list/train_100.txt\",\n        \"train50\": \"image_list/train_50.txt\",\n        \"train30\": \"image_list/train_30.txt\",\n        \"train15\": \"image_list/train_15.txt\",\n        \"test\": \"image_list/test.txt\",\n        \"test100\": \"image_list/test.txt\",\n    }\n    CLASSES = ['n02085620-Chihuahua', 'n02085782-Japanese_spaniel', 'n02085936-Maltese_dog', 'n02086079-Pekinese',\n               'n02086240-Shih-Tzu',\n               'n02086646-Blenheim_spaniel', 'n02086910-papillon', 'n02087046-toy_terrier',\n               'n02087394-Rhodesian_ridgeback',\n               'n02088094-Afghan_hound', 'n02088238-basset', 'n02088364-beagle', 'n02088466-bloodhound',\n               'n02088632-bluetick', 'n02089078-black-and-tan_coonhound',\n               'n02089867-Walker_hound', 'n02089973-English_foxhound', 'n02090379-redbone', 'n02090622-borzoi',\n               'n02090721-Irish_wolfhound', 'n02091032-Italian_greyhound',\n               'n02091134-whippet', 'n02091244-Ibizan_hound', 'n02091467-Norwegian_elkhound', 'n02091635-otterhound',\n               'n02091831-Saluki', 'n02092002-Scottish_deerhound',\n               'n02092339-Weimaraner', 'n02093256-Staffordshire_bullterrier',\n               'n02093428-American_Staffordshire_terrier', 'n02093647-Bedlington_terrier', 'n02093754-Border_terrier',\n               'n02093859-Kerry_blue_terrier', 'n02093991-Irish_terrier', 'n02094114-Norfolk_terrier',\n               'n02094258-Norwich_terrier', 'n02094433-Yorkshire_terrier',\n               'n02095314-wire-haired_fox_terrier', 'n02095570-Lakeland_terrier', 'n02095889-Sealyham_terrier',\n               'n02096051-Airedale', 'n02096177-cairn', 'n02096294-Australian_terrier',\n               'n02096437-Dandie_Dinmont', 'n02096585-Boston_bull', 'n02097047-miniature_schnauzer',\n               'n02097130-giant_schnauzer', 'n02097209-standard_schnauzer',\n               'n02097298-Scotch_terrier', 'n02097474-Tibetan_terrier', 'n02097658-silky_terrier',\n               'n02098105-soft-coated_wheaten_terrier', 'n02098286-West_Highland_white_terrier',\n               'n02098413-Lhasa', 'n02099267-flat-coated_retriever', 'n02099429-curly-coated_retriever',\n               'n02099601-golden_retriever', 'n02099712-Labrador_retriever',\n               'n02099849-Chesapeake_Bay_retriever', 'n02100236-German_short-haired_pointer', 'n02100583-vizsla',\n               'n02100735-English_setter', 'n02100877-Irish_setter',\n               'n02101006-Gordon_setter', 'n02101388-Brittany_spaniel', 'n02101556-clumber',\n               'n02102040-English_springer', 'n02102177-Welsh_springer_spaniel', 'n02102318-cocker_spaniel',\n               'n02102480-Sussex_spaniel', 'n02102973-Irish_water_spaniel', 'n02104029-kuvasz', 'n02104365-schipperke',\n               'n02105056-groenendael', 'n02105162-malinois', 'n02105251-briard', 'n02105412-kelpie',\n               'n02105505-komondor', 'n02105641-Old_English_sheepdog', 'n02105855-Shetland_sheepdog',\n               'n02106030-collie', 'n02106166-Border_collie', 'n02106382-Bouvier_des_Flandres', 'n02106550-Rottweiler',\n               'n02106662-German_shepherd', 'n02107142-Doberman', 'n02107312-miniature_pinscher',\n               'n02107574-Greater_Swiss_Mountain_dog',\n               'n02107683-Bernese_mountain_dog', 'n02107908-Appenzeller', 'n02108000-EntleBucher', 'n02108089-boxer',\n               'n02108422-bull_mastiff', 'n02108551-Tibetan_mastiff',\n               'n02108915-French_bulldog', 'n02109047-Great_Dane', 'n02109525-Saint_Bernard', 'n02109961-Eskimo_dog',\n               'n02110063-malamute', 'n02110185-Siberian_husky',\n               'n02110627-affenpinscher', 'n02110806-basenji', 'n02110958-pug', 'n02111129-Leonberg',\n               'n02111277-Newfoundland', 'n02111500-Great_Pyrenees', 'n02111889-Samoyed', 'n02112018-Pomeranian',\n               'n02112137-chow', 'n02112350-keeshond', 'n02112706-Brabancon_griffon', 'n02113023-Pembroke',\n               'n02113186-Cardigan',\n               'n02113624-toy_poodle', 'n02113712-miniature_poodle', 'n02113799-standard_poodle',\n               'n02113978-Mexican_hairless', 'n02115641-dingo', 'n02115913-dhole', 'n02116738-African_hunting_dog']\n\n    def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,\n                 **kwargs):\n\n        if split == 'train':\n            list_name = 'train' + str(sample_rate)\n            assert list_name in self.image_list\n            data_list_file = os.path.join(root, self.image_list[list_name])\n        else:\n            data_list_file = os.path.join(root, self.image_list['test'])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(StanfordDogs, self).__init__(root, StanfordDogs.CLASSES, data_list_file=data_list_file, **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/sun397.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport os\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass SUN397(ImageList):\n    \"\"\"`SUN397 <https://vision.princeton.edu/projects/2010/SUN/>`_  is a dataset for scene understanding\n    with 108,754 images in 397 scene categories. The number of images varies across categories,\n    but there are at least 100 images per category. Note that the authors construct 10 partitions,\n    where each partition contains 50 training images and 50 testing images per class. We adopt partition 1.\n\n    Args:\n        root (str): Root directory of dataset\n        split (str, optional): The dataset split, supports ``train``, or ``test``.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    \"\"\"\n    dataset_url = (\"SUN397\", \"SUN397.tar.gz\", \"http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz\")\n    image_list_url = (\n        \"SUN397/image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/dec0775147c144ea9f75/?dl=1\")\n\n    def __init__(self, root, split='train', download=True, **kwargs):\n        if download:\n            download_data(root, *self.dataset_url)\n            download_data(os.path.join(root, 'SUN397'), *self.image_list_url)\n        else:\n            check_exits(root, \"SUN397\")\n            check_exits(root, \"SUN397/image_list\")\n\n        classes = list([str(i) for i in range(397)])\n        root = os.path.join(root, 'SUN397')\n        super(SUN397, self).__init__(root, classes, os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs)\n"
  },
  {
    "path": "tllib/vision/datasets/visda2017.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport os\nfrom typing import Optional\nfrom .imagelist import ImageList\nfrom ._util import download as download_data, check_exits\n\n\nclass VisDA2017(ImageList):\n    \"\"\"`VisDA-2017 <http://ai.bu.edu/visda-2017/assets/attachments/VisDA_2017.pdf>`_ Dataset\n\n    Args:\n        root (str): Root directory of dataset\n        task (str): The task (domain) to create dataset. Choices include ``'Synthetic'``: synthetic images and \\\n            ``'Real'``: real-world images.\n        download (bool, optional): If true, downloads the dataset from the internet and puts it \\\n            in root directory. If dataset is already downloaded, it is not downloaded again.\n        transform (callable, optional): A function/transform that  takes in an PIL image and returns a \\\n            transformed version. E.g, ``transforms.RandomCrop``.\n        target_transform (callable, optional): A function/transform that takes in the target and transforms it.\n\n    .. note:: In `root`, there will exist following files after downloading.\n        ::\n            train/\n                aeroplance/\n                    *.png\n                    ...\n            validation/\n            image_list/\n                train.txt\n                validation.txt\n    \"\"\"\n    download_list = [\n        (\"image_list\", \"image_list.zip\", \"https://cloud.tsinghua.edu.cn/f/c107de37b8094c5398dc/?dl=1\"),\n        (\"train\", \"train.tgz\", \"https://cloud.tsinghua.edu.cn/f/c5f3ce59139144ec8221/?dl=1\"),\n        (\"validation\", \"validation.tgz\", \"https://cloud.tsinghua.edu.cn/f/da70e4b1cf514ecea562/?dl=1\")\n    ]\n    image_list = {\n        \"Synthetic\": \"image_list/train.txt\",\n        \"Real\": \"image_list/validation.txt\"\n    }\n    CLASSES = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife',\n               'motorcycle', 'person', 'plant', 'skateboard', 'train', 'truck']\n\n    def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):\n        assert task in self.image_list\n        data_list_file = os.path.join(root, self.image_list[task])\n\n        if download:\n            list(map(lambda args: download_data(root, *args), self.download_list))\n        else:\n            list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))\n\n        super(VisDA2017, self).__init__(root, VisDA2017.CLASSES, data_list_file=data_list_file, **kwargs)\n\n    @classmethod\n    def domains(cls):\n        return list(cls.image_list.keys())"
  },
  {
    "path": "tllib/vision/models/__init__.py",
    "content": "from .resnet import *\nfrom .digits import *\n\n__all__ = ['resnet', 'digits']\n"
  },
  {
    "path": "tllib/vision/models/digits.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\n\n\nclass LeNet(nn.Sequential):\n    def __init__(self, num_classes=10):\n        super(LeNet, self).__init__(\n            nn.Conv2d(1, 20, kernel_size=5),\n            nn.MaxPool2d(2),\n            nn.ReLU(),\n            nn.Conv2d(20, 50, kernel_size=5),\n            nn.Dropout2d(p=0.5),\n            nn.MaxPool2d(2),\n            nn.ReLU(),\n            nn.Flatten(start_dim=1),\n            nn.Linear(50 * 4 * 4, 500),\n            nn.ReLU(),\n            nn.Dropout(p=0.5),\n        )\n        self.num_classes = num_classes\n        self.out_features = 500\n\n    def copy_head(self):\n        return nn.Linear(500, self.num_classes)\n\n\nclass DTN(nn.Sequential):\n    def __init__(self, num_classes=10):\n        super(DTN, self).__init__(\n            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),\n            nn.BatchNorm2d(64),\n            nn.Dropout2d(0.1),\n            nn.ReLU(),\n            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),\n            nn.BatchNorm2d(128),\n            nn.Dropout2d(0.3),\n            nn.ReLU(),\n            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),\n            nn.BatchNorm2d(256),\n            nn.Dropout2d(0.5),\n            nn.ReLU(),\n            nn.Flatten(start_dim=1),\n            nn.Linear(256 * 4 * 4, 512),\n            nn.BatchNorm1d(512),\n            nn.ReLU(),\n            nn.Dropout(),\n        )\n        self.num_classes = num_classes\n        self.out_features = 512\n\n    def copy_head(self):\n        return nn.Linear(512, self.num_classes)\n\n\n\ndef lenet(pretrained=False, **kwargs):\n    \"\"\"LeNet model from\n    `\"Gradient-based learning applied to document recognition\" <http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf>`_\n\n    Args:\n        num_classes (int): number of classes. Default: 10\n\n    .. note::\n        The input image size must be 28 x 28.\n\n    \"\"\"\n    return LeNet(**kwargs)\n\n\ndef dtn(pretrained=False, **kwargs):\n    \"\"\" DTN model\n\n    Args:\n        num_classes (int): number of classes. Default: 10\n\n    .. note::\n        The input image size must be 32 x 32.\n\n    \"\"\"\n    return DTN(**kwargs)"
  },
  {
    "path": "tllib/vision/models/keypoint_detection/__init__.py",
    "content": "from .pose_resnet import *\nfrom . import loss\n\n__all__ = ['pose_resnet']"
  },
  {
    "path": "tllib/vision/models/keypoint_detection/loss.py",
    "content": "\"\"\"\nModified from https://github.com/microsoft/human-pose-estimation.pytorch\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass JointsMSELoss(nn.Module):\n    \"\"\"\n    Typical MSE loss for keypoint detection.\n\n    Args:\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output. Default: ``'mean'``\n\n    Inputs:\n        - output (tensor): heatmap predictions\n        - target (tensor): heatmap labels\n        - target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None.\n\n    Shape:\n        - output: :math:`(minibatch, K, H, W)` where K means the number of keypoints,\n          H and W is the height and width of the heatmap respectively.\n        - target: :math:`(minibatch, K, H, W)`.\n        - target_weight: :math:`(minibatch, K)`.\n        - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`.\n\n    \"\"\"\n    def __init__(self, reduction='mean'):\n        super(JointsMSELoss, self).__init__()\n        self.criterion = nn.MSELoss(reduction='none')\n        self.reduction = reduction\n\n    def forward(self, output, target, target_weight=None):\n        B, K, _, _ = output.shape\n        heatmaps_pred = output.reshape((B, K, -1))\n        heatmaps_gt = target.reshape((B, K, -1))\n        loss = self.criterion(heatmaps_pred, heatmaps_gt) * 0.5\n        if target_weight is not None:\n            loss = loss * target_weight.view((B, K, 1))\n        if self.reduction == 'mean':\n            return loss.mean()\n        elif self.reduction == 'none':\n            return loss.mean(dim=-1)\n\n\nclass JointsKLLoss(nn.Module):\n    \"\"\"\n    KL Divergence for keypoint detection proposed by\n    `Regressive Domain Adaptation for Unsupervised Keypoint Detection <https://arxiv.org/abs/2103.06175>`_.\n\n    Args:\n        reduction (str, optional): Specifies the reduction to apply to the output:\n          ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,\n          ``'mean'``: the sum of the output will be divided by the number of\n          elements in the output. Default: ``'mean'``\n\n    Inputs:\n        - output (tensor): heatmap predictions\n        - target (tensor): heatmap labels\n        - target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None.\n\n    Shape:\n        - output: :math:`(minibatch, K, H, W)` where K means the number of keypoints,\n          H and W is the height and width of the heatmap respectively.\n        - target: :math:`(minibatch, K, H, W)`.\n        - target_weight: :math:`(minibatch, K)`.\n        - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`.\n\n    \"\"\"\n    def __init__(self, reduction='mean', epsilon=0.):\n        super(JointsKLLoss, self).__init__()\n        self.criterion = nn.KLDivLoss(reduction='none')\n        self.reduction = reduction\n        self.epsilon = epsilon\n\n    def forward(self, output, target, target_weight=None):\n        B, K, _, _ = output.shape\n        heatmaps_pred = output.reshape((B, K, -1))\n        heatmaps_pred = F.log_softmax(heatmaps_pred, dim=-1)\n        heatmaps_gt = target.reshape((B, K, -1))\n        heatmaps_gt = heatmaps_gt + self.epsilon\n        heatmaps_gt = heatmaps_gt / heatmaps_gt.sum(dim=-1, keepdims=True)\n        loss = self.criterion(heatmaps_pred, heatmaps_gt).sum(dim=-1)\n        if target_weight is not None:\n            loss = loss * target_weight.view((B, K))\n        if self.reduction == 'mean':\n            return loss.mean()\n        elif self.reduction == 'none':\n            return loss.mean(dim=-1)\n"
  },
  {
    "path": "tllib/vision/models/keypoint_detection/pose_resnet.py",
    "content": "\"\"\"\nModified from https://github.com/microsoft/human-pose-estimation.pytorch\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\nfrom ..resnet import _resnet, Bottleneck\n\n\nclass Upsampling(nn.Sequential):\n    \"\"\"\n    3-layers deconvolution used in `Simple Baseline <https://arxiv.org/abs/1804.06208>`_.\n    \"\"\"\n    def __init__(self, in_channel=2048, hidden_dims=(256, 256, 256), kernel_sizes=(4, 4, 4), bias=False):\n        assert len(hidden_dims) == len(kernel_sizes), \\\n            'ERROR: len(hidden_dims) is different len(kernel_sizes)'\n\n        layers = []\n        for hidden_dim, kernel_size in zip(hidden_dims, kernel_sizes):\n            if kernel_size == 4:\n                padding = 1\n                output_padding = 0\n            elif kernel_size == 3:\n                padding = 1\n                output_padding = 1\n            elif kernel_size == 2:\n                padding = 0\n                output_padding = 0\n            else:\n                raise NotImplementedError(\"kernel_size is {}\".format(kernel_size))\n\n            layers.append(\n                nn.ConvTranspose2d(\n                    in_channels=in_channel,\n                    out_channels=hidden_dim,\n                    kernel_size=kernel_size,\n                    stride=2,\n                    padding=padding,\n                    output_padding=output_padding,\n                    bias=bias))\n            layers.append(nn.BatchNorm2d(hidden_dim))\n            layers.append(nn.ReLU(inplace=True))\n            in_channel = hidden_dim\n\n        super(Upsampling, self).__init__(*layers)\n\n        # init following Simple Baseline\n        for name, m in self.named_modules():\n            if isinstance(m, nn.ConvTranspose2d):\n                nn.init.normal_(m.weight, std=0.001)\n                if bias:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n\nclass PoseResNet(nn.Module):\n    \"\"\"\n    `Simple Baseline <https://arxiv.org/abs/1804.06208>`_ for keypoint detection.\n\n    Args:\n        backbone (torch.nn.Module): Backbone to extract 2-d features from data\n        upsampling (torch.nn.Module): Layer to upsample image feature to heatmap size\n        feature_dim (int): The dimension of the features from upsampling layer.\n        num_keypoints (int): Number of keypoints\n        finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: False\n    \"\"\"\n    def __init__(self, backbone, upsampling, feature_dim, num_keypoints, finetune=False):\n        super(PoseResNet, self).__init__()\n        self.backbone = backbone\n        self.upsampling = upsampling\n        self.head = nn.Conv2d(in_channels=feature_dim, out_channels=num_keypoints, kernel_size=1, stride=1, padding=0)\n        self.finetune = finetune\n        for m in self.head.modules():\n            nn.init.normal_(m.weight, std=0.001)\n            nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        x = self.backbone(x)\n        x = self.upsampling(x)\n        x = self.head(x)\n        return x\n\n    def get_parameters(self, lr=1.):\n        return [\n            {'params': self.backbone.parameters(), 'lr': 0.1 * lr if self.finetune else lr},\n            {'params': self.upsampling.parameters(), 'lr': lr},\n            {'params': self.head.parameters(), 'lr': lr},\n        ]\n\n\ndef _pose_resnet(arch, num_keypoints, block, layers, pretrained_backbone, deconv_with_bias, finetune=False, progress=True, **kwargs):\n    backbone = _resnet(arch, block, layers, pretrained_backbone, progress, **kwargs)\n    upsampling = Upsampling(backbone.out_features, bias=deconv_with_bias)\n    model = PoseResNet(backbone, upsampling, 256, num_keypoints, finetune)\n    return model\n\n\ndef pose_resnet101(num_keypoints, pretrained_backbone=True, deconv_with_bias=False, finetune=False, progress=True, **kwargs):\n    \"\"\"Constructs a Simple Baseline model with a ResNet-101 backbone.\n\n    Args:\n        num_keypoints (int): number of keypoints\n        pretrained_backbone (bool, optional): If True, returns a model pre-trained on ImageNet. Default: True.\n        deconv_with_bias (bool, optional): Whether use bias in the deconvolution layer. Default: False\n        finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: False\n        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True\n    \"\"\"\n    return _pose_resnet('resnet101', num_keypoints, Bottleneck, [3, 4, 23, 3], pretrained_backbone, deconv_with_bias, finetune, progress, **kwargs)"
  },
  {
    "path": "tllib/vision/models/object_detection/__init__.py",
    "content": "from . import meta_arch\nfrom . import roi_heads\nfrom . import proposal_generator\nfrom . import backbone\n"
  },
  {
    "path": "tllib/vision/models/object_detection/backbone/__init__.py",
    "content": "from .vgg import VGG, build_vgg_fpn_backbone\r\n"
  },
  {
    "path": "tllib/vision/models/object_detection/backbone/mmdetection/vgg.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\n# Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/vgg.py\nfrom mmcv.runner import load_checkpoint\nimport torch.nn as nn\nfrom .weight_init import constant_init, kaiming_init, normal_init\n\n\ndef conv3x3(in_planes, out_planes, dilation=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        padding=dilation,\n        dilation=dilation)\n\n\ndef make_vgg_layer(inplanes,\n                   planes,\n                   num_blocks,\n                   dilation=1,\n                   with_bn=False,\n                   ceil_mode=False):\n    layers = []\n    for _ in range(num_blocks):\n        layers.append(conv3x3(inplanes, planes, dilation))\n        if with_bn:\n            layers.append(nn.BatchNorm2d(planes))\n        layers.append(nn.ReLU(inplace=True))\n        inplanes = planes\n    layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))\n\n    return layers\n\n\nclass VGG(nn.Module):\n    \"\"\"VGG backbone.\n    Args:\n        depth (int): Depth of vgg, from {11, 13, 16, 19}.\n        with_bn (bool): Use BatchNorm or not.\n        num_classes (int): number of classes for classification.\n        num_stages (int): VGG stages, normally 5.\n        dilations (Sequence[int]): Dilation of each stage.\n        out_indices (Sequence[int]): Output from which stages.\n        frozen_stages (int): Stages to be frozen (all param fixed). -1 means\n            not freezing any parameters.\n        bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze\n            running stats (mean and var).\n        bn_frozen (bool): Whether to freeze weight and bias of BN layers.\n    \"\"\"\n\n    arch_settings = {\n        11: (1, 1, 2, 2, 2),\n        13: (2, 2, 2, 2, 2),\n        16: (2, 2, 3, 3, 3),\n        19: (2, 2, 4, 4, 4)\n    }\n\n    def __init__(self,\n                 depth,\n                 with_bn=False,\n                 num_classes=-1,\n                 num_stages=5,\n                 dilations=(1, 1, 1, 1, 1),\n                 out_indices=(0, 1, 2, 3, 4),\n                 frozen_stages=-1,\n                 bn_eval=True,\n                 bn_frozen=False,\n                 ceil_mode=False,\n                 with_last_pool=True):\n        super(VGG, self).__init__()\n        if depth not in self.arch_settings:\n            raise KeyError('invalid depth {} for vgg'.format(depth))\n        assert num_stages >= 1 and num_stages <= 5\n        stage_blocks = self.arch_settings[depth]\n        self.stage_blocks = stage_blocks[:num_stages]\n        assert len(dilations) == num_stages\n        assert max(out_indices) <= num_stages\n\n        self.num_classes = num_classes\n        self.out_indices = out_indices\n        self.frozen_stages = frozen_stages\n        self.bn_eval = bn_eval\n        self.bn_frozen = bn_frozen\n\n        self.inplanes = 3\n        start_idx = 0\n        vgg_layers = []\n        self.range_sub_modules = []\n        for i, num_blocks in enumerate(self.stage_blocks):\n            num_modules = num_blocks * (2 + with_bn) + 1\n            end_idx = start_idx + num_modules\n            dilation = dilations[i]\n            planes = 64 * 2**i if i < 4 else 512\n            vgg_layer = make_vgg_layer(\n                self.inplanes,\n                planes,\n                num_blocks,\n                dilation=dilation,\n                with_bn=with_bn,\n                ceil_mode=ceil_mode)\n            vgg_layers.extend(vgg_layer)\n            self.inplanes = planes\n            self.range_sub_modules.append([start_idx, end_idx])\n            start_idx = end_idx\n        if not with_last_pool:\n            vgg_layers.pop(-1)\n            self.range_sub_modules[-1][1] -= 1\n        self.module_name = 'features'\n        self.add_module(self.module_name, nn.Sequential(*vgg_layers))\n\n        if self.num_classes > 0:\n            self.classifier = nn.Sequential(\n                nn.Linear(512 * 7 * 7, 4096),\n                nn.ReLU(True),\n                nn.Dropout(),\n                nn.Linear(4096, 4096),\n                nn.ReLU(True),\n                nn.Dropout(),\n                nn.Linear(4096, num_classes),\n            )\n\n        # initialize the model by random\n        self.init_weights()\n        # Optionally freeze (requires_grad=False) parts of the backbone\n        self._freeze_backbone(self.frozen_stages)\n\n    def _freeze_backbone(self, freeze_at):\n        if freeze_at < 0:\n            return\n\n        vgg_layers = getattr(self, self.module_name)\n        for i in range(freeze_at):\n            for j in range(*self.range_sub_modules[i]):\n                mod = vgg_layers[j]\n                mod.eval()\n                for param in mod.parameters():\n                    param.requires_grad = False\n\n    def init_weights(self, pretrained=None):\n        if isinstance(pretrained, str):\n            load_checkpoint(self, pretrained, strict=False)\n        elif pretrained is None:\n            for m in self.modules():\n                if isinstance(m, nn.Conv2d):\n                    kaiming_init(m)\n                elif isinstance(m, nn.BatchNorm2d):\n                    constant_init(m, 1)\n                elif isinstance(m, nn.Linear):\n                    normal_init(m, std=0.01)\n        else:\n            raise TypeError('pretrained must be a str or None')\n\n    def forward(self, x):\n        outs = []\n        vgg_layers = getattr(self, self.module_name)\n        for i, num_blocks in enumerate(self.stage_blocks):\n            for j in range(*self.range_sub_modules[i]):\n                vgg_layer = vgg_layers[j]\n                x = vgg_layer(x)\n            if i in self.out_indices:\n                outs.append(x)\n        if self.num_classes > 0:\n            x = x.view(x.size(0), -1)\n            x = self.classifier(x)\n            outs.append(x)\n        if len(outs) == 1:\n            return outs[0]\n        else:\n            return tuple(outs)"
  },
  {
    "path": "tllib/vision/models/object_detection/backbone/mmdetection/weight_init.py",
    "content": "# Copyright (c) Open-MMLab. All rights reserved.\n# Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/weight_init.py\n\nimport numpy as np\nimport torch.nn as nn\n\n\ndef constant_init(module, val, bias=0):\n    nn.init.constant_(module.weight, val)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)\n\n\ndef xavier_init(module, gain=1, bias=0, distribution='normal'):\n    assert distribution in ['uniform', 'normal']\n    if distribution == 'uniform':\n        nn.init.xavier_uniform_(module.weight, gain=gain)\n    else:\n        nn.init.xavier_normal_(module.weight, gain=gain)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)\n\n\ndef normal_init(module, mean=0, std=1, bias=0):\n    nn.init.normal_(module.weight, mean, std)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)\n\n\ndef uniform_init(module, a=0, b=1, bias=0):\n    nn.init.uniform_(module.weight, a, b)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)\n\n\ndef kaiming_init(module,\n                 a=0,\n                 mode='fan_out',\n                 nonlinearity='relu',\n                 bias=0,\n                 distribution='normal'):\n    assert distribution in ['uniform', 'normal']\n    if distribution == 'uniform':\n        nn.init.kaiming_uniform_(\n            module.weight, a=a, mode=mode, nonlinearity=nonlinearity)\n    else:\n        nn.init.kaiming_normal_(\n            module.weight, a=a, mode=mode, nonlinearity=nonlinearity)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)\n\n\ndef caffe2_xavier_init(module, bias=0):\n    # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch\n    # Acknowledgment to FAIR's internal code\n    kaiming_init(\n        module,\n        a=1,\n        mode='fan_in',\n        nonlinearity='leaky_relu',\n        distribution='uniform')"
  },
  {
    "path": "tllib/vision/models/object_detection/backbone/vgg.py",
    "content": "# referece from https://github.com/chengchunhsu/EveryPixelMatters\r\nimport torch\r\nimport torch.nn.functional as F\r\nfrom torch import nn\r\nfrom detectron2.modeling.backbone import Backbone, BACKBONE_REGISTRY\r\nfrom .mmdetection.vgg import VGG\r\n\r\n\r\nclass FPN(nn.Module):\r\n    \"\"\"\r\n    Module that adds FPN on top of a list of feature maps.\r\n    The feature maps are currently supposed to be in increasing depth\r\n    order, and must be consecutive\r\n    \"\"\"\r\n\r\n    def __init__(\r\n        self, in_channels_list, out_channels, conv_block, top_blocks=None\r\n    ):\r\n        \"\"\"\r\n        Arguments:\r\n            in_channels_list (list[int]): number of channels for each feature map that\r\n                will be fed\r\n            out_channels (int): number of channels of the FPN representation\r\n            top_blocks (nn.Module or None): if provided, an extra operation will\r\n                be performed on the output of the last (smallest resolution)\r\n                FPN output, and the result will extend the result list\r\n        \"\"\"\r\n        super(FPN, self).__init__()\r\n        self.inner_blocks = []\r\n        self.layer_blocks = []\r\n        for idx, in_channels in enumerate(in_channels_list, 1):\r\n            inner_block = \"fpn_inner{}\".format(idx)\r\n            layer_block = \"fpn_layer{}\".format(idx)\r\n\r\n            if in_channels == 0:\r\n                continue\r\n            inner_block_module = conv_block(in_channels, out_channels, 1)\r\n            layer_block_module = conv_block(out_channels, out_channels, 3, 1)\r\n            self.add_module(inner_block, inner_block_module)\r\n            self.add_module(layer_block, layer_block_module)\r\n            self.inner_blocks.append(inner_block)\r\n            self.layer_blocks.append(layer_block)\r\n        self.top_blocks = top_blocks\r\n\r\n    def forward(self, x):\r\n        \"\"\"\r\n        Arguments:\r\n            x (list[Tensor]): feature maps for each feature level.\r\n        Returns:\r\n            results (tuple[Tensor]): feature maps after FPN layers.\r\n                They are ordered from highest resolution first.\r\n        \"\"\"\r\n        last_inner = getattr(self, self.inner_blocks[-1])(x[-1])\r\n        results = []\r\n        results.append(getattr(self, self.layer_blocks[-1])(last_inner))\r\n        for feature, inner_block, layer_block in zip(\r\n            x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]\r\n        ):\r\n            if not inner_block:\r\n                continue\r\n            # inner_top_down = F.interpolate(last_inner, scale_factor=2, mode=\"nearest\")\r\n            inner_lateral = getattr(self, inner_block)(feature)\r\n            # TODO use size instead of scale to make it robust to different sizes\r\n            inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:],\r\n            mode='bilinear', align_corners=False)\r\n            last_inner = inner_lateral + inner_top_down\r\n            results.insert(0, getattr(self, layer_block)(last_inner))\r\n\r\n        if isinstance(self.top_blocks, LastLevelP6P7):\r\n            last_results = self.top_blocks(x[-1], results[-1])\r\n            results.extend(last_results)\r\n        elif isinstance(self.top_blocks, LastLevelMaxPool):\r\n            last_results = self.top_blocks(results[-1])\r\n            results.extend(last_results)\r\n\r\n        return tuple(results)\r\n\r\n\r\nclass LastLevelMaxPool(nn.Module):\r\n    def forward(self, x):\r\n        return [F.max_pool2d(x, 1, 2, 0)]\r\n\r\n\r\nclass LastLevelP6P7(nn.Module):\r\n    \"\"\"\r\n    This module is used in RetinaNet to generate extra layers, P6 and P7.\r\n    \"\"\"\r\n    def __init__(self, in_channels, out_channels):\r\n        super(LastLevelP6P7, self).__init__()\r\n        self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)\r\n        self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)\r\n        for module in [self.p6, self.p7]:\r\n            nn.init.kaiming_uniform_(module.weight, a=1)\r\n            nn.init.constant_(module.bias, 0)\r\n        self.use_P5 = in_channels == out_channels\r\n\r\n    def forward(self, c5, p5):\r\n        x = p5 if self.use_P5 else c5\r\n        p6 = self.p6(x)\r\n        p7 = self.p7(F.relu(p6))\r\n        return [p6, p7]\r\n\r\n\r\nclass _NewEmptyTensorOp(torch.autograd.Function):\r\n    @staticmethod\r\n    def forward(ctx, x, new_shape):\r\n        ctx.shape = x.shape\r\n        return x.new_empty(new_shape)\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad):\r\n        shape = ctx.shape\r\n        return _NewEmptyTensorOp.apply(grad, shape), None\r\n\r\n\r\nclass Conv2d(torch.nn.Conv2d):\r\n    def forward(self, x):\r\n        if x.numel() > 0:\r\n            return super(Conv2d, self).forward(x)\r\n        # get output shape\r\n\r\n        output_shape = [\r\n            (i + 2 * p - (di * (k - 1) + 1)) // d + 1\r\n            for i, p, di, k, d in zip(\r\n                x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride\r\n            )\r\n        ]\r\n        output_shape = [x.shape[0], self.weight.shape[0]] + output_shape\r\n        return _NewEmptyTensorOp.apply(x, output_shape)\r\n\r\n\r\ndef conv_with_kaiming_uniform():\r\n    def make_conv(\r\n        in_channels, out_channels, kernel_size, stride=1, dilation=1\r\n    ):\r\n        conv = Conv2d(\r\n            in_channels,\r\n            out_channels,\r\n            kernel_size=kernel_size,\r\n            stride=stride,\r\n            padding=dilation * (kernel_size - 1) // 2,\r\n            dilation=dilation,\r\n            bias=True\r\n        )\r\n        # Caffe2 implementation uses XavierFill, which in fact\r\n        # corresponds to kaiming_uniform_ in PyTorch\r\n        nn.init.kaiming_uniform_(conv.weight, a=1)\r\n        nn.init.constant_(conv.bias, 0)\r\n        module = [conv,]\r\n        if len(module) > 1:\r\n            return nn.Sequential(*module)\r\n        return conv\r\n\r\n    return make_conv\r\n\r\n\r\nclass VGGFPN(Backbone):\r\n    def __init__(self, body, fpn):\r\n        super(VGGFPN, self).__init__()\r\n        self.body = body\r\n        self.fpn = fpn\r\n        self._out_features = [\"p3\", \"p4\", \"p5\", \"p6\", \"p7\"]\r\n        self._out_feature_channels = {\r\n            \"p3\": 256, \"p4\": 256, \"p5\": 256, \"p6\": 256, \"p7\": 256\r\n        }\r\n        self._out_feature_strides = {\r\n            \"p3\": 8, \"p4\": 16, \"p5\": 32, \"p6\": 64, \"p7\": 128\r\n        }\r\n\r\n    def forward(self, x):\r\n        # print(x.shape)\r\n        f = self.body(x)\r\n        f = self.fpn(f)\r\n        return {\r\n            name: feature for name, feature in zip(self._out_features, f)\r\n        }\r\n\r\n\r\n@BACKBONE_REGISTRY.register()\r\ndef build_vgg_fpn_backbone(cfg, input_shape):\r\n    body = VGG(depth=16, with_last_pool=True, frozen_stages=2)\r\n    body.init_weights(cfg.MODEL.WEIGHTS)\r\n    in_channels_stage2 = 128  # default: cfg.MODEL.RESNETS.RES2_OUT_CHANNELS (256)\r\n    out_channels = 256  # default: cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS (256)\r\n    in_channels_p6p7 = out_channels\r\n    fpn = FPN(\r\n        in_channels_list=[\r\n            0,\r\n            0,\r\n            in_channels_stage2 * 2,\r\n            in_channels_stage2 * 4,\r\n            in_channels_stage2 * 4,  # in_channels_stage2 * 8\r\n        ],\r\n        out_channels=out_channels,\r\n        conv_block=conv_with_kaiming_uniform(),\r\n        top_blocks=LastLevelP6P7(in_channels_p6p7, out_channels),\r\n    )\r\n    model = VGGFPN(body, fpn)\r\n    model.out_channels = out_channels\r\n    return model"
  },
  {
    "path": "tllib/vision/models/object_detection/meta_arch/__init__.py",
    "content": "from .rcnn import TLGeneralizedRCNN\nfrom .retinanet import TLRetinaNet"
  },
  {
    "path": "tllib/vision/models/object_detection/meta_arch/rcnn.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Tuple, Dict\nimport torch\nfrom detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN as GeneralizedRCNNBase, get_event_storage\nfrom detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY\n\n\n@META_ARCH_REGISTRY.register()\nclass TLGeneralizedRCNN(GeneralizedRCNNBase):\n    \"\"\"\n    Generalized R-CNN for Transfer Learning.\n    Similar to that in in Supervised Learning, TLGeneralizedRCNN has the following three components:\n    1. Per-image feature extraction (aka backbone)\n    2. Region proposal generation\n    3. Per-region feature extraction and prediction\n\n    Different from that in Supervised Learning, TLGeneralizedRCNN\n    1. accepts unlabeled images during training (return no losses)\n    2. return both detection outputs, features, and losses during training\n\n    Args:\n        backbone: a backbone module, must follow detectron2's backbone interface\n        proposal_generator: a module that generates proposals using backbone features\n        roi_heads: a ROI head that performs per-region computation\n        pixel_mean, pixel_std: list or tuple with #channels element,\n            representing the per-channel mean and std to be used to normalize\n            the input image\n        input_format: describe the meaning of channels of input. Needed by visualization\n        vis_period: the period to run visualization. Set to 0 to disable.\n        finetune (bool): whether finetune the detector or train from scratch. Default: True\n\n    Inputs:\n        - batched_inputs: a list, batched outputs of :class:`DatasetMapper`.\n          Each item in the list contains the inputs for one image.\n          For now, each item in the list is a dict that contains:\n            * image: Tensor, image in (C, H, W) format.\n            * instances (optional): groundtruth :class:`Instances`\n            * proposals (optional): :class:`Instances`, precomputed proposals.\n            * \"height\", \"width\" (int): the output resolution of the model, used in inference.\n              See :meth:`postprocess` for details.\n        - labeled (bool, optional): whether has ground-truth label\n\n    Outputs:\n        - outputs: A list of dict where each dict is the output for one input image.\n          The dict contains a key \"instances\" whose value is a :class:`Instances`\n          and a key \"features\" whose value is the features of middle layers.\n          The :class:`Instances` object has the following keys:\n          \"pred_boxes\", \"pred_classes\", \"scores\", \"pred_masks\", \"pred_keypoints\"\n        - losses: A dict of different losses\n    \"\"\"\n\n    def __init__(self, *args, finetune=False, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.finetune = finetune\n\n    def forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]], labeled=True):\n        \"\"\"\"\"\"\n        if not self.training:\n            return self.inference(batched_inputs)\n\n        images = self.preprocess_image(batched_inputs)\n        if \"instances\" in batched_inputs[0] and labeled:\n            gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\n        else:\n            gt_instances = None\n\n        features = self.backbone(images.tensor)\n\n        if self.proposal_generator is not None:\n            proposals, proposal_losses = self.proposal_generator(images, features, gt_instances, labeled)\n        else:\n            assert \"proposals\" in batched_inputs[0]\n            proposals = [x[\"proposals\"].to(self.device) for x in batched_inputs]\n            proposal_losses = {}\n\n        outputs, detector_losses = self.roi_heads(images, features, proposals, gt_instances, labeled)\n        if self.vis_period > 0:\n            storage = get_event_storage()\n            if storage.iter % self.vis_period == 0:\n                self.visualize_training(batched_inputs, proposals)\n\n        losses = {}\n        losses.update(detector_losses)\n        losses.update(proposal_losses)\n        outputs['features'] = features\n        return outputs, losses\n\n    def get_parameters(self, lr=1.):\n        \"\"\"Return a parameter list which decides optimization hyper-parameters,\n            such as the learning rate of each layer\n        \"\"\"\n        return [\n            (self.backbone, 0.1 * lr if self.finetune else lr),\n            (self.proposal_generator, lr),\n            (self.roi_heads, lr),\n        ]"
  },
  {
    "path": "tllib/vision/models/object_detection/meta_arch/retinanet.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Dict, List, Tuple\nimport torch\nfrom torch import Tensor, nn\n\nfrom detectron2.modeling.meta_arch.retinanet import RetinaNet as RetinaNetBase\nfrom detectron2.modeling import detector_postprocess\n\n\nclass TLRetinaNet(RetinaNetBase):\n    \"\"\"\n    RetinaNet for Transfer Learning.\n\n    Different from that in Supervised Learning, TLRetinaNet\n    1. accepts unlabeled images during training (return no losses)\n    2. return both detection outputs, features, and losses during training\n\n    Args:\n        backbone: a backbone module, must follow detectron2's backbone interface\n        head (nn.Module): a module that predicts logits and regression deltas\n            for each level from a list of per-level features\n        head_in_features (Tuple[str]): Names of the input feature maps to be used in head\n        anchor_generator (nn.Module): a module that creates anchors from a\n            list of features. Usually an instance of :class:`AnchorGenerator`\n        box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to\n            instance boxes\n        anchor_matcher (Matcher): label the anchors by matching them with ground truth.\n        num_classes (int): number of classes. Used to label background proposals.\n\n        # Loss parameters:\n        focal_loss_alpha (float): focal_loss_alpha\n        focal_loss_gamma (float): focal_loss_gamma\n        smooth_l1_beta (float): smooth_l1_beta\n        box_reg_loss_type (str): Options are \"smooth_l1\", \"giou\"\n\n        # Inference parameters:\n        test_score_thresh (float): Inference cls score threshold, only anchors with\n            score > INFERENCE_TH are considered for inference (to improve speed)\n        test_topk_candidates (int): Select topk candidates before NMS\n        test_nms_thresh (float): Overlap threshold used for non-maximum suppression\n            (suppress boxes with IoU >= this threshold)\n        max_detections_per_image (int):\n            Maximum number of detections to return per image during inference\n            (100 is based on the limit established for the COCO dataset).\n\n        # Input parameters\n        pixel_mean (Tuple[float]):\n            Values to be used for image normalization (BGR order).\n            To train on images of different number of channels, set different mean & std.\n            Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]\n        pixel_std (Tuple[float]):\n            When using pre-trained models in Detectron1 or any MSRA models,\n            std has been absorbed into its conv1 weights, so the std needs to be set 1.\n            Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)\n        vis_period (int):\n            The period (in terms of steps) for minibatch visualization at train time.\n            Set to 0 to disable.\n        input_format (str): Whether the model needs RGB, YUV, HSV etc.\n        finetune (bool): whether finetune the detector or train from scratch. Default: True\n\n    Inputs:\n        - batched_inputs: a list, batched outputs of :class:`DatasetMapper`.\n          Each item in the list contains the inputs for one image.\n          For now, each item in the list is a dict that contains:\n            * image: Tensor, image in (C, H, W) format.\n            * instances (optional): groundtruth :class:`Instances`\n            * \"height\", \"width\" (int): the output resolution of the model, used in inference.\n              See :meth:`postprocess` for details.\n        - labeled (bool, optional): whether has ground-truth label\n\n    Outputs:\n        - outputs: A list of dict where each dict is the output for one input image.\n          The dict contains a key \"instances\" whose value is a :class:`Instances`\n          and a key \"features\" whose value is the features of middle layers.\n          The :class:`Instances` object has the following keys:\n          \"pred_boxes\", \"pred_classes\", \"scores\", \"pred_masks\", \"pred_keypoints\"\n        - losses: A dict of different losses\n    \"\"\"\n    def __init__(self, *args, finetune=False, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.finetune = finetune\n\n    def forward(self, batched_inputs: Tuple[Dict[str, Tensor]], labeled=True):\n        \"\"\"\"\"\"\n        images = self.preprocess_image(batched_inputs)\n        features = self.backbone(images.tensor)\n        features = [features[f] for f in self.head_in_features]\n        predictions = self.head(features)\n\n        if self.training:\n            if labeled:\n                assert not torch.jit.is_scripting(), \"Not supported\"\n                assert \"instances\" in batched_inputs[0], \"Instance annotations are missing in training!\"\n                gt_instances = [x[\"instances\"].to(self.device) for x in batched_inputs]\n                losses = self.forward_training(images, features, predictions, gt_instances)\n            else:\n                losses = {}\n            outputs = {\"features\": features}\n            return outputs, losses\n        else:\n            results = self.forward_inference(images, features, predictions)\n            if torch.jit.is_scripting():\n                return results\n\n            processed_results = []\n            for results_per_image, input_per_image, image_size in zip(\n                results, batched_inputs, images.image_sizes\n            ):\n                height = input_per_image.get(\"height\", image_size[0])\n                width = input_per_image.get(\"width\", image_size[1])\n                r = detector_postprocess(results_per_image, height, width)\n                processed_results.append({\"instances\": r})\n            return processed_results\n\n    def get_parameters(self, lr=1.):\n        \"\"\"Return a parameter list which decides optimization hyper-parameters,\n            such as the learning rate of each layer\n        \"\"\"\n        return [\n            (self.backbone.bottom_up, 0.1 * lr if self.finetune else lr),\n            (self.backbone.fpn_lateral4, lr),\n            (self.backbone.fpn_output4, lr),\n            (self.backbone.fpn_lateral5, lr),\n            (self.backbone.fpn_output5, lr),\n            (self.backbone.top_block, lr),\n            (self.head, lr),\n        ]"
  },
  {
    "path": "tllib/vision/models/object_detection/proposal_generator/__init__.py",
    "content": "from .rpn import TLRPN"
  },
  {
    "path": "tllib/vision/models/object_detection/proposal_generator/rpn.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom typing import Dict, Optional, List\n\nimport torch\nfrom detectron2.structures import ImageList, Instances\nfrom detectron2.modeling.proposal_generator import (\n    RPN,\n    PROPOSAL_GENERATOR_REGISTRY,\n)\n\n\n@PROPOSAL_GENERATOR_REGISTRY.register()\nclass TLRPN(RPN):\n    \"\"\"\n    Region Proposal Network, introduced by `Faster R-CNN`.\n\n    Args:\n        in_features (list[str]): list of names of input features to use\n        head (nn.Module): a module that predicts logits and regression deltas\n            for each level from a list of per-level features\n        anchor_generator (nn.Module): a module that creates anchors from a\n            list of features. Usually an instance of :class:`AnchorGenerator`\n        anchor_matcher (Matcher): label the anchors by matching them with ground truth.\n        box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to\n            instance boxes\n        batch_size_per_image (int): number of anchors per image to sample for training\n        positive_fraction (float): fraction of foreground anchors to sample for training\n        pre_nms_topk (tuple[float]): (train, test) that represents the\n            number of top k proposals to select before NMS, in\n            training and testing.\n        post_nms_topk (tuple[float]): (train, test) that represents the\n            number of top k proposals to select after NMS, in\n            training and testing.\n        nms_thresh (float): NMS threshold used to de-duplicate the predicted proposals\n        min_box_size (float): remove proposal boxes with any side smaller than this threshold,\n            in the unit of input image pixels\n        anchor_boundary_thresh (float): legacy option\n        loss_weight (float|dict): weights to use for losses. Can be single float for weighting\n            all rpn losses together, or a dict of individual weightings. Valid dict keys are:\n                \"loss_rpn_cls\" - applied to classification loss\n                \"loss_rpn_loc\" - applied to box regression loss\n        box_reg_loss_type (str): Loss type to use. Supported losses: \"smooth_l1\", \"giou\".\n        smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to\n            use L1 loss. Only used when `box_reg_loss_type` is \"smooth_l1\"\n\n    Inputs:\n        - images (ImageList): input images of length `N`\n        - features (dict[str, Tensor]): input data as a mapping from feature\n          map name to tensor. Axis 0 represents the number of images `N` in\n          the input data; axes 1-3 are channels, height, and width, which may\n          vary between feature maps (e.g., if a feature pyramid is used).\n        - gt_instances (list[Instances], optional): a length `N` list of `Instances`s.\n          Each `Instances` stores ground-truth instances for the corresponding image.\n        - labeled (bool, optional): whether has ground-truth label. Default: True\n\n    Outputs:\n        - proposals: list[Instances]: contains fields \"proposal_boxes\", \"objectness_logits\"\n        - loss: dict[Tensor] or None\n    \"\"\"\n    def __init__(self, *args, **kwargs):\n        super(TLRPN, self).__init__(*args, **kwargs)\n\n    def forward(\n        self,\n        images: ImageList,\n        features: Dict[str, torch.Tensor],\n        gt_instances: Optional[List[Instances]] = None,\n        labeled: Optional[bool] = True\n    ):\n        features = [features[f] for f in self.in_features]\n        # print(torch.max(features[0]))\n        anchors = self.anchor_generator(features)\n\n        pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)\n        # Transpose the Hi*Wi*A dimension to the middle:\n        pred_objectness_logits = [\n            # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A)\n            score.permute(0, 2, 3, 1).flatten(1)\n            for score in pred_objectness_logits\n        ]\n        pred_anchor_deltas = [\n            # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B)\n            x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1])\n            .permute(0, 3, 4, 1, 2)\n            .flatten(1, -2)\n            for x in pred_anchor_deltas\n        ]\n\n        if self.training and labeled:\n            gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)\n            losses = self.losses(\n                anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes\n            )\n        else:\n            losses = {}\n        proposals = self.predict_proposals(\n            anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes\n        )\n        return proposals, losses\n\n"
  },
  {
    "path": "tllib/vision/models/object_detection/roi_heads/__init__.py",
    "content": "from .roi_heads import TLRes5ROIHeads, TLStandardROIHeads"
  },
  {
    "path": "tllib/vision/models/object_detection/roi_heads/roi_heads.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch\nfrom typing import List, Dict\nfrom detectron2.structures import Instances\nfrom detectron2.modeling.roi_heads import (\n    ROI_HEADS_REGISTRY,\n    Res5ROIHeads,\n    StandardROIHeads,\n    select_foreground_proposals,\n)\n\n\n@ROI_HEADS_REGISTRY.register()\nclass TLRes5ROIHeads(Res5ROIHeads):\n    \"\"\"\n    The ROIHeads in a typical \"C4\" R-CNN model, where\n    the box and mask head share the cropping and\n    the per-region feature computation by a Res5 block.\n\n    Args:\n        in_features (list[str]): list of backbone feature map names to use for\n            feature extraction\n        pooler (ROIPooler): pooler to extra region features from backbone\n        res5 (nn.Sequential): a CNN to compute per-region features, to be used by\n            ``box_predictor`` and ``mask_head``. Typically this is a \"res5\"\n            block from a ResNet.\n        box_predictor (nn.Module): make box predictions from the feature.\n            Should have the same interface as :class:`FastRCNNOutputLayers`.\n        mask_head (nn.Module): transform features to make mask predictions\n\n    Inputs:\n        - images (ImageList):\n        - features (dict[str,Tensor]): input data as a mapping from feature\n          map name to tensor. Axis 0 represents the number of images `N` in\n          the input data; axes 1-3 are channels, height, and width, which may\n          vary between feature maps (e.g., if a feature pyramid is used).\n        - proposals (list[Instances]): length `N` list of `Instances`. The i-th\n          `Instances` contains object proposals for the i-th input image,\n          with fields \"proposal_boxes\" and \"objectness_logits\".\n        - targets (list[Instances], optional): length `N` list of `Instances`. The i-th\n          `Instances` contains the ground-truth per-instance annotations\n          for the i-th input image.  Specify `targets` during training only.\n          It may have the following fields:\n            - gt_boxes: the bounding box of each instance.\n            - gt_classes: the label for each instance with a category ranging in [0, #class].\n            - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.\n            - gt_keypoints: NxKx3, the groud-truth keypoints for each instance.\n        - labeled (bool, optional): whether has ground-truth label. Default: True\n\n    Outputs:\n        - list[Instances]: length `N` list of `Instances` containing the\n          detected instances. Returned during inference only; may be [] during training.\n\n        - dict[str->Tensor]:\n          mapping from a named loss to a tensor storing the loss. Used during training only.\n    \"\"\"\n    def __init__(self, *args, **kwargs):\n        super(TLRes5ROIHeads, self).__init__(*args, **kwargs)\n\n    def forward(self, images, features, proposals, targets=None, labeled=True):\n        \"\"\"\"\"\"\n        del images\n\n        if self.training:\n            if labeled:\n                proposals = self.label_and_sample_proposals(proposals, targets)\n            else:\n                proposals = self.sample_unlabeled_proposals(proposals)\n        del targets\n\n        proposal_boxes = [x.proposal_boxes for x in proposals]\n        box_features = self._shared_roi_transform(\n            [features[f] for f in self.in_features], proposal_boxes\n        )\n        predictions = self.box_predictor(box_features.mean(dim=[2, 3]))\n\n        if self.training:\n            del features\n            if labeled:\n                losses = self.box_predictor.losses(predictions, proposals)\n                if self.mask_on:\n                    proposals, fg_selection_masks = select_foreground_proposals(\n                        proposals, self.num_classes\n                    )\n                    # Since the ROI feature transform is shared between boxes and masks,\n                    # we don't need to recompute features. The mask loss is only defined\n                    # on foreground proposals, so we need to select out the foreground\n                    # features.\n                    mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]\n                    # del box_features\n                    losses.update(self.mask_head(mask_features, proposals))\n            else:\n                losses = {}\n            outputs = {\n                'predictions': predictions[0],\n                'box_features': box_features\n            }\n            return outputs, losses\n        else:\n            pred_instances, _ = self.box_predictor.inference(predictions, proposals)\n            pred_instances = self.forward_with_given_boxes(features, pred_instances)\n            return pred_instances, {}\n\n    @torch.no_grad()\n    def sample_unlabeled_proposals(\n        self, proposals: List[Instances]\n    ) -> List[Instances]:\n        \"\"\"\n        Prepare some unlabeled proposals.\n        It returns top ``self.batch_size_per_image`` samples from proposals\n\n        Args:\n            proposals (list[Instances]): length `N` list of `Instances`. The i-th\n                `Instances` contains object proposals for the i-th input image,\n                with fields \"proposal_boxes\" and \"objectness_logits\".\n\n        Returns:\n            length `N` list of `Instances`s containing the proposals sampled for training.\n        \"\"\"\n        return [proposal[:self.batch_size_per_image] for proposal in proposals]\n\n\n@ROI_HEADS_REGISTRY.register()\nclass TLStandardROIHeads(StandardROIHeads):\n    \"\"\"\n    It's \"standard\" in a sense that there is no ROI transform sharing\n    or feature sharing between tasks.\n    Each head independently processes the input features by each head's\n    own pooler and head.\n\n    Args:\n        box_in_features (list[str]): list of feature names to use for the box head.\n        box_pooler (ROIPooler): pooler to extra region features for box head\n        box_head (nn.Module): transform features to make box predictions\n        box_predictor (nn.Module): make box predictions from the feature.\n            Should have the same interface as :class:`FastRCNNOutputLayers`.\n        mask_in_features (list[str]): list of feature names to use for the mask\n            pooler or mask head. None if not using mask head.\n        mask_pooler (ROIPooler): pooler to extract region features from image features.\n            The mask head will then take region features to make predictions.\n            If None, the mask head will directly take the dict of image features\n            defined by `mask_in_features`\n        mask_head (nn.Module): transform features to make mask predictions\n        keypoint_in_features, keypoint_pooler, keypoint_head: similar to ``mask_*``.\n        train_on_pred_boxes (bool): whether to use proposal boxes or\n            predicted boxes from the box head to train other heads.\n\n    Inputs:\n        - images (ImageList):\n        - features (dict[str,Tensor]): input data as a mapping from feature\n          map name to tensor. Axis 0 represents the number of images `N` in\n          the input data; axes 1-3 are channels, height, and width, which may\n          vary between feature maps (e.g., if a feature pyramid is used).\n        - proposals (list[Instances]): length `N` list of `Instances`. The i-th\n          `Instances` contains object proposals for the i-th input image,\n          with fields \"proposal_boxes\" and \"objectness_logits\".\n        - targets (list[Instances], optional): length `N` list of `Instances`. The i-th\n          `Instances` contains the ground-truth per-instance annotations\n          for the i-th input image.  Specify `targets` during training only.\n          It may have the following fields:\n            - gt_boxes: the bounding box of each instance.\n            - gt_classes: the label for each instance with a category ranging in [0, #class].\n            - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.\n            - gt_keypoints: NxKx3, the groud-truth keypoints for each instance.\n        - labeled (bool, optional): whether has ground-truth label. Default: True\n\n    Outputs:\n        - list[Instances]: length `N` list of `Instances` containing the\n          detected instances. Returned during inference only; may be [] during training.\n\n        - dict[str->Tensor]:\n          mapping from a named loss to a tensor storing the loss. Used during training only.\n    \"\"\"\n    def __init__(self, *args, **kwargs):\n        super(TLStandardROIHeads, self).__init__(*args, **kwargs)\n\n    def forward(self, images, features, proposals, targets=None, labeled=True):\n        \"\"\"\"\"\"\n        del images\n        if self.training:\n            if labeled:\n                proposals = self.label_and_sample_proposals(proposals, targets)\n            else:\n                proposals = self.sample_unlabeled_proposals(proposals)\n        del targets\n\n        if self.training:\n            if labeled:\n                outputs, losses = self._forward_box(features, proposals)\n                # Usually the original proposals used by the box head are used by the mask, keypoint\n                # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes\n                # predicted by the box head.\n                losses.update(self._forward_mask(features, proposals))\n                losses.update(self._forward_keypoint(features, proposals))\n            else:\n                losses = {}\n            return outputs, losses\n        else:\n            pred_instances = self._forward_box(features, proposals)\n            # During inference cascaded prediction is used: the mask and keypoints heads are only\n            # applied to the top scoring box detections.\n            pred_instances = self.forward_with_given_boxes(features, pred_instances)\n            return pred_instances, {}\n\n    def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]):\n        \"\"\"\n        Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`,\n            the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument.\n\n        Args:\n            features (dict[str, Tensor]): mapping from feature map names to tensor.\n                Same as in :meth:`ROIHeads.forward`.\n            proposals (list[Instances]): the per-image object proposals with\n                their matching ground truth.\n                Each has fields \"proposal_boxes\", and \"objectness_logits\",\n                \"gt_classes\", \"gt_boxes\".\n\n        Returns:\n            In training, a dict of losses.\n            In inference, a list of `Instances`, the predicted instances.\n        \"\"\"\n        features = [features[f] for f in self.box_in_features]\n        box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])\n        box_features = self.box_head(box_features)\n        predictions = self.box_predictor(box_features)\n\n        if self.training:\n            losses = self.box_predictor.losses(predictions, proposals)\n            # proposals is modified in-place below, so losses must be computed first.\n            if self.train_on_pred_boxes:\n                with torch.no_grad():\n                    pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(\n                        predictions, proposals\n                    )\n                    for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):\n                        proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image)\n            outputs = {\n                'predictions': predictions[0],\n                'box_features': box_features\n            }\n            return outputs, losses\n        else:\n            pred_instances, _ = self.box_predictor.inference(predictions, proposals)\n            return pred_instances\n\n    @torch.no_grad()\n    def sample_unlabeled_proposals(\n        self, proposals: List[Instances]\n    ) -> List[Instances]:\n        \"\"\"\n        Prepare some unlabeled proposals.\n        It returns top ``self.batch_size_per_image`` samples from proposals\n\n        Args:\n            proposals (list[Instances]): length `N` list of `Instances`. The i-th\n                `Instances` contains object proposals for the i-th input image,\n                with fields \"proposal_boxes\" and \"objectness_logits\".\n\n        Returns:\n            length `N` list of `Instances`s containing the proposals sampled for training.\n        \"\"\"\n        return [proposal[:self.batch_size_per_image] for proposal in proposals]\n\n\n\n"
  },
  {
    "path": "tllib/vision/models/reid/__init__.py",
    "content": "from .resnet import *\n\n__all__ = ['resnet']\n"
  },
  {
    "path": "tllib/vision/models/reid/identifier.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom typing import List, Dict, Optional\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\n\n\nclass ReIdentifier(nn.Module):\n    r\"\"\"Person reIdentifier from `Bag of Tricks and A Strong Baseline for Deep Person Re-identification (CVPR 2019)\n    <https://arxiv.org/pdf/1903.07071.pdf>`_.\n    Given 2-d features :math:`f` from backbone network, the authors pass :math:`f` through another `BatchNorm1d` layer\n    and get :math:`bn\\_f`, which will then pass through a `Linear` layer to output predictions. During training, we\n    use :math:`f` to compute triplet loss. While during testing, :math:`bn\\_f` is used as feature. This may be a little\n    confusing. The figures in the origin paper will help you understand better.\n    \"\"\"\n\n    def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,\n                 bottleneck_dim: Optional[int] = -1, finetune=True, pool_layer=None):\n        super(ReIdentifier, self).__init__()\n        if pool_layer is None:\n            self.pool_layer = nn.Sequential(\n                nn.AdaptiveAvgPool2d(output_size=(1, 1)),\n                nn.Flatten()\n            )\n        else:\n            self.pool_layer = pool_layer\n        self.backbone = backbone\n        self.num_classes = num_classes\n        if bottleneck is None:\n            feature_bn = nn.BatchNorm1d(backbone.out_features)\n            self.bottleneck = feature_bn\n            self._features_dim = backbone.out_features\n        else:\n            feature_bn = nn.BatchNorm1d(bottleneck_dim)\n            self.bottleneck = nn.Sequential(\n                bottleneck,\n                feature_bn\n            )\n            self._features_dim = bottleneck_dim\n\n        self.head = nn.Linear(self.features_dim, num_classes, bias=False)\n        self.finetune = finetune\n\n        # initialize feature_bn and head\n        feature_bn.bias.requires_grad_(False)\n        init.constant_(feature_bn.weight, 1)\n        init.constant_(feature_bn.bias, 0)\n        init.normal_(self.head.weight, std=0.001)\n\n    @property\n    def features_dim(self) -> int:\n        \"\"\"The dimension of features before the final `head` layer\"\"\"\n        return self._features_dim\n\n    def forward(self, x: torch.Tensor):\n        \"\"\"\"\"\"\n        f = self.pool_layer(self.backbone(x))\n        bn_f = self.bottleneck(f)\n        if not self.training:\n            return bn_f\n        predictions = self.head(bn_f)\n        return predictions, f\n\n    def get_parameters(self, base_lr=1.0, rate=0.1) -> List[Dict]:\n        \"\"\"A parameter list which decides optimization hyper-parameters,\n            such as the relative learning rate of each layer\n        \"\"\"\n        params = [\n            {\"params\": self.backbone.parameters(), \"lr\": rate * base_lr if self.finetune else 1.0 * base_lr},\n            {\"params\": self.bottleneck.parameters(), \"lr\": 1.0 * base_lr},\n            {\"params\": self.head.parameters(), \"lr\": 1.0 * base_lr},\n        ]\n\n        return params\n"
  },
  {
    "path": "tllib/vision/models/reid/loss.py",
    "content": "\"\"\"\nModified from https://github.com/yxgeee/MMT\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef pairwise_euclidean_distance(x, y):\n    \"\"\"Compute pairwise euclidean distance between two sets of features\"\"\"\n    m, n = x.size(0), y.size(0)\n    dist_mat = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) + \\\n               torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() \\\n               - 2 * torch.matmul(x, y.t())\n    # for numerical stability\n    dist_mat = dist_mat.clamp(min=1e-12).sqrt()\n    return dist_mat\n\n\ndef hard_examples_mining(dist_mat, identity_mat, return_idxes=False):\n    r\"\"\"Select hard positives and hard negatives according to `In defense of the Triplet Loss for Person\n    Re-Identification (ICCV 2017) <https://arxiv.org/pdf/1703.07737v2.pdf>`_\n\n    Args:\n        dist_mat (tensor): pairwise distance matrix between two sets of features\n        identity_mat (tensor): a matrix of shape :math:`(N, M)`. If two images :math:`P[i]` of set :math:`P` and\n            :math:`Q[j]` of set :math:`Q` come from the same person, then :math:`identity\\_mat[i, j] = 1`,\n            otherwise :math:`identity\\_mat[i, j] = 0`\n        return_idxes (bool, optional): if True, also return indexes of hard examples. Default: False\n    \"\"\"\n    # the implementation here is a little tricky, dist_mat contains pairwise distance between probe image and other\n    # images in current mini-batch. As we want to select positive examples of the same person, we add a constant\n    # negative offset on other images before sorting. As a result, images of the **same** person will rank first.\n    sorted_dist_mat, sorted_idxes = torch.sort(dist_mat + (-1e7) * (1 - identity_mat), dim=1,\n                                               descending=True)\n    dist_ap = sorted_dist_mat[:, 0]\n    hard_positive_idxes = sorted_idxes[:, 0]\n\n    # the implementation here is similar to above code, we add a constant positive offset on images of same person\n    # before sorting. Besides, we sort in ascending order. As a result, images of **different** persons will rank first.\n    sorted_dist_mat, sorted_idxes = torch.sort(dist_mat + 1e7 * identity_mat, dim=1,\n                                               descending=False)\n    dist_an = sorted_dist_mat[:, 0]\n    hard_negative_idxes = sorted_idxes[:, 0]\n    if return_idxes:\n        return dist_ap, dist_an, hard_positive_idxes, hard_negative_idxes\n    return dist_ap, dist_an\n\n\nclass CrossEntropyLossWithLabelSmooth(nn.Module):\n    r\"\"\"Cross entropy loss with label smooth from `Rethinking the Inception Architecture for Computer Vision\n    (CVPR 2016) <https://arxiv.org/pdf/1512.00567.pdf>`_.\n\n    Given one-hot labels :math:`labels \\in R^C`, where :math:`C` is the number of classes,\n    smoothed labels are calculated as\n\n    .. math::\n        smoothed\\_labels = (1 - \\epsilon) \\times labels + \\epsilon \\times \\frac{1}{C}\n\n    We use smoothed labels when calculating cross entropy loss and this can be helpful for preventing over-fitting.\n\n    Args:\n        num_classes (int): number of classes.\n        epsilon (float): a float number that controls the smoothness.\n\n    Inputs:\n        - y (tensor): unnormalized classifier predictions, :math:`y`\n        - labels (tensor): ground truth labels, :math:`labels`\n\n    Shape:\n        - y: :math:`(minibatch, C)`, where :math:`C` is the number of classes\n        - labels: :math:`(minibatch, )`\n    \"\"\"\n\n    def __init__(self, num_classes, epsilon=0.1):\n        super(CrossEntropyLossWithLabelSmooth, self).__init__()\n        self.num_classes = num_classes\n        self.epsilon = epsilon\n        self.log_softmax = nn.LogSoftmax(dim=1).cuda()\n\n    def forward(self, y, labels):\n        log_prob = self.log_softmax(y)\n        labels = torch.zeros_like(log_prob).scatter_(1, labels.unsqueeze(1), 1)\n        labels = (1 - self.epsilon) * labels + self.epsilon / self.num_classes\n        loss = (- labels * log_prob).mean(0).sum()\n        return loss\n\n\nclass TripletLoss(nn.Module):\n    \"\"\"Triplet loss augmented with batch hard from `In defense of the Triplet Loss for Person Re-Identification\n    (ICCV 2017) <https://arxiv.org/pdf/1703.07737v2.pdf>`_.\n\n    Args:\n        margin (float): margin of triplet loss\n        normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss.\n            Default: False.\n    \"\"\"\n\n    def __init__(self, margin, normalize_feature=False):\n        super(TripletLoss, self).__init__()\n        self.margin = margin\n        self.normalize_feature = normalize_feature\n        self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda()\n\n    def forward(self, f, labels):\n        if self.normalize_feature:\n            # equivalent to cosine similarity\n            f = F.normalize(f)\n        dist_mat = pairwise_euclidean_distance(f, f)\n\n        n = dist_mat.size(0)\n        identity_mat = labels.expand(n, n).eq(labels.expand(n, n).t()).float()\n\n        dist_ap, dist_an = hard_examples_mining(dist_mat, identity_mat)\n        y = torch.ones_like(dist_ap)\n        loss = self.margin_loss(dist_an, dist_ap, y)\n        return loss\n\n\nclass TripletLossXBM(nn.Module):\n    r\"\"\"Triplet loss augmented with batch hard from `In defense of the Triplet Loss for Person Re-Identification\n    (ICCV 2017) <https://arxiv.org/pdf/1703.07737v2.pdf>`_. The only difference from triplet loss lies in that\n    both features from current mini batch and external storage (XBM) are involved.\n\n    Args:\n        margin (float, optional): margin of triplet loss. Default: 0.3\n        normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss.\n            Default: False\n\n    Inputs:\n        - f (tensor): features of current mini batch, :math:`f`\n        - labels (tensor): identity labels for current mini batch, :math:`labels`\n        - xbm_f (tensor): features collected from XBM, :math:`xbm\\_f`\n        - xbm_labels (tensor): corresponding identity labels of xbm_f, :math:`xbm\\_labels`\n\n    Shape:\n        - f: :math:`(minibatch, F)`, where :math:`F` is the feature dimension\n        - labels: :math:`(minibatch, )`\n        - xbm_f: :math:`(minibatch, F)`\n        - xbm_labels: :math:`(minibatch, )`\n    \"\"\"\n\n    def __init__(self, margin=0.3, normalize_feature=False):\n        super(TripletLossXBM, self).__init__()\n        self.margin = margin\n        self.normalize_feature = normalize_feature\n        self.ranking_loss = nn.MarginRankingLoss(margin=margin)\n\n    def forward(self, f, labels, xbm_f, xbm_labels):\n        if self.normalize_feature:\n            # equivalent to cosine similarity\n            f = F.normalize(f)\n            xbm_f = F.normalize(xbm_f)\n\n        dist_mat = pairwise_euclidean_distance(f, xbm_f)\n\n        # hard examples mining\n        n, m = f.size(0), xbm_f.size(0)\n        identity_mat = labels.expand(m, n).t().eq(xbm_labels.expand(n, m)).float()\n        dist_ap, dist_an = hard_examples_mining(dist_mat, identity_mat)\n\n        # Compute ranking hinge loss\n        y = torch.ones_like(dist_an)\n        loss = self.ranking_loss(dist_an, dist_ap, y)\n\n        return loss\n\n\nclass SoftTripletLoss(nn.Module):\n    r\"\"\"Soft triplet loss from `Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised\n    Domain Adaptation on Person Re-identification (ICLR 2020) <https://arxiv.org/pdf/2001.01526.pdf>`_.\n    Consider a triplet :math:`x,x_p,x_n` (anchor, positive, negative), corresponding features are :math:`f,f_p,f_n`.\n    We optimize for a smaller distance between :math:`f` and :math:`f_p` and a larger distance\n    between :math:`f` and :math:`f_n`. Inner product is adopted as their similarity measure, soft triplet loss is thus\n    defined as\n\n    .. math::\n        loss = \\mathcal{L}_{\\text{bce}}(\\frac{\\text{exp}(f^Tf_p)}{\\text{exp}(f^Tf_p)+\\text{exp}(f^Tf_n)}, 1)\n\n    where :math:`\\mathcal{L}_{\\text{bce}}` means binary cross entropy loss. We denote the first term in above loss function\n    as :math:`T`. When features from another teacher network can be obtained, we can calculate :math:`T_{teacher}` as\n    labels, resulting in the following soft version\n\n    .. math::\n        loss = \\mathcal{L}_{\\text{bce}}(T, T_{teacher})\n\n    Args:\n        margin (float, optional): margin of triplet loss. If None, soft labels from another network will be adopted when\n            computing loss. Default: None.\n        normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss.\n            Default: False.\n    \"\"\"\n\n    def __init__(self, margin=None, normalize_feature=False):\n        super(SoftTripletLoss, self).__init__()\n        self.margin = margin\n        self.normalize_feature = normalize_feature\n\n    def forward(self, features_1, features_2, labels):\n        if self.normalize_feature:\n            # equal to cosine similarity\n            features_1 = F.normalize(features_1)\n            features_2 = F.normalize(features_2)\n\n        dist_mat = pairwise_euclidean_distance(features_1, features_1)\n        assert dist_mat.size(0) == dist_mat.size(1)\n\n        n = dist_mat.size(0)\n        identity_mat = labels.expand(n, n).eq(labels.expand(n, n).t()).float()\n\n        dist_ap, dist_an, ap_idxes, an_idxes = hard_examples_mining(dist_mat, identity_mat, return_idxes=True)\n        assert dist_an.size(0) == dist_ap.size(0)\n        triple_dist = torch.stack((dist_ap, dist_an), dim=1)\n        triple_dist = F.log_softmax(triple_dist, dim=1)\n        if self.margin is not None:\n            loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean()\n            return loss\n\n        dist_mat_ref = pairwise_euclidean_distance(features_2, features_2)\n        dist_ap_ref = torch.gather(dist_mat_ref, 1, ap_idxes.view(n, 1).expand(n, n))[:, 0]\n        dist_an_ref = torch.gather(dist_mat_ref, 1, an_idxes.view(n, 1).expand(n, n))[:, 0]\n        triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1)\n        triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach()\n\n        loss = (- triple_dist_ref * triple_dist).sum(dim=1).mean()\n        return loss\n\n\nclass CrossEntropyLoss(nn.Module):\n    r\"\"\"We use :math:`C` to denote the number of classes, :math:`N` to denote mini-batch\n    size, this criterion expects unnormalized predictions :math:`y\\_{logits}` of shape :math:`(N, C)` and\n    :math:`target\\_{logits}` of the same shape :math:`(N, C)`. Then we first normalize them into\n    probability distributions among classes\n\n    .. math::\n        y = \\text{softmax}(y\\_{logits})\n    .. math::\n        target = \\text{softmax}(target\\_{logits})\n\n    Final objective is calculated as\n\n    .. math::\n        \\text{loss} = \\frac{1}{N} \\sum_{i=1}^{N} \\sum_{j=1}^C -target_i^j \\times \\text{log} (y_i^j)\n    \"\"\"\n\n    def __init__(self):\n        super(CrossEntropyLoss, self).__init__()\n        self.log_softmax = nn.LogSoftmax(dim=1).cuda()\n\n    def forward(self, y, labels):\n        log_prob = self.log_softmax(y)\n        loss = (- F.softmax(labels, dim=1).detach() * log_prob).sum(dim=1).mean()\n        return loss\n"
  },
  {
    "path": "tllib/vision/models/reid/resnet.py",
    "content": "\"\"\"\n@author: Baixu Chen\n@contact: cbx_99_hasta@outlook.com\n\"\"\"\nfrom tllib.vision.models.resnet import ResNet, load_state_dict_from_url, model_urls, BasicBlock, Bottleneck\n\n__all__ = ['reid_resnet18', 'reid_resnet34', 'reid_resnet50', 'reid_resnet101']\n\n\nclass ReidResNet(ResNet):\n    r\"\"\"Modified `ResNet` architecture for ReID from `Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised\n    Domain Adaptation on Person Re-identification (ICLR 2020) <https://arxiv.org/pdf/2001.01526.pdf>`_. We change stride\n    of :math:`layer4\\_group1\\_conv2, layer4\\_group1\\_downsample1` to 1. During forward pass, we will not activate\n    `self.relu`. Please refer to source code for details.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(ReidResNet, self).__init__(*args, **kwargs)\n        self.layer4[0].conv2.stride = (1, 1)\n        self.layer4[0].downsample[0].stride = (1, 1)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        # x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        return x\n\n\ndef _reid_resnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = ReidResNet(block, layers, **kwargs)\n    if pretrained:\n        model_dict = model.state_dict()\n        pretrained_dict = load_state_dict_from_url(model_urls[arch],\n                                                   progress=progress)\n        # remove keys from pretrained dict that doesn't appear in model dict\n        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n        model.load_state_dict(pretrained_dict, strict=False)\n    return model\n\n\ndef reid_resnet18(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Constructs a Reid-ResNet-18 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _reid_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                        **kwargs)\n\n\ndef reid_resnet34(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Constructs a Reid-ResNet-34 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _reid_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                        **kwargs)\n\n\ndef reid_resnet50(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Constructs a Reid-ResNet-50 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _reid_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                        **kwargs)\n\n\ndef reid_resnet101(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Constructs a Reid-ResNet-101 model.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _reid_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                        **kwargs)\n"
  },
  {
    "path": "tllib/vision/models/resnet.py",
    "content": "\"\"\"\nModified based on torchvision.models.resnet.\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\n\nimport torch.nn as nn\nfrom torchvision import models\nfrom torch.hub import load_state_dict_from_url\nfrom torchvision.models.resnet import BasicBlock, Bottleneck, model_urls\nimport copy\n\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',\n           'wide_resnet50_2', 'wide_resnet101_2']\n\n\nclass ResNet(models.ResNet):\n    \"\"\"ResNets without fully connected layer\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(ResNet, self).__init__(*args, **kwargs)\n        self._out_features = self.fc.in_features\n\n    def forward(self, x):\n        \"\"\"\"\"\"\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        # x = self.avgpool(x)\n        # x = torch.flatten(x, 1)\n        # x = x.view(-1, self._out_features)\n        return x\n\n    @property\n    def out_features(self) -> int:\n        \"\"\"The dimension of output features\"\"\"\n        return self._out_features\n\n    def copy_head(self) -> nn.Module:\n        \"\"\"Copy the origin fully connected layer\"\"\"\n        return copy.deepcopy(self.fc)\n\n\ndef _resnet(arch, block, layers, pretrained, progress, **kwargs):\n    model = ResNet(block, layers, **kwargs)\n    if pretrained:\n        model_dict = model.state_dict()\n        pretrained_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        # remove keys from pretrained dict that doesn't appear in model dict\n        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n        model.load_state_dict(pretrained_dict, strict=False)\n    return model\n\n\ndef resnet18(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-18 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet34(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet50(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-50 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet101(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-101 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet152(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNet-152 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnext50_32x4d(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNeXt-50 32x4d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef resnext101_32x8d(pretrained=False, progress=True, **kwargs):\n    r\"\"\"ResNeXt-101 32x8d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet50_2(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Wide ResNet-50-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet101_2(pretrained=False, progress=True, **kwargs):\n    r\"\"\"Wide ResNet-101-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n"
  },
  {
    "path": "tllib/vision/models/segmentation/__init__.py",
    "content": "from .deeplabv2 import *\n\n__all__ = ['deeplabv2']"
  },
  {
    "path": "tllib/vision/models/segmentation/deeplabv2.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nimport torch.nn as nn\nfrom torchvision.models.utils import load_state_dict_from_url\n\n\nmodel_urls = {\n    'deeplabv2_resnet101': 'https://cloud.tsinghua.edu.cn/f/2d9a7fc43ce34f76803a/?dl=1'\n}\n\naffine_par = True\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)  # change\n        self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)\n        for i in self.bn1.parameters():\n            i.requires_grad = False\n\n        padding = dilation\n\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,  # change\n                               padding=padding, bias=False, dilation=dilation)\n        self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)\n        for i in self.bn2.parameters():\n            i.requires_grad = False\n\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)\n        for i in self.bn3.parameters():\n            i.requires_grad = False\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ASPP_V2(nn.Module):\n    def __init__(self, inplanes, dilation_series, padding_series, num_classes):\n        super(ASPP_V2, self).__init__()\n        self.conv2d_list = nn.ModuleList()\n        for dilation, padding in zip(dilation_series, padding_series):\n            self.conv2d_list.append(\n                nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation,\n                          bias=True))\n        for m in self.conv2d_list:\n            m.weight.data.normal_(0, 0.01)\n\n    def forward(self, x):\n        out = self.conv2d_list[0](x)\n        for i in range(len(self.conv2d_list) - 1):\n            out += self.conv2d_list[i + 1](x)\n\n        return out\n\n\nclass ResNet(nn.Module):\n    def __init__(self, block, layers):\n        self.inplanes = 64\n        super(ResNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(64, affine=affine_par)\n        for i in self.bn1.parameters():\n            i.requires_grad = False\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)  # change\n\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                m.weight.data.normal_(0, 0.01)\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion, affine=affine_par))\n        for i in downsample._modules['1'].parameters():\n            i.requires_grad = False\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, dilation=dilation))\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        return x\n\n\nclass Deeplab(nn.Module):\n    def __init__(self, backbone, classifier, num_classes):\n        super(Deeplab, self).__init__()\n        self.backbone = backbone\n        self.classifier = classifier\n        self.num_classes = num_classes\n\n    def forward(self, x):\n        x = self.backbone(x)\n        y = self.classifier(x)\n        return y\n\n    def get_1x_lr_params_NOscale(self):\n        \"\"\"\n        This generator returns all the parameters of the net except for\n        the last classification layer. Note that for each batchnorm layer,\n        requires_grad is set to False in deeplab_resnet.py, therefore this function does not return\n        any batchnorm parameter\n        \"\"\"\n        layers = [self.backbone.conv1, self.backbone.bn1,\n                self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]\n        for layer in layers:\n            for module in layer.modules():\n                for param in module.parameters():\n                    if param.requires_grad:\n                        yield param\n\n    def get_10x_lr_params(self):\n        \"\"\"\n        This generator returns all the parameters for the last layer of the net,\n        which does the classification of pixel into classes\n        \"\"\"\n        for param in self.classifier.parameters():\n            yield param\n\n    def get_parameters(self, lr=1.):\n        return [\n            {'params': self.get_1x_lr_params_NOscale(), 'lr': 0.1 * lr},\n            {'params': self.get_10x_lr_params(), 'lr': lr}\n        ]\n\n\ndef deeplabv2_resnet101(num_classes=19, pretrained_backbone=True):\n    \"\"\"Constructs a DeepLabV2 model with a ResNet-101 backbone.\n\n     Args:\n         num_classes (int, optional): number of classes. Default: 19\n         pretrained_backbone (bool, optional): If True, returns a model pre-trained on ImageNet. Default: True.\n     \"\"\"\n    backbone = ResNet(Bottleneck, [3, 4, 23, 3])\n    if pretrained_backbone:\n        # download from Internet\n        saved_state_dict = load_state_dict_from_url(model_urls['deeplabv2_resnet101'], map_location=lambda storage, loc: storage, file_name=\"deeplabv2_resnet101.pth\")\n        new_params = backbone.state_dict().copy()\n        for i in saved_state_dict:\n            i_parts = i.split('.')\n            if not i_parts[1] == 'layer5':\n                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]\n        backbone.load_state_dict(new_params)\n    classifier = ASPP_V2(2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes)\n    return Deeplab(backbone, classifier, num_classes)\n"
  },
  {
    "path": "tllib/vision/transforms/__init__.py",
    "content": "import math\nimport random\nfrom PIL import Image\nimport numpy as np\nimport torch\nfrom torchvision.transforms import Normalize\n\n\nclass ResizeImage(object):\n    \"\"\"Resize the input PIL Image to the given size.\n\n    Args:\n        size (sequence or int): Desired output size. If size is a sequence like\n          (h, w), output size will be matched to this. If size is an int,\n          output size will be (size, size)\n    \"\"\"\n\n    def __init__(self, size):\n        if isinstance(size, int):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n\n    def __call__(self, img):\n        th, tw = self.size\n        return img.resize((th, tw))\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(size={0})'.format(self.size)\n\n\nclass MultipleApply:\n    \"\"\"Apply a list of transformations to an image and get multiple transformed images.\n\n    Args:\n        transforms (list or tuple): list of transformations\n\n    Example:\n        \n        >>> transform1 = T.Compose([\n        ...     ResizeImage(256),\n        ...     T.RandomCrop(224)\n        ... ])\n        >>> transform2 = T.Compose([\n        ...     ResizeImage(256),\n        ...     T.RandomCrop(224),\n        ... ])\n        >>> multiply_transform = MultipleApply([transform1, transform2])\n    \"\"\"\n\n    def __init__(self, transforms):\n        self.transforms = transforms\n\n    def __call__(self, image):\n        return [t(image) for t in self.transforms]\n\n    def __repr__(self):\n        format_string = self.__class__.__name__ + '('\n        for t in self.transforms:\n            format_string += '\\n'\n            format_string += '    {0}'.format(t)\n        format_string += '\\n)'\n        return format_string\n\n\nclass Denormalize(Normalize):\n    \"\"\"DeNormalize a tensor image with mean and standard deviation.\n    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``\n    channels, this transform will denormalize each channel of the input\n    ``torch.*Tensor`` i.e.,\n    ``output[channel] = input[channel] * std[channel] + mean[channel]``\n\n    .. note::\n        This transform acts out of place, i.e., it does not mutate the input tensor.\n\n    Args:\n        mean (sequence): Sequence of means for each channel.\n        std (sequence): Sequence of standard deviations for each channel.\n\n    \"\"\"\n\n    def __init__(self, mean, std):\n        mean = np.array(mean)\n        std = np.array(std)\n        super().__init__((-mean / std).tolist(), (1 / std).tolist())\n\n\nclass NormalizeAndTranspose:\n    \"\"\"\n    First, normalize a tensor image with mean and standard deviation.\n    Then, convert the shape (H x W x C) to shape (C x H x W).\n    \"\"\"\n\n    def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)):\n        self.mean = np.array(mean, dtype=np.float32)\n\n    def __call__(self, image):\n        if isinstance(image, Image.Image):\n            image = np.asarray(image, np.float32)\n            # change to BGR\n            image = image[:, :, ::-1]\n            # normalize\n            image -= self.mean\n            image = image.transpose((2, 0, 1)).copy()\n        elif isinstance(image, torch.Tensor):\n            # change to BGR\n            image = image[:, :, [2, 1, 0]]\n            # normalize\n            image -= torch.from_numpy(self.mean).to(image.device)\n            image = image.permute((2, 0, 1))\n        else:\n            raise NotImplementedError(type(image))\n        return image\n\n\nclass DeNormalizeAndTranspose:\n    \"\"\"\n    First, convert a tensor image from the shape (C x H x W ) to shape (H x W x C).\n    Then, denormalize it with mean and standard deviation.\n    \"\"\"\n\n    def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)):\n        self.mean = np.array(mean, dtype=np.float32)\n\n    def __call__(self, image):\n        image = image.transpose((1, 2, 0))\n        # denormalize\n        image += self.mean\n        # change to RGB\n        image = image[:, :, ::-1]\n        return image\n\n\nclass RandomErasing(object):\n    \"\"\"Random erasing augmentation from `Random Erasing Data Augmentation (CVPR 2017)\n    <https://arxiv.org/pdf/1708.04896.pdf>`_. This augmentation randomly selects a rectangle region in an image\n    and erases its pixels.\n\n    Args:\n         probability (float): The probability that the Random Erasing operation will be performed.\n         sl (float): Minimum proportion of erased area against input image.\n         sh (float): Maximum proportion of erased area against input image.\n         r1 (float): Minimum aspect ratio of erased area.\n         mean (sequence): Value to fill the erased area.\n    \"\"\"\n\n    def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):\n        self.probability = probability\n        self.mean = mean\n        self.sl = sl\n        self.sh = sh\n        self.r1 = r1\n\n    def __call__(self, img):\n\n        if random.uniform(0, 1) >= self.probability:\n            return img\n\n        for attempt in range(100):\n            area = img.size()[1] * img.size()[2]\n\n            target_area = random.uniform(self.sl, self.sh) * area\n            aspect_ratio = random.uniform(self.r1, 1 / self.r1)\n\n            h = int(round(math.sqrt(target_area * aspect_ratio)))\n            w = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if w < img.size()[2] and h < img.size()[1]:\n                x1 = random.randint(0, img.size()[1] - h)\n                y1 = random.randint(0, img.size()[2] - w)\n                if img.size()[0] == 3:\n                    img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]\n                    img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]\n                    img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]\n                else:\n                    img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]\n                return img\n\n        return img\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(p={})'.format(self.probability)\n"
  },
  {
    "path": "tllib/vision/transforms/keypoint_detection.py",
    "content": "\"\"\"\r\n@author: Junguang Jiang\r\n@contact: JiangJunguang1123@outlook.com\r\n\"\"\"\r\n# TODO needs better documentation\r\nimport numpy as np\r\nfrom PIL import ImageFilter, Image\r\nimport torchvision.transforms.functional as F\r\nimport torchvision.transforms.transforms as T\r\nimport numbers\r\nimport random\r\nimport math\r\nimport warnings\r\nfrom typing import ClassVar\r\n\r\n\r\ndef wrapper(transform: ClassVar):\r\n    \"\"\" Wrap a transform for classification to a transform for keypoint detection.\r\n    Note that the keypoint detection label will keep the same before and after wrapper.\r\n\r\n    Args:\r\n        transform (class, callable): transform for classification\r\n\r\n    Returns:\r\n        transform for keypoint detection\r\n    \"\"\"\r\n    class WrapperTransform(transform):\r\n        def __call__(self, image, **kwargs):\r\n            image = super().__call__(image)\r\n            return image, kwargs\r\n    return WrapperTransform\r\n\r\n\r\nToTensor = wrapper(T.ToTensor)\r\nNormalize = wrapper(T.Normalize)\r\nColorJitter = wrapper(T.ColorJitter)\r\n\r\n\r\ndef resize(image: Image.Image, size: int, interpolation=Image.BILINEAR,\r\n           keypoint2d: np.ndarray=None, intrinsic_matrix: np.ndarray=None):\r\n    width, height = image.size\r\n    assert width == height\r\n    factor = float(size) / float(width)\r\n    image = F.resize(image, size, interpolation)\r\n    keypoint2d = np.copy(keypoint2d)\r\n    keypoint2d *= factor\r\n    intrinsic_matrix = np.copy(intrinsic_matrix)\r\n    intrinsic_matrix[0][0] *= factor\r\n    intrinsic_matrix[0][2] *= factor\r\n    intrinsic_matrix[1][1] *= factor\r\n    intrinsic_matrix[1][2] *= factor\r\n    return image, keypoint2d, intrinsic_matrix\r\n\r\n\r\ndef crop(image: Image.Image, top, left, height, width, keypoint2d: np.ndarray):\r\n    image = F.crop(image, top, left, height, width)\r\n    keypoint2d = np.copy(keypoint2d)\r\n    keypoint2d[:, 0] -= left\r\n    keypoint2d[:, 1] -= top\r\n    return image, keypoint2d\r\n\r\n\r\ndef resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR,\r\n                 keypoint2d: np.ndarray=None, intrinsic_matrix: np.ndarray=None):\r\n    \"\"\"Crop the given PIL Image and resize it to desired size.\r\n\r\n    Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.\r\n\r\n    Args:\r\n        img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.\r\n        top (int): Vertical component of the top left corner of the crop box.\r\n        left (int): Horizontal component of the top left corner of the crop box.\r\n        height (int): Height of the crop box.\r\n        width (int): Width of the crop box.\r\n        size (sequence or int): Desired output size. Same semantics as ``resize``.\r\n        interpolation (int, optional): Desired interpolation. Default is\r\n            ``PIL.Image.BILINEAR``.\r\n    Returns:\r\n        PIL Image: Cropped image.\r\n    \"\"\"\r\n    assert isinstance(img, Image.Image), 'img should be PIL Image'\r\n    img, keypoint2d = crop(img, top, left, height, width, keypoint2d)\r\n    img, keypoint2d, intrinsic_matrix = resize(img, size, interpolation, keypoint2d, intrinsic_matrix)\r\n    return img, keypoint2d, intrinsic_matrix\r\n\r\n\r\ndef center_crop(image, output_size, keypoint2d: np.ndarray):\r\n    \"\"\"Crop the given PIL Image and resize it to desired size.\r\n\r\n    Args:\r\n        img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.\r\n        output_size (sequence or int): (height, width) of the crop box. If int,\r\n            it is used for both directions\r\n\r\n    Returns:\r\n        PIL Image: Cropped image.\r\n    \"\"\"\r\n    width, height = image.size\r\n    crop_height, crop_width = output_size\r\n    crop_top = int(round((height - crop_height) / 2.))\r\n    crop_left = int(round((width - crop_width) / 2.))\r\n    return crop(image, crop_top, crop_left, crop_height, crop_width, keypoint2d)\r\n\r\n\r\ndef hflip(image: Image.Image, keypoint2d: np.ndarray):\r\n    width, height = image.size\r\n    image = F.hflip(image)\r\n    keypoint2d = np.copy(keypoint2d)\r\n    keypoint2d[:, 0] = width - 1. - keypoint2d[:, 0]\r\n    return image, keypoint2d\r\n\r\n\r\ndef rotate(image: Image.Image, angle, keypoint2d: np.ndarray):\r\n    image = F.rotate(image, angle)\r\n\r\n    angle = -np.deg2rad(angle)\r\n    keypoint2d = np.copy(keypoint2d)\r\n    rotation_matrix = np.array([\r\n        [np.cos(angle), -np.sin(angle)],\r\n        [np.sin(angle), np.cos(angle)]\r\n    ])\r\n    width, height = image.size\r\n    keypoint2d[:, 0] = keypoint2d[:, 0] - width / 2\r\n    keypoint2d[:, 1] = keypoint2d[:, 1] - height / 2\r\n    keypoint2d = np.matmul(rotation_matrix, keypoint2d.T).T\r\n    keypoint2d[:, 0] = keypoint2d[:, 0] + width / 2\r\n    keypoint2d[:, 1] = keypoint2d[:, 1] + height / 2\r\n    return image, keypoint2d\r\n\r\n\r\ndef resize_pad(img, keypoint2d, size, interpolation=Image.BILINEAR):\r\n    w, h = img.size\r\n    if w < h:\r\n        oh = size\r\n        ow = int(size * w / h)\r\n        img = img.resize((ow, oh), interpolation)\r\n        pad_top = pad_bottom = 0\r\n        pad_left = math.floor((size - ow) / 2)\r\n        pad_right = math.ceil((size - ow) / 2)\r\n        keypoint2d = keypoint2d * oh / h\r\n        keypoint2d[:, 0] += (size - ow) / 2\r\n    else:\r\n        ow = size\r\n        oh = int(size * h / w)\r\n        img = img.resize((ow, oh), interpolation)\r\n        pad_top = math.floor((size - oh) / 2)\r\n        pad_bottom = math.ceil((size - oh) / 2)\r\n        pad_left = pad_right = 0\r\n        keypoint2d = keypoint2d * ow / w\r\n        keypoint2d[:, 1] += (size - oh) / 2\r\n        keypoint2d[:, 0] += (size - ow) / 2\r\n    img = np.asarray(img)\r\n\r\n    img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), 'constant', constant_values=0)\r\n    return Image.fromarray(img), keypoint2d\r\n\r\n\r\nclass Compose(object):\r\n    \"\"\"Composes several transforms together.\r\n\r\n    Args:\r\n        transforms (list of ``Transform`` objects): list of transforms to compose.\r\n    \"\"\"\r\n\r\n    def __init__(self, transforms):\r\n        self.transforms = transforms\r\n\r\n    def __call__(self, image, **kwargs):\r\n        for t in self.transforms:\r\n            image, kwargs = t(image, **kwargs)\r\n        return image, kwargs\r\n\r\n\r\nclass GaussianBlur(object):\r\n    def __init__(self, low=0, high=0.8):\r\n        self.low = low\r\n        self.high = high\r\n\r\n    def __call__(self, image: Image, **kwargs):\r\n        radius = np.random.uniform(low=self.low, high=self.high)\r\n        image = image.filter(ImageFilter.GaussianBlur(radius))\r\n        return image, kwargs\r\n\r\n\r\nclass Resize(object):\r\n    \"\"\"Resize the input PIL Image to the given size.\r\n    \"\"\"\r\n\r\n    def __init__(self, size, interpolation=Image.BILINEAR):\r\n        assert isinstance(size, int)\r\n        self.size = size\r\n        self.interpolation = interpolation\r\n\r\n    def __call__(self, image, keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, **kwargs):\r\n        image, keypoint2d, intrinsic_matrix = resize(image, self.size, self.interpolation, keypoint2d, intrinsic_matrix)\r\n        kwargs.update(keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)\r\n        if 'depth' in kwargs:\r\n            kwargs['depth'] = F.resize(kwargs['depth'], self.size)\r\n        return image, kwargs\r\n\r\n\r\nclass ResizePad(object):\r\n    \"\"\"Pad the given image on all sides with the given \"pad\" value to resize the image to the given size.\r\n    \"\"\"\r\n    def __init__(self, size, interpolation=Image.BILINEAR):\r\n        self.size = size\r\n        self.interpolation = interpolation\r\n\r\n    def __call__(self, img, keypoint2d, **kwargs):\r\n        image, keypoint2d = resize_pad(img, keypoint2d, self.size, self.interpolation)\r\n        kwargs.update(keypoint2d=keypoint2d)\r\n        return image, kwargs\r\n\r\n\r\nclass CenterCrop(object):\r\n    \"\"\"Crops the given PIL Image at the center.\r\n    \"\"\"\r\n\r\n    def __init__(self, size):\r\n        if isinstance(size, numbers.Number):\r\n            self.size = (int(size), int(size))\r\n        else:\r\n            self.size = size\r\n\r\n    def __call__(self, image, keypoint2d, **kwargs):\r\n        \"\"\"\r\n        Args:\r\n            img (PIL Image): Image to be cropped.\r\n\r\n        Returns:\r\n            PIL Image: Cropped image.\r\n        \"\"\"\r\n        image, keypoint2d = center_crop(image, self.size, keypoint2d)\r\n        kwargs.update(keypoint2d=keypoint2d)\r\n        if 'depth' in kwargs:\r\n            kwargs['depth'] = F.center_crop(kwargs['depth'], self.size)\r\n        return image, kwargs\r\n\r\n\r\nclass RandomRotation(object):\r\n    \"\"\"Rotate the image by angle.\r\n\r\n    Args:\r\n        degrees (sequence or float or int): Range of degrees to select from.\r\n            If degrees is a number instead of sequence like (min, max), the range of degrees\r\n            will be (-degrees, +degrees).\r\n    \"\"\"\r\n\r\n    def __init__(self, degrees):\r\n        if isinstance(degrees, numbers.Number):\r\n            if degrees < 0:\r\n                raise ValueError(\"If degrees is a single number, it must be positive.\")\r\n            self.degrees = (-degrees, degrees)\r\n        else:\r\n            if len(degrees) != 2:\r\n                raise ValueError(\"If degrees is a sequence, it must be of len 2.\")\r\n            self.degrees = degrees\r\n\r\n\r\n    @staticmethod\r\n    def get_params(degrees):\r\n        \"\"\"Get parameters for ``rotate`` for a random rotation.\r\n\r\n        Returns:\r\n            sequence: params to be passed to ``rotate`` for random rotation.\r\n        \"\"\"\r\n        angle = random.uniform(degrees[0], degrees[1])\r\n\r\n        return angle\r\n\r\n    def __call__(self, image, keypoint2d, **kwargs):\r\n        \"\"\"\r\n        Args:\r\n            img (PIL Image): Image to be rotated.\r\n\r\n        Returns:\r\n            PIL Image: Rotated image.\r\n        \"\"\"\r\n\r\n        angle = self.get_params(self.degrees)\r\n\r\n        image, keypoint2d = rotate(image, angle, keypoint2d)\r\n        kwargs.update(keypoint2d=keypoint2d)\r\n        if 'depth' in kwargs:\r\n            kwargs['depth'] = F.rotate(kwargs['depth'], angle)\r\n        return image, kwargs\r\n\r\n\r\nclass RandomResizedCrop(object):\r\n    \"\"\"Crop the given PIL Image to random size and aspect ratio.\r\n\r\n    A crop of random size (default: of 0.08 to 1.0) of the original size and a random\r\n    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop\r\n    is finally resized to given size.\r\n    This is popularly used to train the Inception networks.\r\n\r\n    Args:\r\n        size: expected output size of each edge\r\n        scale: range of size of the origin size cropped\r\n        ratio: range of aspect ratio of the origin aspect ratio cropped\r\n        interpolation: Default: PIL.Image.BILINEAR\r\n    \"\"\"\r\n\r\n    def __init__(self, size, scale=(0.6, 1.3), interpolation=Image.BILINEAR):\r\n        self.size = size\r\n        if scale[0] > scale[1]:\r\n            warnings.warn(\"range should be of kind (min, max)\")\r\n\r\n        self.interpolation = interpolation\r\n        self.scale = scale\r\n\r\n    @staticmethod\r\n    def get_params(img, scale):\r\n        \"\"\"Get parameters for ``crop`` for a random sized crop.\r\n\r\n        Args:\r\n            img (PIL Image): Image to be cropped.\r\n            scale (tuple): range of size of the origin size cropped\r\n\r\n        Returns:\r\n            tuple: params (i, j, h, w) to be passed to ``crop`` for a random\r\n                sized crop.\r\n        \"\"\"\r\n        width, height = img.size\r\n        area = height * width\r\n\r\n        for attempt in range(10):\r\n            target_area = random.uniform(*scale) * area\r\n            aspect_ratio = 1\r\n\r\n            w = int(round(math.sqrt(target_area * aspect_ratio)))\r\n            h = int(round(math.sqrt(target_area / aspect_ratio)))\r\n\r\n            if 0 < w <= width and 0 < h <= height:\r\n                i = random.randint(0, height - h)\r\n                j = random.randint(0, width - w)\r\n                return i, j, h, w\r\n\r\n        # Fallback to whole image\r\n        return 0, 0, height, width\r\n\r\n    def __call__(self, image, keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, **kwargs):\r\n        \"\"\"\r\n        Args:\r\n            img (PIL Image): Image to be cropped and resized.\r\n\r\n        Returns:\r\n            PIL Image: Randomly cropped and resized image.\r\n        \"\"\"\r\n        i, j, h, w = self.get_params(image, self.scale)\r\n        image, keypoint2d, intrinsic_matrix = resized_crop(image, i, j, h, w, self.size, self.interpolation, keypoint2d, intrinsic_matrix)\r\n        kwargs.update(keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)\r\n        if 'depth' in kwargs:\r\n            kwargs['depth'] = F.resized_crop(kwargs['depth'], i, j, h, w, self.size, self.interpolation,)\r\n        return image, kwargs\r\n\r\n\r\nclass RandomApply(T.RandomTransforms):\r\n    \"\"\"Apply randomly a list of transformations with a given probability.\r\n\r\n    Args:\r\n        transforms (list or tuple or torch.nn.Module): list of transformations\r\n        p (float): probability\r\n    \"\"\"\r\n\r\n    def __init__(self, transforms, p=0.5):\r\n        super(RandomApply, self).__init__(transforms)\r\n        self.p = p\r\n\r\n    def __call__(self, image, **kwargs):\r\n        if self.p < random.random():\r\n            return image, kwargs\r\n        for t in self.transforms:\r\n            image, kwargs = t(image, **kwargs)\r\n        return image, kwargs\r\n"
  },
  {
    "path": "tllib/vision/transforms/segmentation.py",
    "content": "\"\"\"\n@author: Junguang Jiang\n@contact: JiangJunguang1123@outlook.com\n\"\"\"\nfrom PIL import Image\nimport random\nimport math\nfrom typing import ClassVar, Sequence, List, Tuple\nfrom torch import Tensor\nimport torch\nimport torchvision.transforms.functional as F\nimport torchvision.transforms.transforms as T\nimport torch.nn as nn\nfrom . import MultipleApply as MultipleApplyBase, NormalizeAndTranspose as NormalizeAndTransposeBase\n\n\ndef wrapper(transform: ClassVar):\n    \"\"\" Wrap a transform for classification to a transform for segmentation.\n    Note that the segmentation label will keep the same before and after wrapper.\n\n    Args:\n        transform (class, callable): transform for classification\n\n    Returns:\n        transform for segmentation\n    \"\"\"\n    class WrapperTransform(transform):\n        def __call__(self, image, label):\n            image = super().__call__(image)\n            return image, label\n    return WrapperTransform\n\n\nColorJitter = wrapper(T.ColorJitter)\nNormalize = wrapper(T.Normalize)\nToTensor = wrapper(T.ToTensor)\nToPILImage = wrapper(T.ToPILImage)\nMultipleApply = wrapper(MultipleApplyBase)\nNormalizeAndTranspose = wrapper(NormalizeAndTransposeBase)\n\n\nclass Compose:\n    \"\"\"Composes several transforms together.\n\n    Args:\n        transforms (list): list of transforms to compose.\n\n    Example:\n        >>> Compose([\n        >>>     Resize((512, 512)),\n        >>>     RandomHorizontalFlip()\n        >>> ])\n    \"\"\"\n    def __init__(self, transforms):\n        super(Compose, self).__init__()\n        self.transforms = transforms\n\n    def __call__(self, image, target):\n        for t in self.transforms:\n            image, target = t(image, target)\n        return image, target\n\n\nclass Resize(nn.Module):\n    \"\"\"Resize the input image and the corresponding label to the given size.\n    The image should be a PIL Image.\n\n    Args:\n        image_size (sequence): The requested image size in pixels, as a 2-tuple:\n          (width, height).\n        label_size (sequence, optional): The requested segmentation label size in pixels, as a 2-tuple:\n          (width, height). The same as image_size if None. Default: None.\n    \"\"\"\n\n    def __init__(self, image_size, label_size=None):\n        super(Resize, self).__init__()\n        self.image_size = image_size\n        if label_size is None:\n            self.label_size = image_size\n        else:\n            self.label_size = label_size\n\n    def forward(self, image, label):\n        \"\"\"\n        Args:\n            image: (PIL Image): Image to be scaled.\n            label: (PIL Image): Segmentation label to be scaled.\n\n        Returns:\n            Rescaled image, rescaled segmentation label\n        \"\"\"\n        # resize\n        image = image.resize(self.image_size, Image.BICUBIC)\n        label = label.resize(self.label_size, Image.NEAREST)\n        return image, label\n\n\nclass RandomCrop(nn.Module):\n    \"\"\"Crop the given image at a random location.\n    The image can be a PIL Image\n\n    Args:\n        size (sequence): Desired output size of the crop.\n    \"\"\"\n    def __init__(self, size):\n        super(RandomCrop, self).__init__()\n        self.size = size\n\n    def forward(self, image, label):\n        \"\"\"\n        Args:\n            image: (PIL Image): Image to be cropped.\n            label: (PIL Image): Segmentation label to be cropped.\n\n        Returns:\n            Cropped image, cropped segmentation label.\n        \"\"\"\n        # random crop\n        left = image.size[0] - self.size[0]\n        upper = image.size[1] - self.size[1]\n\n        left = random.randint(0, left-1)\n        upper = random.randint(0, upper-1)\n        right = left + self.size[0]\n        lower = upper + self.size[1]\n\n        image = image.crop((left, upper, right, lower))\n        label = label.crop((left, upper, right, lower))\n        return image, label\n\n\nclass RandomHorizontalFlip(nn.Module):\n    \"\"\"Horizontally flip the given PIL Image randomly with a given probability.\n\n    Args:\n        p (float): probability of the image being flipped. Default value is 0.5\n    \"\"\"\n\n    def __init__(self, p=0.5):\n        super(RandomHorizontalFlip, self).__init__()\n        self.p = p\n\n    def forward(self, image, label):\n        \"\"\"\n        Args:\n            image: (PIL Image): Image to be flipped.\n            label: (PIL Image): Segmentation label to be flipped.\n\n        Returns:\n            Randomly flipped image, randomly flipped segmentation label.\n        \"\"\"\n        if random.random() < self.p:\n            return F.hflip(image), F.hflip(label)\n        return image, label\n\n\nclass RandomResizedCrop(T.RandomResizedCrop):\n    \"\"\"Crop the given image to random size and aspect ratio.\n    The image can be a PIL Image.\n\n    A crop of random size (default: of 0.5 to 1.0) of the original size and a random\n    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop\n    is finally resized to given size.\n\n    Args:\n        size (int or sequence): expected output size of each edge. If size is an\n          int instead of sequence like (h, w), a square output size ``(size, size)`` is\n          made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).\n        scale (tuple of float): range of size of the origin size cropped\n        ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.\n        interpolation: Default: PIL.Image.BILINEAR\n    \"\"\"\n\n    def __init__(self, size, scale=(0.5, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BICUBIC):\n        super(RandomResizedCrop, self).__init__(size, scale, ratio, interpolation)\n\n    @staticmethod\n    def get_params(\n            img: Tensor, scale: List[float], ratio: List[float]\n    ) -> Tuple[int, int, int, int]:\n        \"\"\"Get parameters for ``crop`` for a random sized crop.\n\n        Args:\n            img (PIL Image): Input image.\n            scale (list): range of scale of the origin size cropped\n            ratio (list): range of aspect ratio of the origin aspect ratio cropped\n\n        Returns:\n            params (i, j, h, w) to be passed to ``crop`` for a random sized crop.\n        \"\"\"\n        width, height = F._get_image_size(img)\n        area = height * width\n\n        for _ in range(10):\n            target_area = area * random.uniform(scale[0], scale[1])\n            log_ratio = torch.log(torch.tensor(ratio))\n            aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1]))\n\n            w = int(round(math.sqrt(target_area * aspect_ratio)))\n            h = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if 0 < w <= width and 0 < h <= height:\n                i = random.randint(0, height - h)\n                j = random.randint(0, width - w)\n                return i, j, h, w\n\n        # Fallback to central crop\n        in_ratio = float(width) / float(height)\n        if in_ratio < min(ratio):\n            w = width\n            h = int(round(w / min(ratio)))\n        elif in_ratio > max(ratio):\n            h = height\n            w = int(round(h * max(ratio)))\n        else:  # whole image\n            w = width\n            h = height\n        i = (height - h) // 2\n        j = (width - w) // 2\n        return i, j, h, w\n\n    def forward(self, image, label):\n        \"\"\"\n        Args:\n            image: (PIL Image): Image to be cropped and resized.\n            label: (PIL Image): Segmentation label to be cropped and resized.\n\n        Returns:\n            Randomly cropped and resized image, randomly cropped and resized segmentation label.\n        \"\"\"\n        top, left, height, width = self.get_params(image, self.scale, self.ratio)\n        image = image.crop((left, top, left + width, top + height))\n        image = image.resize(self.size, self.interpolation)\n        label = label.crop((left, top, left + width, top + height))\n        label = label.resize(self.size, Image.NEAREST)\n        return image, label\n\n\nclass RandomChoice(T.RandomTransforms):\n    \"\"\"Apply single transformation randomly picked from a list.\n    \"\"\"\n    def __call__(self, image, label):\n        t = random.choice(self.transforms)\n        return t(image, label)\n\n\nclass RandomApply(T.RandomTransforms):\n    \"\"\"Apply randomly a list of transformations with a given probability.\n\n    Args:\n        transforms (list or tuple or torch.nn.Module): list of transformations\n        p (float): probability\n    \"\"\"\n\n    def __init__(self, transforms, p=0.5):\n        super(RandomApply, self).__init__(transforms)\n        self.p = p\n\n    def __call__(self, image, label):\n        if self.p < random.random():\n            return image\n        for t in self.transforms:\n            image, label = t(image, label)\n        return image\n"
  }
]