[
  {
    "path": "Dassl.ProGrad.pytorch/.flake8",
    "content": "[flake8]\nignore =\n    # At least two spaces before inline comment\n    E261,\n    # Line lengths are recommended to be no greater than 79 characters\n    E501,\n    # Missing whitespace around arithmetic operator \n    E226,\n    # Blank line contains whitespace\n    W293,\n    # Do not use bare 'except'\n    E722,\n    # Line break after binary operator\n    W504,\n    # Too many leading '#' for block comment\n    E266,\n    # line break before binary operator\n    W503,\n    # continuation line over-indented for hanging indent\n    E126\nmax-line-length = 79\nexclude = __init__.py, build"
  },
  {
    "path": "Dassl.ProGrad.pytorch/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# OS X\n.DS_Store\n.Spotlight-V100\n.Trashes\n._*\n\n# This project\noutput/\ndebug.sh\ndebug.py\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/.isort.cfg",
    "content": "[isort]\nline_length=79\nmulti_line_output=6\nlength_sort=true\nknown_standard_library=numpy,setuptools\nknown_myself=dassl\nknown_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown\nno_lines_before=STDLIB,THIRDPARTY\nsections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER\ndefault_section=FIRSTPARTY"
  },
  {
    "path": "Dassl.ProGrad.pytorch/.style.yapf",
    "content": "[style]\nBASED_ON_STYLE = pep8\nBLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true\nSPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true\nDEDENT_CLOSING_BRACKETS = true\nSPACES_BEFORE_COMMENT = 2\nARITHMETIC_PRECEDENCE_INDICATION = true"
  },
  {
    "path": "Dassl.ProGrad.pytorch/DATASETS.md",
    "content": "# How to Install Datasets\n\n`$DATA` denotes the location where datasets are installed, e.g.\n\n```\n$DATA/\n|–– office31/\n|–– office_home/\n|–– visda17/\n```\n\n[Domain Adaptation](#domain-adaptation)\n- [Office-31](#office-31)\n- [Office-Home](#office-home)\n- [VisDA17](#visda17)\n- [CIFAR10-STL10](#cifar10-stl10)\n- [Digit-5](#digit-5)\n- [DomainNet](#domainnet)\n- [miniDomainNet](#miniDomainNet)\n\n[Domain Generalization](#domain-generalization)\n- [PACS](#pacs)\n- [VLCS](#vlcs)\n- [Office-Home-DG](#office-home-dg)\n- [Digits-DG](#digits-dg)\n- [Digit-Single](#digit-single)\n- [CIFAR-10-C](#cifar-10-c)\n- [CIFAR-100-C](#cifar-100-c)\n\n[Semi-Supervised Learning](#semi-supervised-learning)\n- [CIFAR10/100 and SVHN](#cifar10100-and-svhn)\n- [STL10](#stl10)\n\n## Domain Adaptation\n\n### Office-31\n\nDownload link: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/#datasets_code.\n\nFile structure:\n\n```\noffice31/\n|–– amazon/\n|   |–– back_pack/\n|   |–– bike/\n|   |–– ...\n|–– dslr/\n|   |–– back_pack/\n|   |–– bike/\n|   |–– ...\n|–– webcam/\n|   |–– back_pack/\n|   |–– bike/\n|   |–– ...\n```\n\nNote that within each domain folder you need to move all class folders out of the `images/` folder and then delete the `images/` folder.\n\n### Office-Home\n\nDownload link: http://hemanthdv.org/OfficeHome-Dataset/.\n\nFile structure:\n\n```\noffice_home/\n|–– art/\n|–– clipart/\n|–– product/\n|–– real_world/\n```\n\n### VisDA17\n\nDownload link: http://ai.bu.edu/visda-2017/.\n\nThe dataset can also be downloaded using our script at `datasets/da/visda17.sh`. Run the following command in your terminal under `Dassl.pytorch/datasets/da`,\n\n```bash\nsh visda17.sh $DATA\n```\n\nOnce the download is finished, the file structure will look like\n\n```\nvisda17/\n|–– train/\n|–– test/\n|–– validation/\n```\n\n### CIFAR10-STL10\n\nRun the following command in your terminal under `Dassl.pytorch/datasets/da`,\n\n```bash\npython cifar_stl.py $DATA/cifar_stl\n```\n\nThis will create a folder named `cifar_stl` under `$DATA`. The file structure will look like\n\n```\ncifar_stl/\n|–– cifar/\n|   |–– train/\n|   |–– test/\n|–– stl/\n|   |–– train/\n|   |–– test/\n```\n\nNote that only 9 classes shared by both datasets are kept.\n\n### Digit-5\n\nCreate a folder `$DATA/digit5` and download to this folder the dataset from [here](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download). This should give you\n\n```\ndigit5/\n|–– Digit-Five/\n```\n\nThen, run the following command in your terminal under `Dassl.pytorch/datasets/da`,\n\n```bash \npython digit5.py $DATA/digit5\n```\n\nThis will extract the data and organize the file structure as\n\n```\ndigit5/\n|–– Digit-Five/\n|–– mnist/\n|–– mnist_m/\n|–– usps/\n|–– svhn/\n|–– syn/\n```\n\n### DomainNet\n\nDownload link: http://ai.bu.edu/M3SDA/. (Please download the cleaned version of split files)\n\nFile structure:\n\n```\ndomainnet/\n|–– clipart/\n|–– infograph/\n|–– painting/\n|–– quickdraw/\n|–– real/\n|–– sketch/\n|–– splits/\n|   |–– clipart_train.txt\n|   |–– clipart_test.txt\n|   |–– ...\n```\n\n### miniDomainNet\n\nYou need to download the DomainNet dataset first. The miniDomainNet's split files can be downloaded at this [google drive](https://drive.google.com/open?id=15rrLDCrzyi6ZY-1vJar3u7plgLe4COL7). After the zip file is extracted, you should have the folder `$DATA/domainnet/splits_mini/`.\n\n## Domain Generalization\n\n### PACS\n\nDownload link: [google drive](https://drive.google.com/open?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE).\n\nFile structure:\n\n```\npacs/\n|–– images/\n|–– splits/\n```\n\nYou do not necessarily have to manually download this dataset. Once you run ``tools/train.py``, the code will detect if the dataset exists or not and automatically download the dataset to ``$DATA`` if missing. This also applies to VLCS, Office-Home-DG, and Digits-DG.\n\n### VLCS\n\nDownload link: [google drive](https://drive.google.com/file/d/1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd/view?usp=sharing) (credit to https://github.com/fmcarlucci/JigenDG#vlcs)\n\nFile structure:\n\n```\nVLCS/\n|–– CALTECH/\n|–– LABELME/\n|–– PASCAL/\n|–– SUN/\n```\n\n### Office-Home-DG\n\nDownload link: [google drive](https://drive.google.com/open?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa).\n\nFile structure:\n\n```\noffice_home_dg/\n|–– art/\n|–– clipart/\n|–– product/\n|–– real_world/\n```\n\n### Digits-DG\n\nDownload link: [google driv](https://drive.google.com/open?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7).\n\nFile structure:\n\n```\ndigits_dg/\n|–– mnist/\n|–– mnist_m/\n|–– svhn/\n|–– syn/\n```\n\n### Digit-Single\nFollow the steps for [Digit-5](#digit-5) to organize the dataset.\n\n### CIFAR-10-C\n\nFirst download the CIFAR-10-C dataset from https://zenodo.org/record/2535967#.YFxHEWQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal\n```bash\npython cifar_c.py $DATA/CIFAR-10-C\n```\nwhere the first argument denotes the path to the (uncompressed) CIFAR-10-C dataset.\n\nThe script will extract images from the `.npy` files and save them to `cifar10_c/` created under $DATA. The file structure will look like\n```\ncifar10_c/\n|–– brightness/\n|   |–– 1/ # 5 intensity levels in total\n|   |–– 2/\n|   |–– 3/\n|   |–– 4/\n|   |–– 5/\n|–– ... # 19 corruption types in total\n```\n\nNote that `cifar10_c/` only contains the test images. The training images are the normal CIFAR-10 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-10 dataset.\n\n### CIFAR-100-C\n\nFirst download the CIFAR-100-C dataset from https://zenodo.org/record/3555552#.YFxpQmQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal\n```bash\npython cifar_c.py $DATA/CIFAR-100-C\n```\nwhere the first argument denotes the path to the (uncompressed) CIFAR-100-C dataset.\n\nThe script will extract images from the `.npy` files and save them to `cifar100_c/` created under $DATA. The file structure will look like\n```\ncifar100_c/\n|–– brightness/\n|   |–– 1/ # 5 intensity levels in total\n|   |–– 2/\n|   |–– 3/\n|   |–– 4/\n|   |–– 5/\n|–– ... # 19 corruption types in total\n```\n\nNote that `cifar100_c/` only contains the test images. The training images are the normal CIFAR-100 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-100 dataset.\n\n## Semi-Supervised Learning\n\n### CIFAR10/100 and SVHN\n\nRun the following command in your terminal under `Dassl.pytorch/datasets/ssl`,\n\n```bash\npython cifar10_cifar100_svhn.py $DATA\n```\n\nThis will create three folders under `$DATA`, i.e.\n\n```\ncifar10/\n|–– train/\n|–– test/\ncifar100/\n|–– train/\n|–– test/\nsvhn/\n|–– train/\n|–– test/\n```\n\n### STL10\n\nRun the following command in your terminal under `Dassl.pytorch/datasets/ssl`,\n\n```bash\npython stl10.py $DATA/stl10\n```\n\nThis will create a folder named `stl10` under `$DATA` and extract the data into three folders, i.e. `train`, `test` and `unlabeled`. Then, download from http://ai.stanford.edu/~acoates/stl10/ the \"Binary files\" and extract it under `stl10`.\n\nThe file structure will look like\n\n```\nstl10/\n|–– train/\n|–– test/\n|–– unlabeled/\n|–– stl10_binary/\n```"
  },
  {
    "path": "Dassl.ProGrad.pytorch/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 Kaiyang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/README.md",
    "content": "# Dassl\n\n## Introduction\n\nDassl is a [PyTorch](https://pytorch.org) toolbox initially developed for our project [Domain Adaptive Ensemble Learning (DAEL)](https://arxiv.org/abs/2003.07325) to support research in domain adaptation and generalization---since in DAEL we study how to unify these two problems in a single learning framework. Given that domain adaptation is closely related to semi-supervised learning---both study how to exploit unlabeled data---we also incorporate components that support research for the latter.\n\nWhy the name \"Dassl\"? Dassl combines the initials of domain adaptation (DA) and semi-supervised learning (SSL), which sounds natural and informative.\n\nDassl has a modular design and unified interfaces, allowing fast prototyping and experimentation of new DA/DG/SSL methods. With Dassl, a new method can be implemented with only a few lines of code. Don't believe? Take a look at the [engine](https://github.com/KaiyangZhou/Dassl.pytorch/tree/master/dassl/engine) folder, which contains the implementations of many existing methods (then you will come back and star this repo). :-)\n\nBasically, Dassl is perfect for doing research in the following areas:\n- Domain adaptation\n- Domain generalization\n- Semi-supervised learning\n\nBUT, thanks to the neat design, Dassl can also be used as a codebase to develop any deep learning projects, like [this](https://github.com/KaiyangZhou/CoOp). :-)\n\nA drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU training (Dassl uses `DataParallel` to wrap a model, which is less efficient than `DistributedDataParallel`).\n\nWe don't provide detailed documentations for Dassl, unlike another [project](https://kaiyangzhou.github.io/deep-person-reid/) of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-)\n\n## What's new\n- Mar 2022: A new domain generalization method [EFDM](https://arxiv.org/abs/2203.07740) developed by [Yabin Zhang (PolyU)](https://ybzh.github.io/) and to appear at CVPR'22 is added to this repo. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/pull/36) for more details.\n- Feb 2022: In case you don't know, a class in the painting domain of DomainNet (the official splits) only has test images (no training images), which could affect performance. See section 4.a in our [paper](https://arxiv.org/abs/2003.07325) for more details.\n- Oct 2021: `v0.5.0`: **Important changes** made to `transforms.py`. 1) `center_crop` becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio). 2) For training, `Resize(cfg.INPUT.SIZE)` is deactivated when `random_crop` or `random_resized_crop` is used. These changes won't make any difference to the training transforms used in existing config files, nor to the testing transforms unless the raw images are not squared (the only difference is that now the image aspect ratio is respected).\n- Oct 2021: `v0.4.3`: Copy the attributes in `self.dm` (data manager) to `SimpleTrainer` and make `self.dm` optional, which means from now on, you can build data loaders from any source you like rather than being forced to use `DataManager`.\n- Sep 2021: `v0.4.2`: An important update is to set `drop_last=is_train and len(data_source)>=batch_size` when constructing a data loader to avoid 0-length.\n\n<details>\n    <summary>More</summary>\n\n- Aug 2021: `v0.4.0`: The most noteworthy update is adding the learning rate warmup scheduler. The implementation is detailed [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/optim/lr_scheduler.py#L10) and the config variables are specified [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L171).\n- Jul 2021: `v0.3.4`: Adds a new function `generate_fewshot_dataset()` to the base dataset class, which allows for the generation of a few-shot learning setting. One can customize a few-shot dataset by specifying `_C.DATASET.NUM_SHOTS` and give it to `generate_fewshot_dataset()`.\n- Jul 2021: `v0.3.2`: Adds `_C.INPUT.INTERPOLATION` (default: `bilinear`). Available interpolation modes are `bilinear`, `nearest`, and `bicubic`.\n- Jul 2021 `v0.3.1`: Now you can use `*.register(force=True)` to replace previously registered modules.\n- Jul 2021 `v0.3.0`: Allows to deploy the model with the best validation performance for final test (for the purpose of model selection). Specifically, a new config variable named `_C.TEST.FINAL_MODEL` is introduced, which takes either `\"last_step\"` (default) or `\"best_val\"`. When set to `\"best_val\"`, the model will be evaluated on the `val` set after each epoch and the one with the best validation performance will be saved and used for final test (see this [code](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/engine/trainer.py#L412)).\n- Jul 2021 `v0.2.7`: Adds attribute `classnames` to the base dataset class. Now you can get a list of class names ordered by numeric labels by calling `trainer.dm.dataset.classnames`.\n- Jun 2021 `v0.2.6`: Merges `MixStyle2` to `MixStyle`. A new variable `self.mix` is used to switch between random mixing and cross-domain mixing. Please see [this](https://github.com/KaiyangZhou/Dassl.pytorch/issues/23) for more details on the new features.\n- Jun 2021 `v0.2.5`: Fixs a [bug](https://github.com/KaiyangZhou/Dassl.pytorch/commit/29881c7faee7405f80f5f674de4bbbf80d5dc77a) in the calculation of per-class recognition accuracy.\n- Jun 2021 `v0.2.4`: Adds `extend_cfg(cfg)` to `train.py`. This function is particularly useful when you build your own methods on top of Dassl.pytorch and need to define some custom variables. Please see the repository [mixstyle-release](https://github.com/KaiyangZhou/mixstyle-release) or [ssdg-benchmark](https://github.com/KaiyangZhou/ssdg-benchmark) for examples.\n- Jun 2021 New benchmarks for semi-supervised domain generalization at https://github.com/KaiyangZhou/ssdg-benchmark.\n- Apr 2021 Do you know you can use `tools/parse_test_res.py` to read the log files and automatically calculate and print out the results including mean and standard deviation? Check the instructions in `tools/parse_test_res.py` for more details.\n- Apr 2021 `v0.2.3`: A [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) layer can now be deactivated or activated by using `model.apply(deactivate_mixstyle)` or `model.apply(activate_mixstyle)` without modifying the source code. See [dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py) for the details.\n- Apr 2021 `v0.2.2`: Adds `RandomClassSampler`, which samples from a certain number of classes a certain number of images to form a minibatch (the code is modified from [Torchreid](https://github.com/KaiyangZhou/deep-person-reid)).\n- Apr 2021 `v0.2.1`: Slightly adjusts the ordering in `setup_cfg()` (see `tools/train.py`).\n- Apr 2021 `v0.2.0`: Adds `_C.DATASET.ALL_AS_UNLABELED` (for the SSL setting) to the config variable list. When this variable is set to `True`, all labeled data will be included in the unlabeled data set.\n- Apr 2021 `v0.1.9`: Adds [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf) to the benchmark datasets (see `dassl/data/datasets/dg/vlcs.py`).\n- Mar 2021 `v0.1.8`: Allows `optim` and `sched` to be `None` in `register_model()`.\n- Mar 2021 `v0.1.7`: Adds [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) models to [dassl/modeling/backbone/resnet.py](dassl/modeling/backbone/resnet.py). The training configs in `configs/trainers/dg/vanilla` can be used to train MixStyle models.\n- Mar 2021 `v0.1.6`: Adds [CIFAR-10/100-C](https://arxiv.org/abs/1807.01697) to the benchmark datasets for evaluating a model's robustness to image corruptions.\n- Mar 2021 We have just released a survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in this topic with coverage on the history, related problems, datasets, methodologies, potential directions, and so on.\n- Jan 2021 Our recent work, [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) (mixing instance-level feature statistics of samples of different domains for improving domain generalization), is accepted to ICLR'21. The code is available at https://github.com/KaiyangZhou/mixstyle-release where the cross-domain image classification part is based on Dassl.pytorch.\n- May 2020 `v0.1.3`: Adds the `Digit-Single` dataset for benchmarking single-source DG methods. The corresponding CNN model is [dassl/modeling/backbone/cnn_digitsingle.py](dassl/modeling/backbone/cnn_digitsingle.py) and the dataset config file is [configs/datasets/dg/digit_single.yaml](configs/datasets/dg/digit_single.yaml). See [Volpi et al. NIPS'18](https://arxiv.org/abs/1805.12018) for how to do evaluation.\n- May 2020 `v0.1.2`: 1) Adds [EfficientNet](https://arxiv.org/abs/1905.11946) models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, set `MODEL.BACKBONE.NAME` to `efficientnet_b{N}` where `N={0, ..., 7}`. 2) `dassl/modeling/models` is renamed to `dassl/modeling/network` (`build_model()` to `build_network()` and `MODEL_REGISTRY` to `NETWORK_RESIGTRY`).\n\n</details>\n\n## Overview\n\nDassl has implemented the following methods:\n\n- Single-source domain adaptation\n    - [Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)](https://arxiv.org/abs/1904.06487) [[dassl/engine/da/mme.py](dassl/engine/da/mme.py)]\n    - [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)](https://arxiv.org/abs/1712.02560https://arxiv.org/abs/1712.02560) [[dassl/engine/da/mcd.py](dassl/engine/da/mcd.py)]\n    - [Self-ensembling for visual domain adaptation (ICLR'18)](https://arxiv.org/abs/1706.05208) [[dassl/engine/da/self_ensembling.py](dassl/engine/da/self_ensembling.py)]\n    - [Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17)](https://arxiv.org/abs/1603.04779) [[dassl/engine/da/adabn.py](dassl/engine/da/adabn.py)]\n    - [Adversarial Discriminative Domain Adaptation (CVPR'17)](https://arxiv.org/abs/1702.05464) [[dassl/engine/da/adda.py](dassl/engine/da/adda.py)]\n    - [Domain-Adversarial Training of Neural Networks (JMLR'16) ](https://arxiv.org/abs/1505.07818) [[dassl/engine/da/dann.py](dassl/engine/da/dann.py)]\n\n- Multi-source domain adaptation\n    - [Domain Aadaptive Ensemble Learning](https://arxiv.org/abs/2003.07325) [[dassl/engine/da/dael.py](dassl/engine/da/dael.py)]\n    - [Moment Matching for Multi-Source Domain Adaptation (ICCV'19)](https://arxiv.org/abs/1812.01754) [[dassl/engine/da/m3sda.py](dassl/engine/da/m3sda.py)]\n\n- Domain generalization\n    - [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization (CVPR'22)](https://arxiv.org/abs/2203.07740) [[dassl/modeling/ops/efdmix.py](dassl/modeling/ops/efdmix.py)]\n    - [Domain Generalization with MixStyle (ICLR'21)](https://openreview.net/forum?id=6xHJ37MVxxp) [[dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py)]\n    - [Deep Domain-Adversarial Image Generation for Domain Generalisation (AAAI'20)](https://arxiv.org/abs/2003.06054) [[dassl/engine/dg/ddaig.py](dassl/engine/dg/ddaig.py)]\n    - [Generalizing Across Domains via Cross-Gradient Training (ICLR'18)](https://arxiv.org/abs/1804.10745) [[dassl/engine/dg/crossgrad.py](dassl/engine/dg/crossgrad.py)]\n\n- Semi-supervised learning\n    - [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence](https://arxiv.org/abs/2001.07685) [[dassl/engine/ssl/fixmatch.py](dassl/engine/ssl/fixmatch.py)]\n    - [MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19)](https://arxiv.org/abs/1905.02249) [[dassl/engine/ssl/mixmatch.py](dassl/engine/ssl/mixmatch.py)]\n    - [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17)](https://arxiv.org/abs/1703.01780) [[dassl/engine/ssl/mean_teacher.py](dassl/engine/ssl/mean_teacher.py)]\n    - [Semi-supervised Learning by Entropy Minimization (NeurIPS'04)](http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf) [[dassl/engine/ssl/entmin.py](dassl/engine/ssl/entmin.py)]\n\n*Feel free to make a [PR](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) to add your methods here to make it easier for others to benchmark!*\n\nDassl supports the following datasets:\n\n- Domain adaptation\n    - [Office-31](https://scalable.mpi-inf.mpg.de/files/2013/04/saenko_eccv_2010.pdf)\n    - [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)\n    - [VisDA17](http://ai.bu.edu/visda-2017/)\n    - [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)-[STL10](https://cs.stanford.edu/~acoates/stl10/)\n    - [Digit-5](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download)\n    - [DomainNet](http://ai.bu.edu/M3SDA/)\n    - [miniDomainNet](https://arxiv.org/abs/2003.07325)\n\n- Domain generalization\n    - [PACS](https://arxiv.org/abs/1710.03077)\n    - [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf)\n    - [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)\n    - [Digits-DG](https://arxiv.org/abs/2003.06054)\n    - [Digit-Single](https://arxiv.org/abs/1805.12018)\n    - [CIFAR-10-C](https://arxiv.org/abs/1807.01697)\n    - [CIFAR-100-C](https://arxiv.org/abs/1807.01697)\n\n- Semi-supervised learning\n    - [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html.)\n    - [SVHN](http://ufldl.stanford.edu/housenumbers/)\n    - [STL10](https://cs.stanford.edu/~acoates/stl10/)\n\n## Get started\n\n### Installation\n\nMake sure [conda](https://www.anaconda.com/distribution/) is installed properly.\n\n```bash\n# Clone this repo\ngit clone https://github.com/KaiyangZhou/Dassl.pytorch.git\ncd Dassl.pytorch/\n\n# Create a conda environment\nconda create -n dassl python=3.7\n\n# Activate the environment\nconda activate dassl\n\n# Install dependencies\npip install -r requirements.txt\n\n# Install torch (version >= 1.7.1) and torchvision\nconda install pytorch torchvision cudatoolkit=10.1 -c pytorch\n\n# Install this library (no need to re-build if the source code is modified)\npython setup.py develop\n```\n\nFollow the instructions in [DATASETS.md](./DATASETS.md) to preprocess the datasets.\n\n### Training\n\nThe main interface is implemented in `tools/train.py`, which basically does\n\n1. initialize the config with `cfg = setup_cfg(args)` where `args` contains the command-line input (see `tools/train.py` for the list of input arguments);\n2. instantiate a `trainer` with `build_trainer(cfg)` which loads the dataset and builds a deep neural network model;\n3. call `trainer.train()` for training and evaluating the model.\n\nBelow we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31,\n\n```bash\nCUDA_VISIBLE_DEVICES=0 python tools/train.py \\\n--root $DATA \\\n--trainer SourceOnly \\\n--source-domains amazon \\\n--target-domains webcam \\\n--dataset-config-file configs/datasets/da/office31.yaml \\\n--config-file configs/trainers/da/source_only/office31.yaml \\\n--output-dir output/source_only_office31\n```\n\n`$DATA` denotes the location where datasets are installed. `--dataset-config-file` loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. `--config-file` loads the algorithm-specific setting such as hyper-parameters and optimization parameters.\n\nTo use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to `--source-domains`. For instance, to train a source-only baseline on miniDomainNet, one can do\n\n```bash\nCUDA_VISIBLE_DEVICES=0 python tools/train.py \\\n--root $DATA \\\n--trainer SourceOnly \\\n--source-domains clipart painting real \\\n--target-domains sketch \\\n--dataset-config-file configs/datasets/da/mini_domainnet.yaml \\\n--config-file configs/trainers/da/source_only/mini_domainnet.yaml \\\n--output-dir output/source_only_minidn\n```\n\nAfter the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.\n\nTo print out the results saved in the log file (so you do not need to exhaustively go through all log files and calculate the mean/std by yourself), you can use `tools/parse_test_res.py`. The instruction can be found in the code.\n\nFor other trainers such as `MCD`, you can set `--trainer MCD` while keeping the config file unchanged, i.e. using the same training parameters as `SourceOnly` (in the simplest case). To modify the hyper-parameters in MCD, like `N_STEP_F` (number of steps to update the feature extractor), you can append `TRAINER.MCD.N_STEP_F 4` to the existing input arguments (otherwise the default value will be used). Alternatively, you can create a new `.yaml` config file to store your custom setting. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L176) for a complete list of algorithm-specific hyper-parameters.\n\n### Test\nModel testing can be done by using `--eval-only`, which asks the code to run `trainer.test()`. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use `model.pth.tar-20` saved at `output/source_only_office31/model`, you can do\n\n```bash\nCUDA_VISIBLE_DEVICES=0 python tools/train.py \\\n--root $DATA \\\n--trainer SourceOnly \\\n--source-domains amazon \\\n--target-domains webcam \\\n--dataset-config-file configs/datasets/da/office31.yaml \\\n--config-file configs/trainers/da/source_only/office31.yaml \\\n--output-dir output/source_only_office31_test \\\n--eval-only \\\n--model-dir output/source_only_office31 \\\n--load-epoch 20\n```\n\nNote that `--model-dir` takes as input the directory path which was specified in `--output-dir` in the training stage.\n\n### Write a new trainer\nA good practice is to go through `dassl/engine/trainer.py` to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass `TrainerXU`. For domain generalization, the new class can subclass `TrainerX`. In particular, `TrainerXU` and `TrainerX` mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the `forward_backward()` method, which performs loss computation and model update. See `dassl/enigne/da/source_only.py` for example.\n\n### Add a new backbone/head/network\n`backbone` corresponds to a convolutional neural network model which performs feature extraction. `head` (which is an optional module) is mounted on top of `backbone` for further processing, which can be, for example, a MLP. `backbone` and `head` are basic building blocks for constructing a `SimpleNet()` (see `dassl/engine/trainer.py`) which serves as the primary model for a task. `network` contains custom neural network models, such as an image generator.\n\nTo add a new module, namely a backbone/head/network, you need to first register the module using the corresponding `registry`, i.e. `BACKBONE_REGISTRY` for `backbone`, `HEAD_REGISTRY` for `head` and `NETWORK_RESIGTRY` for `network`. Note that for a new `backbone`, we require the model to subclass `Backbone` as defined in `dassl/modeling/backbone/backbone.py` and specify the `self._out_features` attribute.\n\nWe provide an example below for how to add a new `backbone`.\n```python\nfrom dassl.modeling import Backbone, BACKBONE_REGISTRY\n\nclass MyBackbone(Backbone):\n\n    def __init__(self):\n        super().__init__()\n        # Create layers\n        self.conv = ...\n\n        self._out_features = 2048\n\n    def forward(self, x):\n        # Extract and return features\n\n@BACKBONE_REGISTRY.register()\ndef my_backbone(**kwargs):\n    return MyBackbone()\n```\nThen, you can set `MODEL.BACKBONE.NAME` to `my_backbone` to use your own architecture. For more details, please refer to the source code in `dassl/modeling`.\n\n### Add a dataset\nAn example code structure is shown below. Make sure you subclass `DatasetBase` and register the dataset with `@DATASET_REGISTRY.register()`. All you need is to load `train_x`, `train_u` (optional), `val` (optional) and `test`, among which `train_u` and `val` could be `None` or simply ignored. Each of these variables contains a list of `Datum` objects. A `Datum` object (implemented [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/data/datasets/base_dataset.py#L12)) contains information for a single image, like `impath` (string) and `label` (int).\n\n```python\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\n\n@DATASET_REGISTRY.register()\nclass NewDataset(DatasetBase):\n\n    dataset_dir = ''\n\n    def __init__(self, cfg):\n        \n        train_x = ...\n        train_u = ...  # optional, can be None\n        val = ...  # optional, can be None\n        test = ...\n\n        super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)\n```\n\nWe suggest you take a look at the datasets code in some projects like [this](https://github.com/KaiyangZhou/CoOp), which is built on top of Dassl.\n\n## Relevant Research\n\nWe would like to share here our research relevant to Dassl.\n\n- [Domain Adaptive Ensemble Learning](https://arxiv.org/abs/2003.07325), TIP, 2021.\n- [MixStyle Neural Networks for Domain Generalization and Adaptation](https://arxiv.org/abs/2107.02053), arxiv preprint, 2021.\n- [Semi-Supervised Domain Generalization with Stochastic StyleMatch](https://arxiv.org/abs/2106.00592), arxiv preprint, 2021.\n- [Domain Generalization in Vision: A Survey](https://arxiv.org/abs/2103.02503), arxiv preprint, 2021.\n- [Domain Generalization with MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp), in ICLR 2021.\n- [Learning to Generate Novel Domains for Domain Generalization](https://arxiv.org/abs/2007.03304), in ECCV 2020.\n- [Deep Domain-Adversarial Image Generation for Domain Generalisation](https://arxiv.org/abs/2003.06054), in AAAI 2020.\n\n## Citation\n\nIf you find this code useful to your research, please give credit to the following paper\n\n```\n@article{zhou2020domain,\n  title={Domain Adaptive Ensemble Learning},\n  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},\n  journal={IEEE Transactions on Image Processing (TIP)},\n  year={2021}\n}\n```\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/README.md",
    "content": "The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings.\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:\n  NAME: \"CIFARSTL\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n  TRANSFORMS: [\"normalize\"]\n\nDATASET:\n  NAME: \"Digit5\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"cnn_digit5_m3sda\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml",
    "content": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"DomainNet\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"resnet101\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/mini_domainnet.yaml",
    "content": "INPUT:\n  SIZE: (96, 96)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"miniDomainNet\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"resnet18\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml",
    "content": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"Office31\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"resnet50\"\n  HEAD:\n    NAME: \"mlp\"\n    HIDDEN_LAYERS: [256]\n    DROPOUT: 0."
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/office_home.yaml",
    "content": "INPUT:\n  SIZE: (224, 224)\n\nDATASET:\n  NAME: \"OfficeHome\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml",
    "content": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"center_crop\", \"normalize\"]\n\nDATASET:\n  NAME: \"VisDA17\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"resnet101\"\n\nTEST:\n  PER_CLASS_RESULT: True"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:\n  NAME: \"CIFAR100C\"\n  CIFAR_C_TYPE: \"fog\"\n  CIFAR_C_LEVEL: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"wide_resnet_16_4\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:\n  NAME: \"CIFAR10C\"\n  CIFAR_C_TYPE: \"fog\"\n  CIFAR_C_LEVEL: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"wide_resnet_16_4\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:\n  NAME: \"DigitSingle\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"cnn_digitsingle\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:\n  NAME: \"DigitsDG\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"cnn_digitsdg\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/office_home_dg.yaml",
    "content": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"OfficeHomeDG\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"resnet18\"\n    PRETRAINED: True"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml",
    "content": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"PACS\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"resnet18\"\n    PRETRAINED: True"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml",
    "content": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"VLCS\"\n\nMODEL:\n  BACKBONE:\n    NAME: \"resnet18\"\n    PRETRAINED: True"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:\n  NAME: \"CIFAR10\"\n  NUM_LABELED: 4000\n  VAL_PERCENT: 0.\n\nMODEL:\n  BACKBONE:\n    NAME: \"wide_resnet_28_2\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n  CROP_PADDING: 4\n\nDATASET:\n  NAME: \"CIFAR100\"\n  NUM_LABELED: 10000\n  VAL_PERCENT: 0.\n\nMODEL:\n  BACKBONE:\n    NAME: \"wide_resnet_28_2\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml",
    "content": "INPUT:\n  SIZE: (96, 96)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n  CROP_PADDING: 4\n\nDATASET:\n  NAME: \"STL10\"\n  STL10_FOLD: 0\n\nMODEL:\n  BACKBONE:\n    NAME: \"wide_resnet_28_2\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml",
    "content": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n  CROP_PADDING: 4\n\nDATASET:\n  NAME: \"SVHN\"\n  NUM_LABELED: 1000\n  VAL_PERCENT: 0.\n\nMODEL:\n  BACKBONE:\n    NAME: \"wide_resnet_28_2\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 256\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 256\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [30]\n  MAX_EPOCH: 30\n  LR_SCHEDULER: \"cosine\"\n\nTRAINER:\n  DAEL:\n    STRONG_TRANSFORMS: [\"randaugment2\", \"normalize\"]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/dael/domainnet.yaml",
    "content": "DATALOADER:\n  NUM_WORKERS: 4\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 6\n  TEST:\n    BATCH_SIZE: 30\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 40\n  LR_SCHEDULER: \"cosine\"\n\nTRAINER:\n  DAEL:\n    STRONG_TRANSFORMS: [\"random_flip\", \"cutout\", \"randaugment2\", \"normalize\"]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/dael/mini_domainnet.yaml",
    "content": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 192\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 200\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.005\n  MAX_EPOCH: 60\n  LR_SCHEDULER: \"cosine\"\n\nTRAINER:\n  DAEL:\n    STRONG_TRANSFORMS: [\"random_flip\", \"cutout\", \"randaugment2\", \"normalize\"]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 256\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 256\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [30]\n  MAX_EPOCH: 30\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/domainnet.yaml",
    "content": "DATALOADER:\n  NUM_WORKERS: 4\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 6\n  TEST:\n    BATCH_SIZE: 30\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 40\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/mini_domainnet.yaml",
    "content": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 192\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 200\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.005\n  MAX_EPOCH: 60\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/digit5.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 256\n  TEST:\n    BATCH_SIZE: 256\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [30]\n  MAX_EPOCH: 30\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/mini_domainnet.yaml",
    "content": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATCH_SIZE: 128\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.005\n  MAX_EPOCH: 60\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/office31.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 32\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  STEPSIZE: [20]\n  MAX_EPOCH: 20"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/visda17.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 32\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.0001\n  STEPSIZE: [2]\n  MAX_EPOCH: 2\n\nTRAIN:\n  PRINT_FREQ: 50\n  COUNT_ITER: \"train_u\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/dael/digits_dg.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 120\n  TEST:\n    BATCH_SIZE: 100\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [20]\n  MAX_EPOCH: 50\n\nTRAINER:\n  DAEL:\n    STRONG_TRANSFORMS: [\"randaugment2\", \"normalize\"]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/dael/office_home_dg.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TEST:\n    BATCH_SIZE: 100\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 40\n  LR_SCHEDULER: \"cosine\"\n\nTRAINER:\n  DAEL:\n    STRONG_TRANSFORMS: [\"random_flip\", \"cutout\", \"randaugment2\", \"normalize\"]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TEST:\n    BATCH_SIZE: 100\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 40\n  LR_SCHEDULER: \"cosine\"\n\nTRAINER:\n  DAEL:\n    STRONG_TRANSFORMS: [\"random_flip\", \"cutout\", \"randaugment2\", \"normalize\"]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/digits_dg.yaml",
    "content": "INPUT:\n  PIXEL_MEAN: [0., 0., 0.]\n  PIXEL_STD: [1., 1., 1.]\n\nDATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATCH_SIZE: 128\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [20]\n  MAX_EPOCH: 50\n\nTRAINER:\n  DDAIG:\n    G_ARCH: \"fcn_3x32_gctx\"\n    LMDA: 0.3"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/office_home_dg.yaml",
    "content": "INPUT:\n  PIXEL_MEAN: [0., 0., 0.]\n  PIXEL_STD: [1., 1., 1.]\n\nDATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 16\n  TEST:\n    BATCH_SIZE: 16\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.0005\n  STEPSIZE: [20]\n  MAX_EPOCH: 25\n\nTRAINER:\n  DDAIG:\n    G_ARCH: \"fcn_3x64_gctx\"\n    WARMUP: 3\n    LMDA: 0.3"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml",
    "content": "INPUT:\n  PIXEL_MEAN: [0., 0., 0.]\n  PIXEL_STD: [1., 1., 1.]\n\nDATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 16\n  TEST:\n    BATCH_SIZE: 16\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.0005\n  STEPSIZE: [20]\n  MAX_EPOCH: 25\n\nTRAINER:\n  DDAIG:\n    G_ARCH: \"fcn_3x64_gctx\"\n    WARMUP: 3\n    LMDA: 0.3"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/digits_dg.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [20]\n  MAX_EPOCH: 50\n\nTRAIN:\n  PRINT_FREQ: 20"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/mini_domainnet.yaml",
    "content": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATCH_SIZE: 128\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.005\n  MAX_EPOCH: 60\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/office_home_dg.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.001\n  MAX_EPOCH: 50\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.001\n  MAX_EPOCH: 50\n  LR_SCHEDULER: \"cosine\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/ssl/fixmatch/cifar10.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 64\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 448\n  TEST:\n    BATCH_SIZE: 500\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [4000]\n  MAX_EPOCH: 4000\n  LR_SCHEDULER: \"cosine\"\n\nTRAIN:\n  COUNT_ITER: \"train_u\"\n  PRINT_FREQ: 10\n\nTRAINER:\n  FIXMATCH:\n    STRONG_TRANSFORMS: [\"random_flip\", \"randaugment_fixmatch\", \"normalize\", \"cutout\"]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/__init__.py",
    "content": "\"\"\"\nDassl\n------\nPyTorch toolbox for domain adaptation and semi-supervised learning.\n\nURL: https://github.com/KaiyangZhou/Dassl.pytorch\n\n@article{zhou2020domain,\n  title={Domain Adaptive Ensemble Learning},\n  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},\n  journal={arXiv preprint arXiv:2003.07325},\n  year={2020}\n}\n\"\"\"\n\n__version__ = \"0.5.0\"\n__author__ = \"Kaiyang Zhou\"\n__homepage__ = \"https://kaiyangzhou.github.io/\"\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/config/__init__.py",
    "content": "from .defaults import _C as cfg_default\n\n\ndef get_cfg_default():\n    return cfg_default.clone()\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/config/defaults.py",
    "content": "from yacs.config import CfgNode as CN\n\n###########################\n# Config definition\n###########################\n\n_C = CN()\n\n_C.VERSION = 1\n\n# Directory to save the output files (like log.txt and model weights)\n_C.OUTPUT_DIR = \"./output\"\n# Path to a directory where the files were saved previously\n_C.RESUME = \"\"\n# Set seed to negative value to randomize everything\n# Set seed to positive value to use a fixed seed\n_C.SEED = -1\n_C.USE_CUDA = True\n# Print detailed information\n# E.g. trainer, dataset, and backbone\n_C.VERBOSE = True\n\n###########################\n# Input\n###########################\n_C.INPUT = CN()\n_C.INPUT.SIZE = (224, 224)\n# Mode of interpolation in resize functions\n_C.INPUT.INTERPOLATION = \"bilinear\"\n# For available choices please refer to transforms.py\n_C.INPUT.TRANSFORMS = ()\n# If True, tfm_train and tfm_test will be None\n_C.INPUT.NO_TRANSFORM = False\n# Default mean and std come from ImageNet\n_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]\n_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]\n# Padding for random crop\n_C.INPUT.CROP_PADDING = 4\n# Cutout\n_C.INPUT.CUTOUT_N = 1\n_C.INPUT.CUTOUT_LEN = 16\n# Gaussian noise\n_C.INPUT.GN_MEAN = 0.0\n_C.INPUT.GN_STD = 0.15\n# RandomAugment\n_C.INPUT.RANDAUGMENT_N = 2\n_C.INPUT.RANDAUGMENT_M = 10\n# ColorJitter (brightness, contrast, saturation, hue)\n_C.INPUT.COLORJITTER_B = 0.4\n_C.INPUT.COLORJITTER_C = 0.4\n_C.INPUT.COLORJITTER_S = 0.4\n_C.INPUT.COLORJITTER_H = 0.1\n# Random gray scale's probability\n_C.INPUT.RGS_P = 0.2\n# Gaussian blur\n_C.INPUT.GB_P = 0.5  # propability of applying this operation\n_C.INPUT.GB_K = 21  # kernel size (should be an odd number)\n\n###########################\n# Dataset\n###########################\n_C.DATASET = CN()\n# Directory where datasets are stored\n_C.DATASET.ROOT = \"\"\n_C.DATASET.NAME = \"\"\n# List of names of source domains\n_C.DATASET.SOURCE_DOMAINS = ()\n# List of names of target domains\n_C.DATASET.TARGET_DOMAINS = ()\n# Number of labeled instances in total\n# Useful for the semi-supervised learning\n_C.DATASET.NUM_LABELED = -1\n# Number of images per class\n_C.DATASET.NUM_SHOTS = -1\n# Percentage of validation data (only used for SSL datasets)\n# Set to 0 if do not want to use val data\n# Using val data for hyperparameter tuning was done in Oliver et al. 2018\n_C.DATASET.VAL_PERCENT = 0.1\n# Fold index for STL-10 dataset (normal range is 0 - 9)\n# Negative number means None\n_C.DATASET.STL10_FOLD = -1\n# CIFAR-10/100-C's corruption type and intensity level\n_C.DATASET.CIFAR_C_TYPE = \"\"\n_C.DATASET.CIFAR_C_LEVEL = 1\n# Use all data in the unlabeled data set (e.g. FixMatch)\n_C.DATASET.ALL_AS_UNLABELED = False\n\n###########################\n# Dataloader\n###########################\n_C.DATALOADER = CN()\n_C.DATALOADER.NUM_WORKERS = 4\n# Apply transformations to an image K times (during training)\n_C.DATALOADER.K_TRANSFORMS = 1\n# img0 denotes image tensor without augmentation\n# Useful for consistency learning\n_C.DATALOADER.RETURN_IMG0 = False\n# Setting for the train_x data-loader\n_C.DATALOADER.TRAIN_X = CN()\n_C.DATALOADER.TRAIN_X.SAMPLER = \"RandomSampler\"\n_C.DATALOADER.TRAIN_X.BATCH_SIZE = 32\n# Parameter for RandomDomainSampler\n# 0 or -1 means sampling from all domains\n_C.DATALOADER.TRAIN_X.N_DOMAIN = 0\n# Parameter of RandomClassSampler\n# Number of instances per class\n_C.DATALOADER.TRAIN_X.N_INS = 16\n\n# Setting for the train_u data-loader\n_C.DATALOADER.TRAIN_U = CN()\n# Set to false if you want to have unique\n# data loader params for train_u\n_C.DATALOADER.TRAIN_U.SAME_AS_X = True\n_C.DATALOADER.TRAIN_U.SAMPLER = \"RandomSampler\"\n_C.DATALOADER.TRAIN_U.BATCH_SIZE = 32\n_C.DATALOADER.TRAIN_U.N_DOMAIN = 0\n_C.DATALOADER.TRAIN_U.N_INS = 16\n\n# Setting for the test data-loader\n_C.DATALOADER.TEST = CN()\n_C.DATALOADER.TEST.SAMPLER = \"SequentialSampler\"\n_C.DATALOADER.TEST.BATCH_SIZE = 32\n\n###########################\n# Model\n###########################\n_C.MODEL = CN()\n# Path to model weights (for initialization)\n_C.MODEL.INIT_WEIGHTS = \"\"\n_C.MODEL.BACKBONE = CN()\n_C.MODEL.BACKBONE.NAME = \"\"\n_C.MODEL.BACKBONE.PRETRAINED = True\n# Definition of embedding layers\n_C.MODEL.HEAD = CN()\n# If none, do not construct embedding layers, the\n# backbone's output will be passed to the classifier\n_C.MODEL.HEAD.NAME = \"\"\n# Structure of hidden layers (a list), e.g. [512, 512]\n# If undefined, no embedding layer will be constructed\n_C.MODEL.HEAD.HIDDEN_LAYERS = ()\n_C.MODEL.HEAD.ACTIVATION = \"relu\"\n_C.MODEL.HEAD.BN = True\n_C.MODEL.HEAD.DROPOUT = 0.0\n\n###########################\n# Optimization\n###########################\n_C.OPTIM = CN()\n_C.OPTIM.NAME = \"adam\"\n_C.OPTIM.LR = 0.0003\n_C.OPTIM.WEIGHT_DECAY = 5e-4\n_C.OPTIM.MOMENTUM = 0.9\n_C.OPTIM.SGD_DAMPNING = 0\n_C.OPTIM.SGD_NESTEROV = False\n_C.OPTIM.RMSPROP_ALPHA = 0.99\n_C.OPTIM.ADAM_BETA1 = 0.9\n_C.OPTIM.ADAM_BETA2 = 0.999\n# STAGED_LR allows different layers to have\n# different lr, e.g. pre-trained base layers\n# can be assigned a smaller lr than the new\n# classification layer\n_C.OPTIM.STAGED_LR = False\n_C.OPTIM.NEW_LAYERS = ()\n_C.OPTIM.BASE_LR_MULT = 0.1\n# Learning rate scheduler\n_C.OPTIM.LR_SCHEDULER = \"single_step\"\n# -1 or 0 means the stepsize is equal to max_epoch\n_C.OPTIM.STEPSIZE = (-1, )\n_C.OPTIM.GAMMA = 0.1\n_C.OPTIM.MAX_EPOCH = 10\n# Set WARMUP_EPOCH larger than 0 to activate warmup training\n_C.OPTIM.WARMUP_EPOCH = -1\n# Either linear or constant\n_C.OPTIM.WARMUP_TYPE = \"linear\"\n# Constant learning rate when type=constant\n_C.OPTIM.WARMUP_CONS_LR = 1e-5\n# Minimum learning rate when type=linear\n_C.OPTIM.WARMUP_MIN_LR = 1e-5\n# Recount epoch for the next scheduler (last_epoch=-1)\n# Otherwise last_epoch=warmup_epoch\n_C.OPTIM.WARMUP_RECOUNT = True\n\n###########################\n# Train\n###########################\n_C.TRAIN = CN()\n# How often (epoch) to save model during training\n# Set to 0 or negative value to only save the last one\n_C.TRAIN.CHECKPOINT_FREQ = 0\n# How often (batch) to print training information\n_C.TRAIN.PRINT_FREQ = 10\n# Use 'train_x', 'train_u' or 'smaller_one' to count\n# the number of iterations in an epoch (for DA and SSL)\n_C.TRAIN.COUNT_ITER = \"train_x\"\n\n###########################\n# Test\n###########################\n_C.TEST = CN()\n_C.TEST.EVALUATOR = \"Classification\"\n_C.TEST.PER_CLASS_RESULT = False\n# Compute confusion matrix, which will be saved\n# to $OUTPUT_DIR/cmat.pt\n_C.TEST.COMPUTE_CMAT = False\n# If NO_TEST=True, no testing will be conducted\n_C.TEST.NO_TEST = False\n# Use test or val set for FINAL evaluation\n_C.TEST.SPLIT = \"test\"\n# Which model to test after training\n# Either last_step or best_val\n_C.TEST.FINAL_MODEL = \"last_step\"\n\n###########################\n# Trainer specifics\n###########################\n_C.TRAINER = CN()\n_C.TRAINER.NAME = \"\"\n\n# MCD\n_C.TRAINER.MCD = CN()\n_C.TRAINER.MCD.N_STEP_F = 4  # number of steps to train F\n# MME\n_C.TRAINER.MME = CN()\n_C.TRAINER.MME.LMDA = 0.1  # weight for the entropy loss\n# SelfEnsembling\n_C.TRAINER.SE = CN()\n_C.TRAINER.SE.EMA_ALPHA = 0.999\n_C.TRAINER.SE.CONF_THRE = 0.95\n_C.TRAINER.SE.RAMPUP = 300\n\n# M3SDA\n_C.TRAINER.M3SDA = CN()\n_C.TRAINER.M3SDA.LMDA = 0.5  # weight for the moment distance loss\n_C.TRAINER.M3SDA.N_STEP_F = 4  # follow MCD\n# DAEL\n_C.TRAINER.DAEL = CN()\n_C.TRAINER.DAEL.WEIGHT_U = 0.5  # weight on the unlabeled loss\n_C.TRAINER.DAEL.CONF_THRE = 0.95  # confidence threshold\n_C.TRAINER.DAEL.STRONG_TRANSFORMS = ()\n\n# CrossGrad\n_C.TRAINER.CG = CN()\n_C.TRAINER.CG.EPS_F = 1.0  # scaling parameter for D's gradients\n_C.TRAINER.CG.EPS_D = 1.0  # scaling parameter for F's gradients\n_C.TRAINER.CG.ALPHA_F = 0.5  # balancing weight for the label net's loss\n_C.TRAINER.CG.ALPHA_D = 0.5  # balancing weight for the domain net's loss\n# DDAIG\n_C.TRAINER.DDAIG = CN()\n_C.TRAINER.DDAIG.G_ARCH = \"\"  # generator's architecture\n_C.TRAINER.DDAIG.LMDA = 0.3  # perturbation weight\n_C.TRAINER.DDAIG.CLAMP = False  # clamp perturbation values\n_C.TRAINER.DDAIG.CLAMP_MIN = -1.0\n_C.TRAINER.DDAIG.CLAMP_MAX = 1.0\n_C.TRAINER.DDAIG.WARMUP = 0\n_C.TRAINER.DDAIG.ALPHA = 0.5  # balancing weight for the losses\n\n# EntMin\n_C.TRAINER.ENTMIN = CN()\n_C.TRAINER.ENTMIN.LMDA = 1e-3  # weight on the entropy loss\n# Mean Teacher\n_C.TRAINER.MEANTEA = CN()\n_C.TRAINER.MEANTEA.WEIGHT_U = 1.0  # weight on the unlabeled loss\n_C.TRAINER.MEANTEA.EMA_ALPHA = 0.999\n_C.TRAINER.MEANTEA.RAMPUP = 5  # epochs used to ramp up the loss_u weight\n# MixMatch\n_C.TRAINER.MIXMATCH = CN()\n_C.TRAINER.MIXMATCH.WEIGHT_U = 100.0  # weight on the unlabeled loss\n_C.TRAINER.MIXMATCH.TEMP = 2.0  # temperature for sharpening the probability\n_C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75\n_C.TRAINER.MIXMATCH.RAMPUP = 20000  # steps used to ramp up the loss_u weight\n# FixMatch\n_C.TRAINER.FIXMATCH = CN()\n_C.TRAINER.FIXMATCH.WEIGHT_U = 1.0  # weight on the unlabeled loss\n_C.TRAINER.FIXMATCH.CONF_THRE = 0.95  # confidence threshold\n_C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = ()\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/__init__.py",
    "content": "from .data_manager import DataManager, DatasetWrapper\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/data_manager.py",
    "content": "import torch\nimport torchvision.transforms as T\nfrom PIL import Image\nfrom torch.utils.data import Dataset as TorchDataset\n\nfrom dassl.utils import read_image\n\nfrom .datasets import build_dataset\nfrom .samplers import build_sampler\nfrom .transforms import build_transform\n\nINTERPOLATION_MODES = {\n    \"bilinear\": Image.BILINEAR,\n    \"bicubic\": Image.BICUBIC,\n    \"nearest\": Image.NEAREST,\n}\n\n\ndef build_data_loader(\n    cfg,\n    sampler_type=\"SequentialSampler\",\n    data_source=None,\n    batch_size=64,\n    n_domain=0,\n    n_ins=2,\n    tfm=None,\n    is_train=True,\n    dataset_wrapper=None,\n):\n    # Build sampler\n    sampler = build_sampler(\n        sampler_type,\n        cfg=cfg,\n        data_source=data_source,\n        batch_size=batch_size,\n        n_domain=n_domain,\n        n_ins=n_ins,\n    )\n\n    if dataset_wrapper is None:\n        dataset_wrapper = DatasetWrapper\n\n    # Build data loader\n    data_loader = torch.utils.data.DataLoader(\n        dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),\n        batch_size=batch_size,\n        sampler=sampler,\n        num_workers=cfg.DATALOADER.NUM_WORKERS,\n        drop_last=is_train and len(data_source) >= batch_size,\n        pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),\n    )\n    assert len(data_loader) > 0\n\n    return data_loader\n\n\nclass DataManager:\n\n    def __init__(\n        self,\n        cfg,\n        custom_tfm_train=None,\n        custom_tfm_test=None,\n        dataset_wrapper=None\n    ):\n        # Load dataset\n        dataset = build_dataset(cfg)\n        # Build transform\n        if custom_tfm_train is None:\n            tfm_train = build_transform(cfg, is_train=True)\n        else:\n            print(\"* Using custom transform for training\")\n            tfm_train = custom_tfm_train\n\n        if custom_tfm_test is None:\n            tfm_test = build_transform(cfg, is_train=False)\n        else:\n            print(\"* Using custom transform for testing\")\n            tfm_test = custom_tfm_test\n\n        # Build train_loader_x\n        train_loader_x = build_data_loader(\n            cfg,\n            sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,\n            data_source=dataset.train_x,\n            batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,\n            n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,\n            n_ins=cfg.DATALOADER.TRAIN_X.N_INS,\n            tfm=tfm_train,\n            is_train=True,\n            dataset_wrapper=dataset_wrapper,\n        )\n\n        # Build train_loader_u\n        train_loader_u = None\n        if dataset.train_u:\n            sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER\n            batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE\n            n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN\n            n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS\n\n            if cfg.DATALOADER.TRAIN_U.SAME_AS_X:\n                sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER\n                batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE\n                n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN\n                n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS\n\n            train_loader_u = build_data_loader(\n                cfg,\n                sampler_type=sampler_type_,\n                data_source=dataset.train_u,\n                batch_size=batch_size_,\n                n_domain=n_domain_,\n                n_ins=n_ins_,\n                tfm=tfm_train,\n                is_train=True,\n                dataset_wrapper=dataset_wrapper,\n            )\n\n        # Build val_loader\n        val_loader = None\n        if dataset.val:\n            val_loader = build_data_loader(\n                cfg,\n                sampler_type=cfg.DATALOADER.TEST.SAMPLER,\n                data_source=dataset.val,\n                batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,\n                tfm=tfm_test,\n                is_train=False,\n                dataset_wrapper=dataset_wrapper,\n            )\n\n        # Build test_loader\n        test_loader = build_data_loader(\n            cfg,\n            sampler_type=cfg.DATALOADER.TEST.SAMPLER,\n            data_source=dataset.test,\n            batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,\n            tfm=tfm_test,\n            is_train=False,\n            dataset_wrapper=dataset_wrapper,\n        )\n\n        # Attributes\n        self._num_classes = dataset.num_classes\n        self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)\n        self._lab2cname = dataset.lab2cname\n\n        # Dataset and data-loaders\n        self.dataset = dataset\n        self.train_loader_x = train_loader_x\n        self.train_loader_u = train_loader_u\n        self.val_loader = val_loader\n        self.test_loader = test_loader\n\n        if cfg.VERBOSE:\n            self.show_dataset_summary(cfg)\n\n    @property\n    def num_classes(self):\n        return self._num_classes\n\n    @property\n    def num_source_domains(self):\n        return self._num_source_domains\n\n    @property\n    def lab2cname(self):\n        return self._lab2cname\n\n    def show_dataset_summary(self, cfg):\n        print(\"***** Dataset statistics *****\")\n\n        print(\"  Dataset: {}\".format(cfg.DATASET.NAME))\n\n        if cfg.DATASET.SOURCE_DOMAINS:\n            print(\"  Source domains: {}\".format(cfg.DATASET.SOURCE_DOMAINS))\n        if cfg.DATASET.TARGET_DOMAINS:\n            print(\"  Target domains: {}\".format(cfg.DATASET.TARGET_DOMAINS))\n\n        print(\"  # classes: {:,}\".format(self.num_classes))\n\n        print(\"  # train_x: {:,}\".format(len(self.dataset.train_x)))\n\n        if self.dataset.train_u:\n            print(\"  # train_u: {:,}\".format(len(self.dataset.train_u)))\n\n        if self.dataset.val:\n            print(\"  # val: {:,}\".format(len(self.dataset.val)))\n\n        print(\"  # test: {:,}\".format(len(self.dataset.test)))\n\n\nclass DatasetWrapper(TorchDataset):\n\n    def __init__(self, cfg, data_source, transform=None, is_train=False):\n        self.cfg = cfg\n        self.data_source = data_source\n        self.transform = transform  # accept list (tuple) as input\n        self.is_train = is_train\n        # Augmenting an image K>1 times is only allowed during training\n        self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1\n        self.return_img0 = cfg.DATALOADER.RETURN_IMG0\n\n        if self.k_tfm > 1 and transform is None:\n            raise ValueError(\n                \"Cannot augment the image {} times \"\n                \"because transform is None\".format(self.k_tfm)\n            )\n\n        # Build transform that doesn't apply any data augmentation\n        interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]\n        to_tensor = []\n        to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]\n        to_tensor += [T.ToTensor()]\n        if \"normalize\" in cfg.INPUT.TRANSFORMS:\n            normalize = T.Normalize(\n                mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD\n            )\n            to_tensor += [normalize]\n        self.to_tensor = T.Compose(to_tensor)\n\n    def __len__(self):\n        return len(self.data_source)\n\n    def __getitem__(self, idx):\n        item = self.data_source[idx]\n\n        output = {\n            \"label\": item.label,\n            \"domain\": item.domain,\n            \"impath\": item.impath\n        }\n\n        img0 = read_image(item.impath)\n\n        if self.transform is not None:\n            if isinstance(self.transform, (list, tuple)):\n                for i, tfm in enumerate(self.transform):\n                    img = self._transform_image(tfm, img0)\n                    keyname = \"img\"\n                    if (i + 1) > 1:\n                        keyname += str(i + 1)\n                    output[keyname] = img\n            else:\n                img = self._transform_image(self.transform, img0)\n                output[\"img\"] = img\n\n        if self.return_img0:\n            output[\"img0\"] = self.to_tensor(img0)\n\n        return output\n\n    def _transform_image(self, tfm, img0):\n        img_list = []\n\n        for k in range(self.k_tfm):\n            img_list.append(tfm(img0))\n\n        img = img_list\n        if len(img) == 1:\n            img = img[0]\n\n        return img\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py",
    "content": "from .build import DATASET_REGISTRY, build_dataset  # isort:skip\nfrom .base_dataset import Datum, DatasetBase  # isort:skip\n\nfrom .da import *\nfrom .dg import *\nfrom .ssl import *\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py",
    "content": "import os\nimport random\nimport os.path as osp\nimport tarfile\nimport zipfile\nfrom collections import defaultdict\nimport gdown\n\nfrom dassl.utils import check_isfile\n\n\nclass Datum:\n    \"\"\"Data instance which defines the basic attributes.\n\n    Args:\n        impath (str): image path.\n        label (int): class label.\n        domain (int): domain label.\n        classname (str): class name.\n    \"\"\"\n\n    def __init__(self, impath=\"\", label=0, domain=0, classname=\"\"):\n        assert isinstance(impath, str)\n        assert check_isfile(impath)\n\n        self._impath = impath\n        self._label = label\n        self._domain = domain\n        self._classname = classname\n\n    @property\n    def impath(self):\n        return self._impath\n\n    @property\n    def label(self):\n        return self._label\n\n    @property\n    def domain(self):\n        return self._domain\n\n    @property\n    def classname(self):\n        return self._classname\n\n\nclass DatasetBase:\n    \"\"\"A unified dataset class for\n    1) domain adaptation\n    2) domain generalization\n    3) semi-supervised learning\n    \"\"\"\n\n    dataset_dir = \"\"  # the directory where the dataset is stored\n    domains = []  # string names of all domains\n\n    def __init__(self, train_x=None, train_u=None, val=None, test=None):\n        self._train_x = train_x  # labeled training data\n        self._train_u = train_u  # unlabeled training data (optional)\n        self._val = val  # validation data (optional)\n        self._test = test  # test data\n\n        self._num_classes = self.get_num_classes(train_x)\n        self._lab2cname, self._classnames = self.get_lab2cname(train_x)\n\n    @property\n    def train_x(self):\n        return self._train_x\n\n    @property\n    def train_u(self):\n        return self._train_u\n\n    @property\n    def val(self):\n        return self._val\n\n    @property\n    def test(self):\n        return self._test\n\n    @property\n    def lab2cname(self):\n        return self._lab2cname\n\n    @property\n    def classnames(self):\n        return self._classnames\n\n    @property\n    def num_classes(self):\n        return self._num_classes\n\n    def get_num_classes(self, data_source):\n        \"\"\"Count number of classes.\n\n        Args:\n            data_source (list): a list of Datum objects.\n        \"\"\"\n        label_set = set()\n        for item in data_source:\n            label_set.add(item.label)\n        return max(label_set) + 1\n\n    def get_lab2cname(self, data_source):\n        \"\"\"Get a label-to-classname mapping (dict).\n\n        Args:\n            data_source (list): a list of Datum objects.\n        \"\"\"\n        container = set()\n        for item in data_source:\n            container.add((item.label, item.classname))\n        mapping = {label: classname for label, classname in container}\n        labels = list(mapping.keys())\n        labels.sort()\n        classnames = [mapping[label] for label in labels]\n        return mapping, classnames\n\n    def check_input_domains(self, source_domains, target_domains):\n        self.is_input_domain_valid(source_domains)\n        self.is_input_domain_valid(target_domains)\n\n    def is_input_domain_valid(self, input_domains):\n        for domain in input_domains:\n            if domain not in self.domains:\n                raise ValueError(\n                    \"Input domain must belong to {}, \"\n                    \"but got [{}]\".format(self.domains, domain)\n                )\n\n    def download_data(self, url, dst, from_gdrive=True):\n        if not osp.exists(osp.dirname(dst)):\n            os.makedirs(osp.dirname(dst))\n\n        if from_gdrive:\n            gdown.download(url, dst, quiet=False)\n        else:\n            raise NotImplementedError\n\n        print(\"Extracting file ...\")\n\n        try:\n            tar = tarfile.open(dst)\n            tar.extractall(path=osp.dirname(dst))\n            tar.close()\n        except:\n            zip_ref = zipfile.ZipFile(dst, \"r\")\n            zip_ref.extractall(osp.dirname(dst))\n            zip_ref.close()\n\n        print(\"File extracted to {}\".format(osp.dirname(dst)))\n\n    def generate_fewshot_dataset(\n        self, *data_sources, num_shots=-1, repeat=False\n    ):\n        \"\"\"Generate a few-shot dataset (typically for the training set).\n\n        This function is useful when one wants to evaluate a model\n        in a few-shot learning setting where each class only contains\n        a few number of images.\n\n        Args:\n            data_sources: each individual is a list containing Datum objects.\n            num_shots (int): number of instances per class to sample.\n            repeat (bool): repeat images if needed (default: False).\n        \"\"\"\n        if num_shots < 1:\n            if len(data_sources) == 1:\n                return data_sources[0]\n            return data_sources\n\n        print(f\"Creating a {num_shots}-shot dataset\")\n\n        output = []\n\n        for data_source in data_sources:\n            tracker = self.split_dataset_by_label(data_source)\n            dataset = []\n\n            for label, items in tracker.items():\n                if len(items) >= num_shots:\n                    sampled_items = random.sample(items, num_shots)\n                else:\n                    if repeat:\n                        sampled_items = random.choices(items, k=num_shots)\n                    else:\n                        sampled_items = items\n                dataset.extend(sampled_items)\n\n            output.append(dataset)\n\n        if len(output) == 1:\n            return output[0]\n\n        return output\n\n    def split_dataset_by_label(self, data_source):\n        \"\"\"Split a dataset, i.e. a list of Datum objects,\n        into class-specific groups stored in a dictionary.\n\n        Args:\n            data_source (list): a list of Datum objects.\n        \"\"\"\n        output = defaultdict(list)\n\n        for item in data_source:\n            output[item.label].append(item)\n\n        return output\n\n    def split_dataset_by_domain(self, data_source):\n        \"\"\"Split a dataset, i.e. a list of Datum objects,\n        into domain-specific groups stored in a dictionary.\n\n        Args:\n            data_source (list): a list of Datum objects.\n        \"\"\"\n        output = defaultdict(list)\n\n        for item in data_source:\n            output[item.domain].append(item)\n\n        return output\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/build.py",
    "content": "from dassl.utils import Registry, check_availability\n\nDATASET_REGISTRY = Registry(\"DATASET\")\n\n\ndef build_dataset(cfg):\n    avai_datasets = DATASET_REGISTRY.registered_names()\n    check_availability(cfg.DATASET.NAME, avai_datasets)\n    if cfg.VERBOSE:\n        print(\"Loading dataset: {}\".format(cfg.DATASET.NAME))\n    return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py",
    "content": "from .digit5 import Digit5\nfrom .visda17 import VisDA17\nfrom .cifarstl import CIFARSTL\nfrom .office31 import Office31\nfrom .domainnet import DomainNet\nfrom .office_home import OfficeHome\nfrom .mini_domainnet import miniDomainNet\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py",
    "content": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass CIFARSTL(DatasetBase):\n    \"\"\"CIFAR-10 and STL-10.\n\n    CIFAR-10:\n        - 60,000 32x32 colour images.\n        - 10 classes, with 6,000 images per class.\n        - 50,000 training images and 10,000 test images.\n        - URL: https://www.cs.toronto.edu/~kriz/cifar.html.\n\n    STL-10:\n        - 10 classes: airplane, bird, car, cat, deer, dog, horse,\n        monkey, ship, truck.\n        - Images are 96x96 pixels, color.\n        - 500 training images (10 pre-defined folds), 800 test images\n        per class.\n        - URL: https://cs.stanford.edu/~acoates/stl10/.\n\n    Reference:\n        - Krizhevsky. Learning Multiple Layers of Features\n        from Tiny Images. Tech report.\n        - Coates et al. An Analysis of Single Layer Networks in\n        Unsupervised Feature Learning. AISTATS 2011.\n    \"\"\"\n\n    dataset_dir = \"cifar_stl\"\n    domains = [\"cifar\", \"stl\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split=\"train\")\n        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"train\")\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"test\")\n\n        super().__init__(train_x=train_x, train_u=train_u, test=test)\n\n    def _read_data(self, input_domains, split=\"train\"):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            data_dir = osp.join(self.dataset_dir, dname, split)\n            class_names = listdir_nohidden(data_dir)\n\n            for class_name in class_names:\n                class_dir = osp.join(data_dir, class_name)\n                imnames = listdir_nohidden(class_dir)\n                label = int(class_name.split(\"_\")[0])\n\n                for imname in imnames:\n                    impath = osp.join(class_dir, imname)\n                    item = Datum(impath=impath, label=label, domain=domain)\n                    items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py",
    "content": "import random\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n# Folder names for train and test sets\nMNIST = {\"train\": \"train_images\", \"test\": \"test_images\"}\nMNIST_M = {\"train\": \"train_images\", \"test\": \"test_images\"}\nSVHN = {\"train\": \"train_images\", \"test\": \"test_images\"}\nSYN = {\"train\": \"train_images\", \"test\": \"test_images\"}\nUSPS = {\"train\": \"train_images\", \"test\": \"test_images\"}\n\n\ndef read_image_list(im_dir, n_max=None, n_repeat=None):\n    items = []\n\n    for imname in listdir_nohidden(im_dir):\n        imname_noext = osp.splitext(imname)[0]\n        label = int(imname_noext.split(\"_\")[1])\n        impath = osp.join(im_dir, imname)\n        items.append((impath, label))\n\n    if n_max is not None:\n        items = random.sample(items, n_max)\n\n    if n_repeat is not None:\n        items *= n_repeat\n\n    return items\n\n\ndef load_mnist(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, MNIST[split])\n    n_max = 25000 if split == \"train\" else 9000\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_mnist_m(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, MNIST_M[split])\n    n_max = 25000 if split == \"train\" else 9000\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_svhn(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, SVHN[split])\n    n_max = 25000 if split == \"train\" else 9000\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_syn(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, SYN[split])\n    n_max = 25000 if split == \"train\" else 9000\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_usps(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, USPS[split])\n    n_repeat = 3 if split == \"train\" else None\n    return read_image_list(data_dir, n_repeat=n_repeat)\n\n\n@DATASET_REGISTRY.register()\nclass Digit5(DatasetBase):\n    \"\"\"Five digit datasets.\n\n    It contains:\n        - MNIST: hand-written digits.\n        - MNIST-M: variant of MNIST with blended background.\n        - SVHN: street view house number.\n        - SYN: synthetic digits.\n        - USPS: hand-written digits, slightly different from MNIST.\n\n    For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from\n    the training set and 9,000 images from the test set. For USPS which has only\n    9,298 images in total, we use the entire dataset but replicate its training\n    set for 3 times so as to match the training set size of other domains.\n\n    Reference:\n        - Lecun et al. Gradient-based learning applied to document\n        recognition. IEEE 1998.\n        - Ganin et al. Domain-adversarial training of neural networks.\n        JMLR 2016.\n        - Netzer et al. Reading digits in natural images with unsupervised\n        feature learning. NIPS-W 2011.\n    \"\"\"\n\n    dataset_dir = \"digit5\"\n    domains = [\"mnist\", \"mnist_m\", \"svhn\", \"syn\", \"usps\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split=\"train\")\n        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"train\")\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"test\")\n\n        super().__init__(train_x=train_x, train_u=train_u, test=test)\n\n    def _read_data(self, input_domains, split=\"train\"):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            func = \"load_\" + dname\n            domain_dir = osp.join(self.dataset_dir, dname)\n            items_d = eval(func)(domain_dir, split=split)\n\n            for impath, label in items_d:\n                item = Datum(\n                    impath=impath,\n                    label=label,\n                    domain=domain,\n                    classname=str(label)\n                )\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py",
    "content": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass DomainNet(DatasetBase):\n    \"\"\"DomainNet.\n\n    Statistics:\n        - 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,\n        Real, Sketch.\n        - Around 0.6M images.\n        - 345 categories.\n        - URL: http://ai.bu.edu/M3SDA/.\n\n    Special note: the t-shirt class (327) is missing in painting_train.txt.\n\n    Reference:\n        - Peng et al. Moment Matching for Multi-Source Domain\n        Adaptation. ICCV 2019.\n    \"\"\"\n\n    dataset_dir = \"domainnet\"\n    domains = [\n        \"clipart\", \"infograph\", \"painting\", \"quickdraw\", \"real\", \"sketch\"\n    ]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.split_dir = osp.join(self.dataset_dir, \"splits\")\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split=\"train\")\n        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"train\")\n        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split=\"test\")\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"test\")\n\n        super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)\n\n    def _read_data(self, input_domains, split=\"train\"):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            filename = dname + \"_\" + split + \".txt\"\n            split_file = osp.join(self.split_dir, filename)\n\n            with open(split_file, \"r\") as f:\n                lines = f.readlines()\n                for line in lines:\n                    line = line.strip()\n                    impath, label = line.split(\" \")\n                    classname = impath.split(\"/\")[1]\n                    impath = osp.join(self.dataset_dir, impath)\n                    label = int(label)\n                    item = Datum(\n                        impath=impath,\n                        label=label,\n                        domain=domain,\n                        classname=classname\n                    )\n                    items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/mini_domainnet.py",
    "content": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass miniDomainNet(DatasetBase):\n    \"\"\"A subset of DomainNet.\n\n    Reference:\n        - Peng et al. Moment Matching for Multi-Source Domain\n        Adaptation. ICCV 2019.\n        - Zhou et al. Domain Adaptive Ensemble Learning.\n    \"\"\"\n\n    dataset_dir = \"domainnet\"\n    domains = [\"clipart\", \"painting\", \"real\", \"sketch\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.split_dir = osp.join(self.dataset_dir, \"splits_mini\")\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split=\"train\")\n        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"train\")\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"test\")\n\n        super().__init__(train_x=train_x, train_u=train_u, test=test)\n\n    def _read_data(self, input_domains, split=\"train\"):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            filename = dname + \"_\" + split + \".txt\"\n            split_file = osp.join(self.split_dir, filename)\n\n            with open(split_file, \"r\") as f:\n                lines = f.readlines()\n                for line in lines:\n                    line = line.strip()\n                    impath, label = line.split(\" \")\n                    classname = impath.split(\"/\")[1]\n                    impath = osp.join(self.dataset_dir, impath)\n                    label = int(label)\n                    item = Datum(\n                        impath=impath,\n                        label=label,\n                        domain=domain,\n                        classname=classname\n                    )\n                    items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py",
    "content": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass Office31(DatasetBase):\n    \"\"\"Office-31.\n\n    Statistics:\n        - 4,110 images.\n        - 31 classes related to office objects.\n        - 3 domains: Amazon, Webcam, Dslr.\n        - URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.\n\n    Reference:\n        - Saenko et al. Adapting visual category models to\n        new domains. ECCV 2010.\n    \"\"\"\n\n    dataset_dir = \"office31\"\n    domains = [\"amazon\", \"webcam\", \"dslr\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)\n        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS)\n\n        super().__init__(train_x=train_x, train_u=train_u, test=test)\n\n    def _read_data(self, input_domains):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            domain_dir = osp.join(self.dataset_dir, dname)\n            class_names = listdir_nohidden(domain_dir)\n            class_names.sort()\n\n            for label, class_name in enumerate(class_names):\n                class_path = osp.join(domain_dir, class_name)\n                imnames = listdir_nohidden(class_path)\n\n                for imname in imnames:\n                    impath = osp.join(class_path, imname)\n                    item = Datum(\n                        impath=impath,\n                        label=label,\n                        domain=domain,\n                        classname=class_name\n                    )\n                    items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py",
    "content": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass OfficeHome(DatasetBase):\n    \"\"\"Office-Home.\n\n    Statistics:\n        - Around 15,500 images.\n        - 65 classes related to office and home objects.\n        - 4 domains: Art, Clipart, Product, Real World.\n        - URL: http://hemanthdv.org/OfficeHome-Dataset/.\n\n    Reference:\n        - Venkateswara et al. Deep Hashing Network for Unsupervised\n        Domain Adaptation. CVPR 2017.\n    \"\"\"\n\n    dataset_dir = \"office_home\"\n    domains = [\"art\", \"clipart\", \"product\", \"real_world\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)\n        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS)\n\n        super().__init__(train_x=train_x, train_u=train_u, test=test)\n\n    def _read_data(self, input_domains):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            domain_dir = osp.join(self.dataset_dir, dname)\n            class_names = listdir_nohidden(domain_dir)\n            class_names.sort()\n\n            for label, class_name in enumerate(class_names):\n                class_path = osp.join(domain_dir, class_name)\n                imnames = listdir_nohidden(class_path)\n\n                for imname in imnames:\n                    impath = osp.join(class_path, imname)\n                    item = Datum(\n                        impath=impath,\n                        label=label,\n                        domain=domain,\n                        classname=class_name.lower(),\n                    )\n                    items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py",
    "content": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass VisDA17(DatasetBase):\n    \"\"\"VisDA17.\n\n    Focusing on simulation-to-reality domain shift.\n\n    URL: http://ai.bu.edu/visda-2017/.\n\n    Reference:\n        - Peng et al. VisDA: The Visual Domain Adaptation\n        Challenge. ArXiv 2017.\n    \"\"\"\n\n    dataset_dir = \"visda17\"\n    domains = [\"synthetic\", \"real\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train_x = self._read_data(\"synthetic\")\n        train_u = self._read_data(\"real\")\n        test = self._read_data(\"real\")\n\n        super().__init__(train_x=train_x, train_u=train_u, test=test)\n\n    def _read_data(self, dname):\n        filedir = \"train\" if dname == \"synthetic\" else \"validation\"\n        image_list = osp.join(self.dataset_dir, filedir, \"image_list.txt\")\n        items = []\n        # There is only one source domain\n        domain = 0\n\n        with open(image_list, \"r\") as f:\n            lines = f.readlines()\n\n            for line in lines:\n                line = line.strip()\n                impath, label = line.split(\" \")\n                classname = impath.split(\"/\")[0]\n                impath = osp.join(self.dataset_dir, filedir, impath)\n                label = int(label)\n                item = Datum(\n                    impath=impath,\n                    label=label,\n                    domain=domain,\n                    classname=classname\n                )\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py",
    "content": "from .pacs import PACS\nfrom .vlcs import VLCS\nfrom .cifar_c import CIFAR10C, CIFAR100C\nfrom .digits_dg import DigitsDG\nfrom .digit_single import DigitSingle\nfrom .office_home_dg import OfficeHomeDG\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py",
    "content": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\nAVAI_C_TYPES = [\n    \"brightness\",\n    \"contrast\",\n    \"defocus_blur\",\n    \"elastic_transform\",\n    \"fog\",\n    \"frost\",\n    \"gaussian_blur\",\n    \"gaussian_noise\",\n    \"glass_blur\",\n    \"impulse_noise\",\n    \"jpeg_compression\",\n    \"motion_blur\",\n    \"pixelate\",\n    \"saturate\",\n    \"shot_noise\",\n    \"snow\",\n    \"spatter\",\n    \"speckle_noise\",\n    \"zoom_blur\",\n]\n\n\n@DATASET_REGISTRY.register()\nclass CIFAR10C(DatasetBase):\n    \"\"\"CIFAR-10 -> CIFAR-10-C.\n\n    Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o\n\n    Statistics:\n        - 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10\n        - 10 categories\n\n    Reference:\n        - Hendrycks et al. Benchmarking neural network robustness\n        to common corruptions and perturbations. ICLR 2019.\n    \"\"\"\n\n    dataset_dir = \"\"\n    domains = [\"cifar10\", \"cifar10_c\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = root\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n        source_domain = cfg.DATASET.SOURCE_DOMAINS[0]\n        target_domain = cfg.DATASET.TARGET_DOMAINS[0]\n        assert source_domain == self.domains[0]\n        assert target_domain == self.domains[1]\n\n        c_type = cfg.DATASET.CIFAR_C_TYPE\n        c_level = cfg.DATASET.CIFAR_C_LEVEL\n\n        if not c_type:\n            raise ValueError(\n                \"Please specify DATASET.CIFAR_C_TYPE in the config file\"\n            )\n\n        assert (\n            c_type in AVAI_C_TYPES\n        ), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got \"{c_type}\"'\n        assert 1 <= c_level <= 5\n\n        train_dir = osp.join(self.dataset_dir, source_domain, \"train\")\n        test_dir = osp.join(\n            self.dataset_dir, target_domain, c_type, str(c_level)\n        )\n\n        if not osp.exists(test_dir):\n            raise ValueError\n\n        train = self._read_data(train_dir)\n        test = self._read_data(test_dir)\n\n        super().__init__(train_x=train, test=test)\n\n    def _read_data(self, data_dir):\n        class_names = listdir_nohidden(data_dir)\n        class_names.sort()\n        items = []\n\n        for label, class_name in enumerate(class_names):\n            class_dir = osp.join(data_dir, class_name)\n            imnames = listdir_nohidden(class_dir)\n\n            for imname in imnames:\n                impath = osp.join(class_dir, imname)\n                item = Datum(impath=impath, label=label, domain=0)\n                items.append(item)\n\n        return items\n\n\n@DATASET_REGISTRY.register()\nclass CIFAR100C(CIFAR10C):\n    \"\"\"CIFAR-100 -> CIFAR-100-C.\n\n    Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o\n\n    Statistics:\n        - 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100\n        - 10 categories\n\n    Reference:\n        - Hendrycks et al. Benchmarking neural network robustness\n        to common corruptions and perturbations. ICLR 2019.\n    \"\"\"\n\n    dataset_dir = \"\"\n    domains = [\"cifar100\", \"cifar100_c\"]\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py",
    "content": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n# Folder names for train and test sets\nMNIST = {\"train\": \"train_images\", \"test\": \"test_images\"}\nMNIST_M = {\"train\": \"train_images\", \"test\": \"test_images\"}\nSVHN = {\"train\": \"train_images\", \"test\": \"test_images\"}\nSYN = {\"train\": \"train_images\", \"test\": \"test_images\"}\nUSPS = {\"train\": \"train_images\", \"test\": \"test_images\"}\n\n\ndef read_image_list(im_dir, n_max=None, n_repeat=None):\n    items = []\n\n    for imname in listdir_nohidden(im_dir):\n        imname_noext = osp.splitext(imname)[0]\n        label = int(imname_noext.split(\"_\")[1])\n        impath = osp.join(im_dir, imname)\n        items.append((impath, label))\n\n    if n_max is not None:\n        # Note that the sampling process is NOT random,\n        # which follows that in Volpi et al. NIPS'18.\n        items = items[:n_max]\n\n    if n_repeat is not None:\n        items *= n_repeat\n\n    return items\n\n\ndef load_mnist(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, MNIST[split])\n    n_max = 10000 if split == \"train\" else None\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_mnist_m(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, MNIST_M[split])\n    n_max = 10000 if split == \"train\" else None\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_svhn(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, SVHN[split])\n    n_max = 10000 if split == \"train\" else None\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_syn(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, SYN[split])\n    n_max = 10000 if split == \"train\" else None\n    return read_image_list(data_dir, n_max=n_max)\n\n\ndef load_usps(dataset_dir, split=\"train\"):\n    data_dir = osp.join(dataset_dir, USPS[split])\n    return read_image_list(data_dir)\n\n\n@DATASET_REGISTRY.register()\nclass DigitSingle(DatasetBase):\n    \"\"\"Digit recognition datasets for single-source domain generalization.\n\n    There are five digit datasets:\n        - MNIST: hand-written digits.\n        - MNIST-M: variant of MNIST with blended background.\n        - SVHN: street view house number.\n        - SYN: synthetic digits.\n        - USPS: hand-written digits, slightly different from MNIST.\n\n    Protocol:\n        Volpi et al. train a model using 10,000 images from MNIST and\n        evaluate the model on the test split of the other four datasets. However,\n        the code does not restrict you to only use MNIST as the source dataset.\n        Instead, you can use any dataset as the source. But note that only 10,000\n        images will be sampled from the source dataset for training.\n\n    Reference:\n        - Lecun et al. Gradient-based learning applied to document\n        recognition. IEEE 1998.\n        - Ganin et al. Domain-adversarial training of neural networks.\n        JMLR 2016.\n        - Netzer et al. Reading digits in natural images with unsupervised\n        feature learning. NIPS-W 2011.\n        - Volpi et al. Generalizing to Unseen Domains via Adversarial Data\n        Augmentation. NIPS 2018.\n    \"\"\"\n\n    # Reuse the digit-5 folder instead of creating a new folder\n    dataset_dir = \"digit5\"\n    domains = [\"mnist\", \"mnist_m\", \"svhn\", \"syn\", \"usps\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split=\"train\")\n        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split=\"test\")\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split=\"test\")\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def _read_data(self, input_domains, split=\"train\"):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            func = \"load_\" + dname\n            domain_dir = osp.join(self.dataset_dir, dname)\n            items_d = eval(func)(domain_dir, split=split)\n\n            for impath, label in items_d:\n                item = Datum(impath=impath, label=label, domain=domain)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py",
    "content": "import glob\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass DigitsDG(DatasetBase):\n    \"\"\"Digits-DG.\n\n    It contains 4 digit datasets:\n        - MNIST: hand-written digits.\n        - MNIST-M: variant of MNIST with blended background.\n        - SVHN: street view house number.\n        - SYN: synthetic digits.\n\n    Reference:\n        - Lecun et al. Gradient-based learning applied to document\n        recognition. IEEE 1998.\n        - Ganin et al. Domain-adversarial training of neural networks.\n        JMLR 2016.\n        - Netzer et al. Reading digits in natural images with unsupervised\n        feature learning. NIPS-W 2011.\n        - Zhou et al. Deep Domain-Adversarial Image Generation for Domain\n        Generalisation. AAAI 2020.\n    \"\"\"\n\n    dataset_dir = \"digits_dg\"\n    domains = [\"mnist\", \"mnist_m\", \"svhn\", \"syn\"]\n    data_url = \"https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7\"\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        if not osp.exists(self.dataset_dir):\n            dst = osp.join(root, \"digits_dg.zip\")\n            self.download_data(self.data_url, dst, from_gdrive=True)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train = self.read_data(\n            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, \"train\"\n        )\n        val = self.read_data(\n            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, \"val\"\n        )\n        test = self.read_data(\n            self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, \"all\"\n        )\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    @staticmethod\n    def read_data(dataset_dir, input_domains, split):\n\n        def _load_data_from_directory(directory):\n            folders = listdir_nohidden(directory)\n            folders.sort()\n            items_ = []\n\n            for label, folder in enumerate(folders):\n                impaths = glob.glob(osp.join(directory, folder, \"*.jpg\"))\n\n                for impath in impaths:\n                    items_.append((impath, label))\n\n            return items_\n\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            if split == \"all\":\n                train_dir = osp.join(dataset_dir, dname, \"train\")\n                impath_label_list = _load_data_from_directory(train_dir)\n                val_dir = osp.join(dataset_dir, dname, \"val\")\n                impath_label_list += _load_data_from_directory(val_dir)\n            else:\n                split_dir = osp.join(dataset_dir, dname, split)\n                impath_label_list = _load_data_from_directory(split_dir)\n\n            for impath, label in impath_label_list:\n                class_name = impath.split(\"/\")[-2].lower()\n                item = Datum(\n                    impath=impath,\n                    label=label,\n                    domain=domain,\n                    classname=class_name\n                )\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/office_home_dg.py",
    "content": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom .digits_dg import DigitsDG\nfrom ..base_dataset import DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass OfficeHomeDG(DatasetBase):\n    \"\"\"Office-Home.\n\n    Statistics:\n        - Around 15,500 images.\n        - 65 classes related to office and home objects.\n        - 4 domains: Art, Clipart, Product, Real World.\n        - URL: http://hemanthdv.org/OfficeHome-Dataset/.\n\n    Reference:\n        - Venkateswara et al. Deep Hashing Network for Unsupervised\n        Domain Adaptation. CVPR 2017.\n    \"\"\"\n\n    dataset_dir = \"office_home_dg\"\n    domains = [\"art\", \"clipart\", \"product\", \"real_world\"]\n    data_url = \"https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa\"\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        if not osp.exists(self.dataset_dir):\n            dst = osp.join(root, \"office_home_dg.zip\")\n            self.download_data(self.data_url, dst, from_gdrive=True)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train = DigitsDG.read_data(\n            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, \"train\"\n        )\n        val = DigitsDG.read_data(\n            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, \"val\"\n        )\n        test = DigitsDG.read_data(\n            self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, \"all\"\n        )\n\n        super().__init__(train_x=train, val=val, test=test)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py",
    "content": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass PACS(DatasetBase):\n    \"\"\"PACS.\n\n    Statistics:\n        - 4 domains: Photo (1,670), Art (2,048), Cartoon\n        (2,344), Sketch (3,929).\n        - 7 categories: dog, elephant, giraffe, guitar, horse,\n        house and person.\n\n    Reference:\n        - Li et al. Deeper, broader and artier domain generalization.\n        ICCV 2017.\n    \"\"\"\n\n    dataset_dir = \"pacs\"\n    domains = [\"art_painting\", \"cartoon\", \"photo\", \"sketch\"]\n    data_url = \"https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE\"\n    # the following images contain errors and should be ignored\n    _error_paths = [\"sketch/dog/n02103406_4068-1.png\"]\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.image_dir = osp.join(self.dataset_dir, \"images\")\n        self.split_dir = osp.join(self.dataset_dir, \"splits\")\n\n        if not osp.exists(self.dataset_dir):\n            dst = osp.join(root, \"pacs.zip\")\n            self.download_data(self.data_url, dst, from_gdrive=True)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, \"train\")\n        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, \"crossval\")\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, \"all\")\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def _read_data(self, input_domains, split):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            if split == \"all\":\n                file_train = osp.join(\n                    self.split_dir, dname + \"_train_kfold.txt\"\n                )\n                impath_label_list = self._read_split_pacs(file_train)\n                file_val = osp.join(\n                    self.split_dir, dname + \"_crossval_kfold.txt\"\n                )\n                impath_label_list += self._read_split_pacs(file_val)\n            else:\n                file = osp.join(\n                    self.split_dir, dname + \"_\" + split + \"_kfold.txt\"\n                )\n                impath_label_list = self._read_split_pacs(file)\n\n            for impath, label in impath_label_list:\n                classname = impath.split(\"/\")[-2]\n                item = Datum(\n                    impath=impath,\n                    label=label,\n                    domain=domain,\n                    classname=classname\n                )\n                items.append(item)\n\n        return items\n\n    def _read_split_pacs(self, split_file):\n        items = []\n\n        with open(split_file, \"r\") as f:\n            lines = f.readlines()\n\n            for line in lines:\n                line = line.strip()\n                impath, label = line.split(\" \")\n                if impath in self._error_paths:\n                    continue\n                impath = osp.join(self.image_dir, impath)\n                label = int(label) - 1\n                items.append((impath, label))\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py",
    "content": "import glob\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass VLCS(DatasetBase):\n    \"\"\"VLCS.\n\n    Statistics:\n        - 4 domains: CALTECH, LABELME, PASCAL, SUN\n        - 5 categories: bird, car, chair, dog, and person.\n\n    Reference:\n        - Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.\n    \"\"\"\n\n    dataset_dir = \"VLCS\"\n    domains = [\"caltech\", \"labelme\", \"pascal\", \"sun\"]\n    data_url = \"https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd\"\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n\n        if not osp.exists(self.dataset_dir):\n            dst = osp.join(root, \"vlcs.zip\")\n            self.download_data(self.data_url, dst, from_gdrive=True)\n\n        self.check_input_domains(\n            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS\n        )\n\n        train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, \"train\")\n        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, \"crossval\")\n        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, \"test\")\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def _read_data(self, input_domains, split):\n        items = []\n\n        for domain, dname in enumerate(input_domains):\n            dname = dname.upper()\n            path = osp.join(self.dataset_dir, dname, split)\n            folders = listdir_nohidden(path)\n            folders.sort()\n\n            for label, folder in enumerate(folders):\n                impaths = glob.glob(osp.join(path, folder, \"*.jpg\"))\n\n                for impath in impaths:\n                    item = Datum(impath=impath, label=label, domain=domain)\n                    items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/__init__.py",
    "content": "from .svhn import SVHN\nfrom .cifar import CIFAR10, CIFAR100\nfrom .stl10 import STL10\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py",
    "content": "import math\nimport random\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass CIFAR10(DatasetBase):\n    \"\"\"CIFAR10 for SSL.\n\n    Reference:\n        - Krizhevsky. Learning Multiple Layers of Features\n        from Tiny Images. Tech report.\n    \"\"\"\n\n    dataset_dir = \"cifar10\"\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        train_dir = osp.join(self.dataset_dir, \"train\")\n        test_dir = osp.join(self.dataset_dir, \"test\")\n\n        assert cfg.DATASET.NUM_LABELED > 0\n\n        train_x, train_u, val = self._read_data_train(\n            train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT\n        )\n        test = self._read_data_test(test_dir)\n\n        if cfg.DATASET.ALL_AS_UNLABELED:\n            train_u = train_u + train_x\n\n        if len(val) == 0:\n            val = None\n\n        super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)\n\n    def _read_data_train(self, data_dir, num_labeled, val_percent):\n        class_names = listdir_nohidden(data_dir)\n        class_names.sort()\n        num_labeled_per_class = num_labeled / len(class_names)\n        items_x, items_u, items_v = [], [], []\n\n        for label, class_name in enumerate(class_names):\n            class_dir = osp.join(data_dir, class_name)\n            imnames = listdir_nohidden(class_dir)\n\n            # Split into train and val following Oliver et al. 2018\n            # Set cfg.DATASET.VAL_PERCENT to 0 to not use val data\n            num_val = math.floor(len(imnames) * val_percent)\n            imnames_train = imnames[num_val:]\n            imnames_val = imnames[:num_val]\n\n            # Note we do shuffle after split\n            random.shuffle(imnames_train)\n\n            for i, imname in enumerate(imnames_train):\n                impath = osp.join(class_dir, imname)\n                item = Datum(impath=impath, label=label)\n\n                if (i + 1) <= num_labeled_per_class:\n                    items_x.append(item)\n\n                else:\n                    items_u.append(item)\n\n            for imname in imnames_val:\n                impath = osp.join(class_dir, imname)\n                item = Datum(impath=impath, label=label)\n                items_v.append(item)\n\n        return items_x, items_u, items_v\n\n    def _read_data_test(self, data_dir):\n        class_names = listdir_nohidden(data_dir)\n        class_names.sort()\n        items = []\n\n        for label, class_name in enumerate(class_names):\n            class_dir = osp.join(data_dir, class_name)\n            imnames = listdir_nohidden(class_dir)\n\n            for imname in imnames:\n                impath = osp.join(class_dir, imname)\n                item = Datum(impath=impath, label=label)\n                items.append(item)\n\n        return items\n\n\n@DATASET_REGISTRY.register()\nclass CIFAR100(CIFAR10):\n    \"\"\"CIFAR100 for SSL.\n\n    Reference:\n        - Krizhevsky. Learning Multiple Layers of Features\n        from Tiny Images. Tech report.\n    \"\"\"\n\n    dataset_dir = \"cifar100\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py",
    "content": "import numpy as np\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REGISTRY.register()\nclass STL10(DatasetBase):\n    \"\"\"STL-10 dataset.\n\n    Description:\n    - 10 classes: airplane, bird, car, cat, deer, dog, horse,\n    monkey, ship, truck.\n    - Images are 96x96 pixels, color.\n    - 500 training images per class, 800 test images per class.\n    - 100,000 unlabeled images for unsupervised learning.\n\n    Reference:\n        - Coates et al. An Analysis of Single Layer Networks in\n        Unsupervised Feature Learning. AISTATS 2011.\n    \"\"\"\n\n    dataset_dir = \"stl10\"\n\n    def __init__(self, cfg):\n        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        train_dir = osp.join(self.dataset_dir, \"train\")\n        test_dir = osp.join(self.dataset_dir, \"test\")\n        unlabeled_dir = osp.join(self.dataset_dir, \"unlabeled\")\n        fold_file = osp.join(\n            self.dataset_dir, \"stl10_binary\", \"fold_indices.txt\"\n        )\n\n        # Only use the first five splits\n        assert 0 <= cfg.DATASET.STL10_FOLD <= 4\n\n        train_x = self._read_data_train(\n            train_dir, cfg.DATASET.STL10_FOLD, fold_file\n        )\n        train_u = self._read_data_all(unlabeled_dir)\n        test = self._read_data_all(test_dir)\n\n        if cfg.DATASET.ALL_AS_UNLABELED:\n            train_u = train_u + train_x\n\n        super().__init__(train_x=train_x, train_u=train_u, test=test)\n\n    def _read_data_train(self, data_dir, fold, fold_file):\n        imnames = listdir_nohidden(data_dir)\n        imnames.sort()\n        items = []\n\n        list_idx = list(range(len(imnames)))\n        if fold >= 0:\n            with open(fold_file, \"r\") as f:\n                str_idx = f.read().splitlines()[fold]\n                list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=\" \")\n\n        for i in list_idx:\n            imname = imnames[i]\n            impath = osp.join(data_dir, imname)\n            label = osp.splitext(imname)[0].split(\"_\")[1]\n            label = int(label)\n            item = Datum(impath=impath, label=label)\n            items.append(item)\n\n        return items\n\n    def _read_data_all(self, data_dir):\n        imnames = listdir_nohidden(data_dir)\n        items = []\n\n        for imname in imnames:\n            impath = osp.join(data_dir, imname)\n            label = osp.splitext(imname)[0].split(\"_\")[1]\n            if label == \"none\":\n                label = -1\n            else:\n                label = int(label)\n            item = Datum(impath=impath, label=label)\n            items.append(item)\n\n        return items\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py",
    "content": "from .cifar import CIFAR10\nfrom ..build import DATASET_REGISTRY\n\n\n@DATASET_REGISTRY.register()\nclass SVHN(CIFAR10):\n    \"\"\"SVHN for SSL.\n\n    Reference:\n        - Netzer et al. Reading Digits in Natural Images with\n        Unsupervised Feature Learning. NIPS-W 2011.\n    \"\"\"\n\n    dataset_dir = \"svhn\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/samplers.py",
    "content": "import copy\nimport numpy as np\nimport random\nfrom collections import defaultdict\nfrom torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler\n\n\nclass RandomDomainSampler(Sampler):\n    \"\"\"Randomly samples N domains each with K images\n    to form a minibatch of size N*K.\n\n    Args:\n        data_source (list): list of Datums.\n        batch_size (int): batch size.\n        n_domain (int): number of domains to sample in a minibatch.\n    \"\"\"\n\n    def __init__(self, data_source, batch_size, n_domain):\n        self.data_source = data_source\n\n        # Keep track of image indices for each domain\n        self.domain_dict = defaultdict(list)\n        for i, item in enumerate(data_source):\n            self.domain_dict[item.domain].append(i)\n        self.domains = list(self.domain_dict.keys())\n\n        # Make sure each domain has equal number of images\n        if n_domain is None or n_domain <= 0:\n            n_domain = len(self.domains)\n        assert batch_size % n_domain == 0\n        self.n_img_per_domain = batch_size // n_domain\n\n        self.batch_size = batch_size\n        # n_domain denotes number of domains sampled in a minibatch\n        self.n_domain = n_domain\n        self.length = len(list(self.__iter__()))\n\n    def __iter__(self):\n        domain_dict = copy.deepcopy(self.domain_dict)\n        final_idxs = []\n        stop_sampling = False\n\n        while not stop_sampling:\n            selected_domains = random.sample(self.domains, self.n_domain)\n\n            for domain in selected_domains:\n                idxs = domain_dict[domain]\n                selected_idxs = random.sample(idxs, self.n_img_per_domain)\n                final_idxs.extend(selected_idxs)\n\n                for idx in selected_idxs:\n                    domain_dict[domain].remove(idx)\n\n                remaining = len(domain_dict[domain])\n                if remaining < self.n_img_per_domain:\n                    stop_sampling = True\n\n        return iter(final_idxs)\n\n    def __len__(self):\n        return self.length\n\n\nclass SeqDomainSampler(Sampler):\n    \"\"\"Sequential domain sampler, which randomly samples K\n    images from each domain to form a minibatch.\n\n    Args:\n        data_source (list): list of Datums.\n        batch_size (int): batch size.\n    \"\"\"\n\n    def __init__(self, data_source, batch_size):\n        self.data_source = data_source\n\n        # Keep track of image indices for each domain\n        self.domain_dict = defaultdict(list)\n        for i, item in enumerate(data_source):\n            self.domain_dict[item.domain].append(i)\n        self.domains = list(self.domain_dict.keys())\n        self.domains.sort()\n\n        # Make sure each domain has equal number of images\n        n_domain = len(self.domains)\n        assert batch_size % n_domain == 0\n        self.n_img_per_domain = batch_size // n_domain\n\n        self.batch_size = batch_size\n        # n_domain denotes number of domains sampled in a minibatch\n        self.n_domain = n_domain\n        self.length = len(list(self.__iter__()))\n\n    def __iter__(self):\n        domain_dict = copy.deepcopy(self.domain_dict)\n        final_idxs = []\n        stop_sampling = False\n\n        while not stop_sampling:\n            for domain in self.domains:\n                idxs = domain_dict[domain]\n                selected_idxs = random.sample(idxs, self.n_img_per_domain)\n                final_idxs.extend(selected_idxs)\n\n                for idx in selected_idxs:\n                    domain_dict[domain].remove(idx)\n\n                remaining = len(domain_dict[domain])\n                if remaining < self.n_img_per_domain:\n                    stop_sampling = True\n\n        return iter(final_idxs)\n\n    def __len__(self):\n        return self.length\n\n\nclass RandomClassSampler(Sampler):\n    \"\"\"Randomly samples N classes each with K instances to\n    form a minibatch of size N*K.\n\n    Modified from https://github.com/KaiyangZhou/deep-person-reid.\n\n    Args:\n        data_source (list): list of Datums.\n        batch_size (int): batch size.\n        n_ins (int): number of instances per class to sample in a minibatch.\n    \"\"\"\n\n    def __init__(self, data_source, batch_size, n_ins):\n        if batch_size < n_ins:\n            raise ValueError(\n                \"batch_size={} must be no less \"\n                \"than n_ins={}\".format(batch_size, n_ins)\n            )\n\n        self.data_source = data_source\n        self.batch_size = batch_size\n        self.n_ins = n_ins\n        self.ncls_per_batch = self.batch_size // self.n_ins\n        self.index_dic = defaultdict(list)\n        for index, item in enumerate(data_source):\n            self.index_dic[item.label].append(index)\n        self.labels = list(self.index_dic.keys())\n        assert len(self.labels) >= self.ncls_per_batch\n\n        # estimate number of images in an epoch\n        self.length = len(list(self.__iter__()))\n\n    def __iter__(self):\n        batch_idxs_dict = defaultdict(list)\n\n        for label in self.labels:\n            idxs = copy.deepcopy(self.index_dic[label])\n            if len(idxs) < self.n_ins:\n                idxs = np.random.choice(idxs, size=self.n_ins, replace=True)\n            random.shuffle(idxs)\n            batch_idxs = []\n            for idx in idxs:\n                batch_idxs.append(idx)\n                if len(batch_idxs) == self.n_ins:\n                    batch_idxs_dict[label].append(batch_idxs)\n                    batch_idxs = []\n\n        avai_labels = copy.deepcopy(self.labels)\n        final_idxs = []\n\n        while len(avai_labels) >= self.ncls_per_batch:\n            selected_labels = random.sample(avai_labels, self.ncls_per_batch)\n            for label in selected_labels:\n                batch_idxs = batch_idxs_dict[label].pop(0)\n                final_idxs.extend(batch_idxs)\n                if len(batch_idxs_dict[label]) == 0:\n                    avai_labels.remove(label)\n\n        return iter(final_idxs)\n\n    def __len__(self):\n        return self.length\n\n\ndef build_sampler(\n    sampler_type,\n    cfg=None,\n    data_source=None,\n    batch_size=32,\n    n_domain=0,\n    n_ins=16\n):\n    if sampler_type == \"RandomSampler\":\n        return RandomSampler(data_source)\n\n    elif sampler_type == \"SequentialSampler\":\n        return SequentialSampler(data_source)\n\n    elif sampler_type == \"RandomDomainSampler\":\n        return RandomDomainSampler(data_source, batch_size, n_domain)\n\n    elif sampler_type == \"SeqDomainSampler\":\n        return SeqDomainSampler(data_source, batch_size)\n\n    elif sampler_type == \"RandomClassSampler\":\n        return RandomClassSampler(data_source, batch_size, n_ins)\n\n    else:\n        raise ValueError(\"Unknown sampler type: {}\".format(sampler_type))\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py",
    "content": "from .transforms import build_transform\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py",
    "content": "\"\"\"\nSource: https://github.com/DeepVoltaire/AutoAugment\n\"\"\"\nimport numpy as np\nimport random\nfrom PIL import Image, ImageOps, ImageEnhance\n\n\nclass ImageNetPolicy:\n    \"\"\"Randomly choose one of the best 24 Sub-policies on ImageNet.\n\n    Example:\n        >>> policy = ImageNetPolicy()\n        >>> transformed = policy(image)\n\n    Example as a PyTorch Transform:\n        >>> transform=transforms.Compose([\n        >>>     transforms.Resize(256),\n        >>>     ImageNetPolicy(),\n        >>>     transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.4, \"posterize\", 8, 0.6, \"rotate\", 9, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.6, \"posterize\", 7, 0.6, \"posterize\", 6, fillcolor),\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n            SubPolicy(0.4, \"equalize\", 4, 0.8, \"rotate\", 8, fillcolor),\n            SubPolicy(0.6, \"solarize\", 3, 0.6, \"equalize\", 7, fillcolor),\n            SubPolicy(0.8, \"posterize\", 5, 1.0, \"equalize\", 2, fillcolor),\n            SubPolicy(0.2, \"rotate\", 3, 0.6, \"solarize\", 8, fillcolor),\n            SubPolicy(0.6, \"equalize\", 8, 0.4, \"posterize\", 6, fillcolor),\n            SubPolicy(0.8, \"rotate\", 8, 0.4, \"color\", 0, fillcolor),\n            SubPolicy(0.4, \"rotate\", 9, 0.6, \"equalize\", 2, fillcolor),\n            SubPolicy(0.0, \"equalize\", 7, 0.8, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n            SubPolicy(0.8, \"rotate\", 8, 1.0, \"color\", 2, fillcolor),\n            SubPolicy(0.8, \"color\", 8, 0.8, \"solarize\", 7, fillcolor),\n            SubPolicy(0.4, \"sharpness\", 7, 0.6, \"invert\", 8, fillcolor),\n            SubPolicy(0.6, \"shearX\", 5, 1.0, \"equalize\", 9, fillcolor),\n            SubPolicy(0.4, \"color\", 0, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.4, \"equalize\", 7, 0.2, \"solarize\", 4, fillcolor),\n            SubPolicy(0.6, \"solarize\", 5, 0.6, \"autocontrast\", 5, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 1.0, \"equalize\", 8, fillcolor),\n            SubPolicy(0.6, \"color\", 4, 1.0, \"contrast\", 8, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.6, \"equalize\", 3, fillcolor),\n        ]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment ImageNet Policy\"\n\n\nclass CIFAR10Policy:\n    \"\"\"Randomly choose one of the best 25 Sub-policies on CIFAR10.\n\n    Example:\n        >>> policy = CIFAR10Policy()\n        >>> transformed = policy(image)\n\n    Example as a PyTorch Transform:\n        >>> transform=transforms.Compose([\n        >>>     transforms.Resize(256),\n        >>>     CIFAR10Policy(),\n        >>>     transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.1, \"invert\", 7, 0.2, \"contrast\", 6, fillcolor),\n            SubPolicy(0.7, \"rotate\", 2, 0.3, \"translateX\", 9, fillcolor),\n            SubPolicy(0.8, \"sharpness\", 1, 0.9, \"sharpness\", 3, fillcolor),\n            SubPolicy(0.5, \"shearY\", 8, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.5, \"autocontrast\", 8, 0.9, \"equalize\", 2, fillcolor),\n            SubPolicy(0.2, \"shearY\", 7, 0.3, \"posterize\", 7, fillcolor),\n            SubPolicy(0.4, \"color\", 3, 0.6, \"brightness\", 7, fillcolor),\n            SubPolicy(0.3, \"sharpness\", 9, 0.7, \"brightness\", 9, fillcolor),\n            SubPolicy(0.6, \"equalize\", 5, 0.5, \"equalize\", 1, fillcolor),\n            SubPolicy(0.6, \"contrast\", 7, 0.6, \"sharpness\", 5, fillcolor),\n            SubPolicy(0.7, \"color\", 7, 0.5, \"translateX\", 8, fillcolor),\n            SubPolicy(0.3, \"equalize\", 7, 0.4, \"autocontrast\", 8, fillcolor),\n            SubPolicy(0.4, \"translateY\", 3, 0.2, \"sharpness\", 6, fillcolor),\n            SubPolicy(0.9, \"brightness\", 6, 0.2, \"color\", 8, fillcolor),\n            SubPolicy(0.5, \"solarize\", 2, 0.0, \"invert\", 3, fillcolor),\n            SubPolicy(0.2, \"equalize\", 0, 0.6, \"autocontrast\", 0, fillcolor),\n            SubPolicy(0.2, \"equalize\", 8, 0.6, \"equalize\", 4, fillcolor),\n            SubPolicy(0.9, \"color\", 9, 0.6, \"equalize\", 6, fillcolor),\n            SubPolicy(0.8, \"autocontrast\", 4, 0.2, \"solarize\", 8, fillcolor),\n            SubPolicy(0.1, \"brightness\", 3, 0.7, \"color\", 0, fillcolor),\n            SubPolicy(0.4, \"solarize\", 5, 0.9, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.9, \"translateY\", 9, 0.7, \"translateY\", 9, fillcolor),\n            SubPolicy(0.9, \"autocontrast\", 2, 0.8, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"equalize\", 8, 0.1, \"invert\", 3, fillcolor),\n            SubPolicy(0.7, \"translateY\", 9, 0.9, \"autocontrast\", 1, fillcolor),\n        ]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment CIFAR10 Policy\"\n\n\nclass SVHNPolicy:\n    \"\"\"Randomly choose one of the best 25 Sub-policies on SVHN.\n\n    Example:\n        >>> policy = SVHNPolicy()\n        >>> transformed = policy(image)\n\n    Example as a PyTorch Transform:\n        >>> transform=transforms.Compose([\n        >>>     transforms.Resize(256),\n        >>>     SVHNPolicy(),\n        >>>     transforms.ToTensor()])\n    \"\"\"\n\n    def __init__(self, fillcolor=(128, 128, 128)):\n        self.policies = [\n            SubPolicy(0.9, \"shearX\", 4, 0.2, \"invert\", 3, fillcolor),\n            SubPolicy(0.9, \"shearY\", 8, 0.7, \"invert\", 5, fillcolor),\n            SubPolicy(0.6, \"equalize\", 5, 0.6, \"solarize\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 3, 0.6, \"equalize\", 3, fillcolor),\n            SubPolicy(0.6, \"equalize\", 1, 0.9, \"rotate\", 3, fillcolor),\n            SubPolicy(0.9, \"shearX\", 4, 0.8, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.9, \"shearY\", 8, 0.4, \"invert\", 5, fillcolor),\n            SubPolicy(0.9, \"shearY\", 5, 0.2, \"solarize\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 6, 0.8, \"autocontrast\", 1, fillcolor),\n            SubPolicy(0.6, \"equalize\", 3, 0.9, \"rotate\", 3, fillcolor),\n            SubPolicy(0.9, \"shearX\", 4, 0.3, \"solarize\", 3, fillcolor),\n            SubPolicy(0.8, \"shearY\", 8, 0.7, \"invert\", 4, fillcolor),\n            SubPolicy(0.9, \"equalize\", 5, 0.6, \"translateY\", 6, fillcolor),\n            SubPolicy(0.9, \"invert\", 4, 0.6, \"equalize\", 7, fillcolor),\n            SubPolicy(0.3, \"contrast\", 3, 0.8, \"rotate\", 4, fillcolor),\n            SubPolicy(0.8, \"invert\", 5, 0.0, \"translateY\", 2, fillcolor),\n            SubPolicy(0.7, \"shearY\", 6, 0.4, \"solarize\", 8, fillcolor),\n            SubPolicy(0.6, \"invert\", 4, 0.8, \"rotate\", 4, fillcolor),\n            SubPolicy(0.3, \"shearY\", 7, 0.9, \"translateX\", 3, fillcolor),\n            SubPolicy(0.1, \"shearX\", 6, 0.6, \"invert\", 5, fillcolor),\n            SubPolicy(0.7, \"solarize\", 2, 0.6, \"translateY\", 7, fillcolor),\n            SubPolicy(0.8, \"shearY\", 4, 0.8, \"invert\", 8, fillcolor),\n            SubPolicy(0.7, \"shearX\", 9, 0.8, \"translateY\", 3, fillcolor),\n            SubPolicy(0.8, \"shearY\", 5, 0.7, \"autocontrast\", 3, fillcolor),\n            SubPolicy(0.7, \"shearX\", 2, 0.1, \"invert\", 5, fillcolor),\n        ]\n\n    def __call__(self, img):\n        policy_idx = random.randint(0, len(self.policies) - 1)\n        return self.policies[policy_idx](img)\n\n    def __repr__(self):\n        return \"AutoAugment SVHN Policy\"\n\n\nclass SubPolicy(object):\n\n    def __init__(\n        self,\n        p1,\n        operation1,\n        magnitude_idx1,\n        p2,\n        operation2,\n        magnitude_idx2,\n        fillcolor=(128, 128, 128),\n    ):\n        ranges = {\n            \"shearX\": np.linspace(0, 0.3, 10),\n            \"shearY\": np.linspace(0, 0.3, 10),\n            \"translateX\": np.linspace(0, 150 / 331, 10),\n            \"translateY\": np.linspace(0, 150 / 331, 10),\n            \"rotate\": np.linspace(0, 30, 10),\n            \"color\": np.linspace(0.0, 0.9, 10),\n            \"posterize\": np.round(np.linspace(8, 4, 10), 0).astype(np.int),\n            \"solarize\": np.linspace(256, 0, 10),\n            \"contrast\": np.linspace(0.0, 0.9, 10),\n            \"sharpness\": np.linspace(0.0, 0.9, 10),\n            \"brightness\": np.linspace(0.0, 0.9, 10),\n            \"autocontrast\": [0] * 10,\n            \"equalize\": [0] * 10,\n            \"invert\": [0] * 10,\n        }\n\n        # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand\n        def rotate_with_fill(img, magnitude):\n            rot = img.convert(\"RGBA\").rotate(magnitude)\n            return Image.composite(\n                rot, Image.new(\"RGBA\", rot.size, (128, ) * 4), rot\n            ).convert(img.mode)\n\n        func = {\n            \"shearX\":\n            lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),\n                Image.BICUBIC,\n                fillcolor=fillcolor,\n            ),\n            \"shearY\":\n            lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),\n                Image.BICUBIC,\n                fillcolor=fillcolor,\n            ),\n            \"translateX\":\n            lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (\n                    1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0,\n                    1, 0\n                ),\n                fillcolor=fillcolor,\n            ),\n            \"translateY\":\n            lambda img, magnitude: img.transform(\n                img.size,\n                Image.AFFINE,\n                (\n                    1, 0, 0, 0, 1, magnitude * img.size[1] * random.\n                    choice([-1, 1])\n                ),\n                fillcolor=fillcolor,\n            ),\n            \"rotate\":\n            lambda img, magnitude: rotate_with_fill(img, magnitude),\n            \"color\":\n            lambda img, magnitude: ImageEnhance.Color(img).\n            enhance(1 + magnitude * random.choice([-1, 1])),\n            \"posterize\":\n            lambda img, magnitude: ImageOps.posterize(img, magnitude),\n            \"solarize\":\n            lambda img, magnitude: ImageOps.solarize(img, magnitude),\n            \"contrast\":\n            lambda img, magnitude: ImageEnhance.Contrast(img).\n            enhance(1 + magnitude * random.choice([-1, 1])),\n            \"sharpness\":\n            lambda img, magnitude: ImageEnhance.Sharpness(img).\n            enhance(1 + magnitude * random.choice([-1, 1])),\n            \"brightness\":\n            lambda img, magnitude: ImageEnhance.Brightness(img).\n            enhance(1 + magnitude * random.choice([-1, 1])),\n            \"autocontrast\":\n            lambda img, magnitude: ImageOps.autocontrast(img),\n            \"equalize\":\n            lambda img, magnitude: ImageOps.equalize(img),\n            \"invert\":\n            lambda img, magnitude: ImageOps.invert(img),\n        }\n\n        self.p1 = p1\n        self.operation1 = func[operation1]\n        self.magnitude1 = ranges[operation1][magnitude_idx1]\n        self.p2 = p2\n        self.operation2 = func[operation2]\n        self.magnitude2 = ranges[operation2][magnitude_idx2]\n\n    def __call__(self, img):\n        if random.random() < self.p1:\n            img = self.operation1(img, self.magnitude1)\n        if random.random() < self.p2:\n            img = self.operation2(img, self.magnitude2)\n        return img\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py",
    "content": "\"\"\"\nCredit to\n1) https://github.com/ildoonet/pytorch-randaugment\n2) https://github.com/kakaobrain/fast-autoaugment\n\"\"\"\nimport numpy as np\nimport random\nimport PIL\nimport torch\nimport PIL.ImageOps\nimport PIL.ImageDraw\nimport PIL.ImageEnhance\nfrom PIL import Image\n\n\ndef ShearX(img, v):\n    assert -0.3 <= v <= 0.3\n    if random.random() > 0.5:\n        v = -v\n    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))\n\n\ndef ShearY(img, v):\n    assert -0.3 <= v <= 0.3\n    if random.random() > 0.5:\n        v = -v\n    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))\n\n\ndef TranslateX(img, v):\n    # [-150, 150] => percentage: [-0.45, 0.45]\n    assert -0.45 <= v <= 0.45\n    if random.random() > 0.5:\n        v = -v\n    v = v * img.size[0]\n    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))\n\n\ndef TranslateXabs(img, v):\n    # [-150, 150] => percentage: [-0.45, 0.45]\n    assert 0 <= v\n    if random.random() > 0.5:\n        v = -v\n    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))\n\n\ndef TranslateY(img, v):\n    # [-150, 150] => percentage: [-0.45, 0.45]\n    assert -0.45 <= v <= 0.45\n    if random.random() > 0.5:\n        v = -v\n    v = v * img.size[1]\n    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))\n\n\ndef TranslateYabs(img, v):\n    # [-150, 150] => percentage: [-0.45, 0.45]\n    assert 0 <= v\n    if random.random() > 0.5:\n        v = -v\n    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))\n\n\ndef Rotate(img, v):\n    assert -30 <= v <= 30\n    if random.random() > 0.5:\n        v = -v\n    return img.rotate(v)\n\n\ndef AutoContrast(img, _):\n    return PIL.ImageOps.autocontrast(img)\n\n\ndef Invert(img, _):\n    return PIL.ImageOps.invert(img)\n\n\ndef Equalize(img, _):\n    return PIL.ImageOps.equalize(img)\n\n\ndef Flip(img, _):\n    return PIL.ImageOps.mirror(img)\n\n\ndef Solarize(img, v):\n    assert 0 <= v <= 256\n    return PIL.ImageOps.solarize(img, v)\n\n\ndef SolarizeAdd(img, addition=0, threshold=128):\n    img_np = np.array(img).astype(np.int)\n    img_np = img_np + addition\n    img_np = np.clip(img_np, 0, 255)\n    img_np = img_np.astype(np.uint8)\n    img = Image.fromarray(img_np)\n    return PIL.ImageOps.solarize(img, threshold)\n\n\ndef Posterize(img, v):\n    assert 4 <= v <= 8\n    v = int(v)\n    return PIL.ImageOps.posterize(img, v)\n\n\ndef Contrast(img, v):\n    assert 0.0 <= v <= 2.0\n    return PIL.ImageEnhance.Contrast(img).enhance(v)\n\n\ndef Color(img, v):\n    assert 0.0 <= v <= 2.0\n    return PIL.ImageEnhance.Color(img).enhance(v)\n\n\ndef Brightness(img, v):\n    assert 0.0 <= v <= 2.0\n    return PIL.ImageEnhance.Brightness(img).enhance(v)\n\n\ndef Sharpness(img, v):\n    assert 0.0 <= v <= 2.0\n    return PIL.ImageEnhance.Sharpness(img).enhance(v)\n\n\ndef Cutout(img, v):\n    # [0, 60] => percentage: [0, 0.2]\n    assert 0.0 <= v <= 0.2\n    if v <= 0.0:\n        return img\n\n    v = v * img.size[0]\n    return CutoutAbs(img, v)\n\n\ndef CutoutAbs(img, v):\n    # [0, 60] => percentage: [0, 0.2]\n    # assert 0 <= v <= 20\n    if v < 0:\n        return img\n    w, h = img.size\n    x0 = np.random.uniform(w)\n    y0 = np.random.uniform(h)\n\n    x0 = int(max(0, x0 - v/2.0))\n    y0 = int(max(0, y0 - v/2.0))\n    x1 = min(w, x0 + v)\n    y1 = min(h, y0 + v)\n\n    xy = (x0, y0, x1, y1)\n    color = (125, 123, 114)\n    # color = (0, 0, 0)\n    img = img.copy()\n    PIL.ImageDraw.Draw(img).rectangle(xy, color)\n    return img\n\n\ndef SamplePairing(imgs):\n    # [0, 0.4]\n    def f(img1, v):\n        i = np.random.choice(len(imgs))\n        img2 = PIL.Image.fromarray(imgs[i])\n        return PIL.Image.blend(img1, img2, v)\n\n    return f\n\n\ndef Identity(img, v):\n    return img\n\n\nclass Lighting:\n    \"\"\"Lighting noise (AlexNet - style PCA - based noise).\"\"\"\n\n    def __init__(self, alphastd, eigval, eigvec):\n        self.alphastd = alphastd\n        self.eigval = torch.Tensor(eigval)\n        self.eigvec = torch.Tensor(eigvec)\n\n    def __call__(self, img):\n        if self.alphastd == 0:\n            return img\n\n        alpha = img.new().resize_(3).normal_(0, self.alphastd)\n        rgb = (\n            self.eigvec.type_as(img).clone().mul(\n                alpha.view(1, 3).expand(3, 3)\n            ).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()\n        )\n\n        return img.add(rgb.view(3, 1, 1).expand_as(img))\n\n\nclass CutoutDefault:\n    \"\"\"\n    Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py\n    \"\"\"\n\n    def __init__(self, length):\n        self.length = length\n\n    def __call__(self, img):\n        h, w = img.size(1), img.size(2)\n        mask = np.ones((h, w), np.float32)\n        y = np.random.randint(h)\n        x = np.random.randint(w)\n\n        y1 = np.clip(y - self.length // 2, 0, h)\n        y2 = np.clip(y + self.length // 2, 0, h)\n        x1 = np.clip(x - self.length // 2, 0, w)\n        x2 = np.clip(x + self.length // 2, 0, w)\n\n        mask[y1:y2, x1:x2] = 0.0\n        mask = torch.from_numpy(mask)\n        mask = mask.expand_as(img)\n        img *= mask\n        return img\n\n\ndef randaugment_list():\n    # 16 oeprations and their ranges\n    # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57\n    # augs = [\n    #     (Identity, 0., 1.0),\n    #     (ShearX, 0., 0.3),  # 0\n    #     (ShearY, 0., 0.3),  # 1\n    #     (TranslateX, 0., 0.33),  # 2\n    #     (TranslateY, 0., 0.33),  # 3\n    #     (Rotate, 0, 30),  # 4\n    #     (AutoContrast, 0, 1),  # 5\n    #     (Invert, 0, 1),  # 6\n    #     (Equalize, 0, 1),  # 7\n    #     (Solarize, 0, 110),  # 8\n    #     (Posterize, 4, 8),  # 9\n    #     # (Contrast, 0.1, 1.9),  # 10\n    #     (Color, 0.1, 1.9),  # 11\n    #     (Brightness, 0.1, 1.9),  # 12\n    #     (Sharpness, 0.1, 1.9),  # 13\n    #     # (Cutout, 0, 0.2),  # 14\n    #     # (SamplePairing(imgs), 0, 0.4)  # 15\n    # ]\n\n    # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505\n    augs = [\n        (AutoContrast, 0, 1),\n        (Equalize, 0, 1),\n        (Invert, 0, 1),\n        (Rotate, 0, 30),\n        (Posterize, 4, 8),\n        (Solarize, 0, 256),\n        (SolarizeAdd, 0, 110),\n        (Color, 0.1, 1.9),\n        (Contrast, 0.1, 1.9),\n        (Brightness, 0.1, 1.9),\n        (Sharpness, 0.1, 1.9),\n        (ShearX, 0.0, 0.3),\n        (ShearY, 0.0, 0.3),\n        (CutoutAbs, 0, 40),\n        (TranslateXabs, 0.0, 100),\n        (TranslateYabs, 0.0, 100),\n    ]\n\n    return augs\n\n\ndef randaugment_list2():\n    augs = [\n        (AutoContrast, 0, 1),\n        (Brightness, 0.1, 1.9),\n        (Color, 0.1, 1.9),\n        (Contrast, 0.1, 1.9),\n        (Equalize, 0, 1),\n        (Identity, 0, 1),\n        (Invert, 0, 1),\n        (Posterize, 4, 8),\n        (Rotate, -30, 30),\n        (Sharpness, 0.1, 1.9),\n        (ShearX, -0.3, 0.3),\n        (ShearY, -0.3, 0.3),\n        (Solarize, 0, 256),\n        (TranslateX, -0.3, 0.3),\n        (TranslateY, -0.3, 0.3),\n    ]\n\n    return augs\n\n\ndef fixmatch_list():\n    # https://arxiv.org/abs/2001.07685\n    augs = [\n        (AutoContrast, 0, 1),\n        (Brightness, 0.05, 0.95),\n        (Color, 0.05, 0.95),\n        (Contrast, 0.05, 0.95),\n        (Equalize, 0, 1),\n        (Identity, 0, 1),\n        (Posterize, 4, 8),\n        (Rotate, -30, 30),\n        (Sharpness, 0.05, 0.95),\n        (ShearX, -0.3, 0.3),\n        (ShearY, -0.3, 0.3),\n        (Solarize, 0, 256),\n        (TranslateX, -0.3, 0.3),\n        (TranslateY, -0.3, 0.3),\n    ]\n\n    return augs\n\n\nclass RandAugment:\n\n    def __init__(self, n=2, m=10):\n        assert 0 <= m <= 30\n        self.n = n\n        self.m = m\n        self.augment_list = randaugment_list()\n\n    def __call__(self, img):\n        ops = random.choices(self.augment_list, k=self.n)\n\n        for op, minval, maxval in ops:\n            val = (self.m / 30) * (maxval-minval) + minval\n            img = op(img, val)\n\n        return img\n\n\nclass RandAugment2:\n\n    def __init__(self, n=2, p=0.6):\n        self.n = n\n        self.p = p\n        self.augment_list = randaugment_list2()\n\n    def __call__(self, img):\n        ops = random.choices(self.augment_list, k=self.n)\n\n        for op, minval, maxval in ops:\n            if random.random() > self.p:\n                continue\n            m = random.random()\n            val = m * (maxval-minval) + minval\n            img = op(img, val)\n\n        return img\n\n\nclass RandAugmentFixMatch:\n\n    def __init__(self, n=2):\n        self.n = n\n        self.augment_list = fixmatch_list()\n\n    def __call__(self, img):\n        ops = random.choices(self.augment_list, k=self.n)\n\n        for op, minval, maxval in ops:\n            m = random.random()\n            val = m * (maxval-minval) + minval\n            img = op(img, val)\n\n        return img\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py",
    "content": "import numpy as np\nimport random\nimport torch\nfrom PIL import Image\nfrom torchvision.transforms import (\n    Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,\n    RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,\n    RandomHorizontalFlip\n)\n\nfrom .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy\nfrom .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch\n\nAVAI_CHOICES = [\n    \"random_flip\",\n    \"random_resized_crop\",\n    \"normalize\",\n    \"instance_norm\",\n    \"random_crop\",\n    \"random_translation\",\n    \"center_crop\",  # This has become a default operation for test\n    \"cutout\",\n    \"imagenet_policy\",\n    \"cifar10_policy\",\n    \"svhn_policy\",\n    \"randaugment\",\n    \"randaugment_fixmatch\",\n    \"randaugment2\",\n    \"gaussian_noise\",\n    \"colorjitter\",\n    \"randomgrayscale\",\n    \"gaussian_blur\",\n]\n\nINTERPOLATION_MODES = {\n    \"bilinear\": Image.BILINEAR,\n    \"bicubic\": Image.BICUBIC,\n    \"nearest\": Image.NEAREST,\n}\n\n\nclass Random2DTranslation:\n    \"\"\"Given an image of (height, width), we resize it to\n    (height*1.125, width*1.125), and then perform random cropping.\n\n    Args:\n        height (int): target image height.\n        width (int): target image width.\n        p (float, optional): probability that this operation takes place.\n            Default is 0.5.\n        interpolation (int, optional): desired interpolation. Default is\n            ``PIL.Image.BILINEAR``\n    \"\"\"\n\n    def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):\n        self.height = height\n        self.width = width\n        self.p = p\n        self.interpolation = interpolation\n\n    def __call__(self, img):\n        if random.uniform(0, 1) > self.p:\n            return img.resize((self.width, self.height), self.interpolation)\n\n        new_width = int(round(self.width * 1.125))\n        new_height = int(round(self.height * 1.125))\n        resized_img = img.resize((new_width, new_height), self.interpolation)\n\n        x_maxrange = new_width - self.width\n        y_maxrange = new_height - self.height\n        x1 = int(round(random.uniform(0, x_maxrange)))\n        y1 = int(round(random.uniform(0, y_maxrange)))\n        croped_img = resized_img.crop(\n            (x1, y1, x1 + self.width, y1 + self.height)\n        )\n\n        return croped_img\n\n\nclass InstanceNormalization:\n    \"\"\"Normalize data using per-channel mean and standard deviation.\n\n    Reference:\n        - Ulyanov et al. Instance normalization: The missing in- gredient\n          for fast stylization. ArXiv 2016.\n        - Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.\n          ICLR 2018.\n    \"\"\"\n\n    def __init__(self, eps=1e-8):\n        self.eps = eps\n\n    def __call__(self, img):\n        C, H, W = img.shape\n        img_re = img.reshape(C, H * W)\n        mean = img_re.mean(1).view(C, 1, 1)\n        std = img_re.std(1).view(C, 1, 1)\n        return (img-mean) / (std + self.eps)\n\n\nclass Cutout:\n    \"\"\"Randomly mask out one or more patches from an image.\n\n    https://github.com/uoguelph-mlrg/Cutout\n\n    Args:\n        n_holes (int, optional): number of patches to cut out\n            of each image. Default is 1.\n        length (int, optinal): length (in pixels) of each square\n            patch. Default is 16.\n    \"\"\"\n\n    def __init__(self, n_holes=1, length=16):\n        self.n_holes = n_holes\n        self.length = length\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (Tensor): tensor image of size (C, H, W).\n\n        Returns:\n            Tensor: image with n_holes of dimension\n                length x length cut out of it.\n        \"\"\"\n        h = img.size(1)\n        w = img.size(2)\n\n        mask = np.ones((h, w), np.float32)\n\n        for n in range(self.n_holes):\n            y = np.random.randint(h)\n            x = np.random.randint(w)\n\n            y1 = np.clip(y - self.length // 2, 0, h)\n            y2 = np.clip(y + self.length // 2, 0, h)\n            x1 = np.clip(x - self.length // 2, 0, w)\n            x2 = np.clip(x + self.length // 2, 0, w)\n\n            mask[y1:y2, x1:x2] = 0.0\n\n        mask = torch.from_numpy(mask)\n        mask = mask.expand_as(img)\n        return img * mask\n\n\nclass GaussianNoise:\n    \"\"\"Add gaussian noise.\"\"\"\n\n    def __init__(self, mean=0, std=0.15, p=0.5):\n        self.mean = mean\n        self.std = std\n        self.p = p\n\n    def __call__(self, img):\n        if random.uniform(0, 1) > self.p:\n            return img\n        noise = torch.randn(img.size()) * self.std + self.mean\n        return img + noise\n\n\ndef build_transform(cfg, is_train=True, choices=None):\n    \"\"\"Build transformation function.\n\n    Args:\n        cfg (CfgNode): config.\n        is_train (bool, optional): for training (True) or test (False).\n            Default is True.\n        choices (list, optional): list of strings which will overwrite\n            cfg.INPUT.TRANSFORMS if given. Default is None.\n    \"\"\"\n    if cfg.INPUT.NO_TRANSFORM:\n        print(\"Note: no transform is applied!\")\n        return None\n\n    if choices is None:\n        choices = cfg.INPUT.TRANSFORMS\n\n    for choice in choices:\n        assert choice in AVAI_CHOICES\n\n    target_size = f\"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}\"\n\n    normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)\n\n    if is_train:\n        return _build_transform_train(cfg, choices, target_size, normalize)\n    else:\n        return _build_transform_test(cfg, choices, target_size, normalize)\n\n\ndef _build_transform_train(cfg, choices, target_size, normalize):\n    print(\"Building transform_train\")\n    tfm_train = []\n\n    interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]\n\n    # Make sure the image size matches the target size\n    conditions = []\n    conditions += [\"random_crop\" not in choices]\n    conditions += [\"random_resized_crop\" not in choices]\n    if all(conditions):\n        print(f\"+ resize to {target_size}\")\n        tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]\n\n    if \"random_translation\" in choices:\n        print(\"+ random translation\")\n        tfm_train += [\n            Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1])\n        ]\n\n    if \"random_crop\" in choices:\n        crop_padding = cfg.INPUT.CROP_PADDING\n        print(\"+ random crop (padding = {})\".format(crop_padding))\n        tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)]\n\n    if \"random_resized_crop\" in choices:\n        print(f\"+ random resized crop (size={cfg.INPUT.SIZE})\")\n        tfm_train += [\n            RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode)\n        ]\n\n    if \"center_crop\" in choices:\n        print(f\"+ center crop (size={cfg.INPUT.SIZE})\")\n        tfm_train += [CenterCrop(cfg.INPUT.SIZE)]\n\n    if \"random_flip\" in choices:\n        print(\"+ random flip\")\n        tfm_train += [RandomHorizontalFlip()]\n\n    if \"imagenet_policy\" in choices:\n        print(\"+ imagenet policy\")\n        tfm_train += [ImageNetPolicy()]\n\n    if \"cifar10_policy\" in choices:\n        print(\"+ cifar10 policy\")\n        tfm_train += [CIFAR10Policy()]\n\n    if \"svhn_policy\" in choices:\n        print(\"+ svhn policy\")\n        tfm_train += [SVHNPolicy()]\n\n    if \"randaugment\" in choices:\n        n_ = cfg.INPUT.RANDAUGMENT_N\n        m_ = cfg.INPUT.RANDAUGMENT_M\n        print(\"+ randaugment (n={}, m={})\".format(n_, m_))\n        tfm_train += [RandAugment(n_, m_)]\n\n    if \"randaugment_fixmatch\" in choices:\n        n_ = cfg.INPUT.RANDAUGMENT_N\n        print(\"+ randaugment_fixmatch (n={})\".format(n_))\n        tfm_train += [RandAugmentFixMatch(n_)]\n\n    if \"randaugment2\" in choices:\n        n_ = cfg.INPUT.RANDAUGMENT_N\n        print(\"+ randaugment2 (n={})\".format(n_))\n        tfm_train += [RandAugment2(n_)]\n\n    if \"colorjitter\" in choices:\n        print(\"+ color jitter\")\n        tfm_train += [\n            ColorJitter(\n                brightness=cfg.INPUT.COLORJITTER_B,\n                contrast=cfg.INPUT.COLORJITTER_C,\n                saturation=cfg.INPUT.COLORJITTER_S,\n                hue=cfg.INPUT.COLORJITTER_H,\n            )\n        ]\n\n    if \"randomgrayscale\" in choices:\n        print(\"+ random gray scale\")\n        tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]\n\n    if \"gaussian_blur\" in choices:\n        print(f\"+ gaussian blur (kernel={cfg.INPUT.GB_K})\")\n        tfm_train += [\n            RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P)\n        ]\n\n    print(\"+ to torch tensor of range [0, 1]\")\n    tfm_train += [ToTensor()]\n\n    if \"cutout\" in choices:\n        cutout_n = cfg.INPUT.CUTOUT_N\n        cutout_len = cfg.INPUT.CUTOUT_LEN\n        print(\"+ cutout (n_holes={}, length={})\".format(cutout_n, cutout_len))\n        tfm_train += [Cutout(cutout_n, cutout_len)]\n\n    if \"normalize\" in choices:\n        print(\n            \"+ normalization (mean={}, \"\n            \"std={})\".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)\n        )\n        tfm_train += [normalize]\n\n    if \"gaussian_noise\" in choices:\n        print(\n            \"+ gaussian noise (mean={}, std={})\".format(\n                cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD\n            )\n        )\n        tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]\n\n    if \"instance_norm\" in choices:\n        print(\"+ instance normalization\")\n        tfm_train += [InstanceNormalization()]\n\n    tfm_train = Compose(tfm_train)\n\n    return tfm_train\n\n\ndef _build_transform_test(cfg, choices, target_size, normalize):\n    print(\"Building transform_test\")\n    tfm_test = []\n\n    interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]\n\n    print(f\"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}\")\n    tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)]\n\n    print(f\"+ {target_size} center crop\")\n    tfm_test += [CenterCrop(cfg.INPUT.SIZE)]\n\n    print(\"+ to torch tensor of range [0, 1]\")\n    tfm_test += [ToTensor()]\n\n    if \"normalize\" in choices:\n        print(\n            \"+ normalization (mean={}, \"\n            \"std={})\".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)\n        )\n        tfm_test += [normalize]\n\n    if \"instance_norm\" in choices:\n        print(\"+ instance normalization\")\n        tfm_test += [InstanceNormalization()]\n\n    tfm_test = Compose(tfm_test)\n\n    return tfm_test\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/__init__.py",
    "content": "from .build import TRAINER_REGISTRY, build_trainer  # isort:skip\nfrom .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet  # isort:skip\n\nfrom .da import *\nfrom .dg import *\nfrom .ssl import *\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/build.py",
    "content": "from dassl.utils import Registry, check_availability\n\nTRAINER_REGISTRY = Registry(\"TRAINER\")\n\n\ndef build_trainer(cfg):\n    avai_trainers = TRAINER_REGISTRY.registered_names()\n    check_availability(cfg.TRAINER.NAME, avai_trainers)\n    if cfg.VERBOSE:\n        print(\"Loading trainer: {}\".format(cfg.TRAINER.NAME))\n    return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py",
    "content": "from .mcd import MCD\nfrom .mme import MME\nfrom .adda import ADDA\nfrom .dael import DAEL\nfrom .dann import DANN\nfrom .adabn import AdaBN\nfrom .m3sda import M3SDA\nfrom .source_only import SourceOnly\nfrom .self_ensembling import SelfEnsembling\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py",
    "content": "import torch\n\nfrom dassl.utils import check_isfile\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\n\n\n@TRAINER_REGISTRY.register()\nclass AdaBN(TrainerXU):\n    \"\"\"Adaptive Batch Normalization.\n\n    https://arxiv.org/abs/1603.04779.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.done_reset_bn_stats = False\n\n    def check_cfg(self, cfg):\n        assert check_isfile(\n            cfg.MODEL.INIT_WEIGHTS\n        ), \"The weights of source model must be provided\"\n\n    def before_epoch(self):\n        if not self.done_reset_bn_stats:\n            for m in self.model.modules():\n                classname = m.__class__.__name__\n                if classname.find(\"BatchNorm\") != -1:\n                    m.reset_running_stats()\n\n            self.done_reset_bn_stats = True\n\n    def forward_backward(self, batch_x, batch_u):\n        input_u = batch_u[\"img\"].to(self.device)\n\n        with torch.no_grad():\n            self.model(input_u)\n\n        return None\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/adda.py",
    "content": "import copy\nimport torch\nimport torch.nn as nn\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import check_isfile, count_num_param, open_specified_layers\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.modeling import build_head\n\n\n@TRAINER_REGISTRY.register()\nclass ADDA(TrainerXU):\n    \"\"\"Adversarial Discriminative Domain Adaptation.\n\n    https://arxiv.org/abs/1702.05464.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.open_layers = [\"backbone\"]\n        if isinstance(self.model.head, nn.Module):\n            self.open_layers.append(\"head\")\n\n        self.source_model = copy.deepcopy(self.model)\n        self.source_model.eval()\n        for param in self.source_model.parameters():\n            param.requires_grad_(False)\n\n        self.build_critic()\n\n        self.bce = nn.BCEWithLogitsLoss()\n\n    def check_cfg(self, cfg):\n        assert check_isfile(\n            cfg.MODEL.INIT_WEIGHTS\n        ), \"The weights of source model must be provided\"\n\n    def build_critic(self):\n        cfg = self.cfg\n\n        print(\"Building critic network\")\n        fdim = self.model.fdim\n        critic_body = build_head(\n            \"mlp\",\n            verbose=cfg.VERBOSE,\n            in_features=fdim,\n            hidden_layers=[fdim, fdim // 2],\n            activation=\"leaky_relu\",\n        )\n        self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))\n        print(\"# params: {:,}\".format(count_num_param(self.critic)))\n        self.critic.to(self.device)\n        self.optim_c = build_optimizer(self.critic, cfg.OPTIM)\n        self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)\n        self.register_model(\"critic\", self.critic, self.optim_c, self.sched_c)\n\n    def forward_backward(self, batch_x, batch_u):\n        open_specified_layers(self.model, self.open_layers)\n        input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)\n        domain_x = torch.ones(input_x.shape[0], 1).to(self.device)\n        domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)\n\n        _, feat_x = self.source_model(input_x, return_feature=True)\n        _, feat_u = self.model(input_u, return_feature=True)\n\n        logit_xd = self.critic(feat_x)\n        logit_ud = self.critic(feat_u.detach())\n\n        loss_critic = self.bce(logit_xd, domain_x)\n        loss_critic += self.bce(logit_ud, domain_u)\n        self.model_backward_and_update(loss_critic, \"critic\")\n\n        logit_ud = self.critic(feat_u)\n        loss_model = self.bce(logit_ud, 1 - domain_u)\n        self.model_backward_and_update(loss_model, \"model\")\n\n        loss_summary = {\n            \"loss_critic\": loss_critic.item(),\n            \"loss_model\": loss_model.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/dael.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom dassl.data import DataManager\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\nfrom dassl.engine.trainer import SimpleNet\nfrom dassl.data.transforms import build_transform\nfrom dassl.modeling.ops.utils import create_onehot\n\n\nclass Experts(nn.Module):\n\n    def __init__(self, n_source, fdim, num_classes):\n        super().__init__()\n        self.linears = nn.ModuleList(\n            [nn.Linear(fdim, num_classes) for _ in range(n_source)]\n        )\n        self.softmax = nn.Softmax(dim=1)\n\n    def forward(self, i, x):\n        x = self.linears[i](x)\n        x = self.softmax(x)\n        return x\n\n\n@TRAINER_REGISTRY.register()\nclass DAEL(TrainerXU):\n    \"\"\"Domain Adaptive Ensemble Learning.\n\n    https://arxiv.org/abs/2003.07325.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN\n        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE\n        if n_domain <= 0:\n            n_domain = self.num_source_domains\n        self.split_batch = batch_size // n_domain\n        self.n_domain = n_domain\n\n        self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U\n        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE\n\n    def check_cfg(self, cfg):\n        assert cfg.DATALOADER.TRAIN_X.SAMPLER == \"RandomDomainSampler\"\n        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X\n        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0\n\n    def build_data_loader(self):\n        cfg = self.cfg\n        tfm_train = build_transform(cfg, is_train=True)\n        custom_tfm_train = [tfm_train]\n        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS\n        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)\n        custom_tfm_train += [tfm_train_strong]\n        dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)\n        self.train_loader_x = dm.train_loader_x\n        self.train_loader_u = dm.train_loader_u\n        self.val_loader = dm.val_loader\n        self.test_loader = dm.test_loader\n        self.num_classes = dm.num_classes\n        self.num_source_domains = dm.num_source_domains\n        self.lab2cname = dm.lab2cname\n\n    def build_model(self):\n        cfg = self.cfg\n\n        print(\"Building F\")\n        self.F = SimpleNet(cfg, cfg.MODEL, 0)\n        self.F.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.F)))\n        self.optim_F = build_optimizer(self.F, cfg.OPTIM)\n        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)\n        self.register_model(\"F\", self.F, self.optim_F, self.sched_F)\n        fdim = self.F.fdim\n\n        print(\"Building E\")\n        self.E = Experts(self.num_source_domains, fdim, self.num_classes)\n        self.E.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.E)))\n        self.optim_E = build_optimizer(self.E, cfg.OPTIM)\n        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)\n        self.register_model(\"E\", self.E, self.optim_E, self.sched_E)\n\n    def forward_backward(self, batch_x, batch_u):\n        parsed_data = self.parse_batch_train(batch_x, batch_u)\n        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data\n\n        input_x = torch.split(input_x, self.split_batch, 0)\n        input_x2 = torch.split(input_x2, self.split_batch, 0)\n        label_x = torch.split(label_x, self.split_batch, 0)\n        domain_x = torch.split(domain_x, self.split_batch, 0)\n        domain_x = [d[0].item() for d in domain_x]\n\n        # Generate pseudo label\n        with torch.no_grad():\n            feat_u = self.F(input_u)\n            pred_u = []\n            for k in range(self.num_source_domains):\n                pred_uk = self.E(k, feat_u)\n                pred_uk = pred_uk.unsqueeze(1)\n                pred_u.append(pred_uk)\n            pred_u = torch.cat(pred_u, 1)  # (B, K, C)\n            # Get the highest probability and index (label) for each expert\n            experts_max_p, experts_max_idx = pred_u.max(2)  # (B, K)\n            # Get the most confident expert\n            max_expert_p, max_expert_idx = experts_max_p.max(1)  # (B)\n            pseudo_label_u = []\n            for i, experts_label in zip(max_expert_idx, experts_max_idx):\n                pseudo_label_u.append(experts_label[i])\n            pseudo_label_u = torch.stack(pseudo_label_u, 0)\n            pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)\n            pseudo_label_u = pseudo_label_u.to(self.device)\n            label_u_mask = (max_expert_p >= self.conf_thre).float()\n\n        loss_x = 0\n        loss_cr = 0\n        acc_x = 0\n\n        feat_x = [self.F(x) for x in input_x]\n        feat_x2 = [self.F(x) for x in input_x2]\n        feat_u2 = self.F(input_u2)\n\n        for feat_xi, feat_x2i, label_xi, i in zip(\n            feat_x, feat_x2, label_x, domain_x\n        ):\n            cr_s = [j for j in domain_x if j != i]\n\n            # Learning expert\n            pred_xi = self.E(i, feat_xi)\n            loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()\n            expert_label_xi = pred_xi.detach()\n            acc_x += compute_accuracy(pred_xi.detach(),\n                                      label_xi.max(1)[1])[0].item()\n\n            # Consistency regularization\n            cr_pred = []\n            for j in cr_s:\n                pred_j = self.E(j, feat_x2i)\n                pred_j = pred_j.unsqueeze(1)\n                cr_pred.append(pred_j)\n            cr_pred = torch.cat(cr_pred, 1)\n            cr_pred = cr_pred.mean(1)\n            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()\n\n        loss_x /= self.n_domain\n        loss_cr /= self.n_domain\n        acc_x /= self.n_domain\n\n        # Unsupervised loss\n        pred_u = []\n        for k in range(self.num_source_domains):\n            pred_uk = self.E(k, feat_u2)\n            pred_uk = pred_uk.unsqueeze(1)\n            pred_u.append(pred_uk)\n        pred_u = torch.cat(pred_u, 1)\n        pred_u = pred_u.mean(1)\n        l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)\n        loss_u = (l_u * label_u_mask).mean()\n\n        loss = 0\n        loss += loss_x\n        loss += loss_cr\n        loss += loss_u * self.weight_u\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc_x\": acc_x,\n            \"loss_cr\": loss_cr.item(),\n            \"loss_u\": loss_u.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input_x = batch_x[\"img\"]\n        input_x2 = batch_x[\"img2\"]\n        label_x = batch_x[\"label\"]\n        domain_x = batch_x[\"domain\"]\n        input_u = batch_u[\"img\"]\n        input_u2 = batch_u[\"img2\"]\n\n        label_x = create_onehot(label_x, self.num_classes)\n\n        input_x = input_x.to(self.device)\n        input_x2 = input_x2.to(self.device)\n        label_x = label_x.to(self.device)\n        input_u = input_u.to(self.device)\n        input_u2 = input_u2.to(self.device)\n\n        return input_x, input_x2, label_x, domain_x, input_u, input_u2\n\n    def model_inference(self, input):\n        f = self.F(input)\n        p = []\n        for k in range(self.num_source_domains):\n            p_k = self.E(k, f)\n            p_k = p_k.unsqueeze(1)\n            p.append(p_k)\n        p = torch.cat(p, 1)\n        p = p.mean(1)\n        return p\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/dann.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\nfrom dassl.modeling import build_head\nfrom dassl.modeling.ops import ReverseGrad\n\n\n@TRAINER_REGISTRY.register()\nclass DANN(TrainerXU):\n    \"\"\"Domain-Adversarial Neural Networks.\n\n    https://arxiv.org/abs/1505.07818.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.build_critic()\n        self.ce = nn.CrossEntropyLoss()\n        self.bce = nn.BCEWithLogitsLoss()\n\n    def build_critic(self):\n        cfg = self.cfg\n\n        print(\"Building critic network\")\n        fdim = self.model.fdim\n        critic_body = build_head(\n            \"mlp\",\n            verbose=cfg.VERBOSE,\n            in_features=fdim,\n            hidden_layers=[fdim, fdim],\n            activation=\"leaky_relu\",\n        )\n        self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))\n        print(\"# params: {:,}\".format(count_num_param(self.critic)))\n        self.critic.to(self.device)\n        self.optim_c = build_optimizer(self.critic, cfg.OPTIM)\n        self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)\n        self.register_model(\"critic\", self.critic, self.optim_c, self.sched_c)\n        self.revgrad = ReverseGrad()\n\n    def forward_backward(self, batch_x, batch_u):\n        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)\n        domain_x = torch.ones(input_x.shape[0], 1).to(self.device)\n        domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)\n\n        global_step = self.batch_idx + self.epoch * self.num_batches\n        progress = global_step / (self.max_epoch * self.num_batches)\n        lmda = 2 / (1 + np.exp(-10 * progress)) - 1\n\n        logit_x, feat_x = self.model(input_x, return_feature=True)\n        _, feat_u = self.model(input_u, return_feature=True)\n\n        loss_x = self.ce(logit_x, label_x)\n\n        feat_x = self.revgrad(feat_x, grad_scaling=lmda)\n        feat_u = self.revgrad(feat_u, grad_scaling=lmda)\n        output_xd = self.critic(feat_x)\n        output_ud = self.critic(feat_u)\n        loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)\n\n        loss = loss_x + loss_d\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc_x\": compute_accuracy(logit_x, label_x)[0].item(),\n            \"loss_d\": loss_d.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.engine.trainer import SimpleNet\n\n\nclass PairClassifiers(nn.Module):\n\n    def __init__(self, fdim, num_classes):\n        super().__init__()\n        self.c1 = nn.Linear(fdim, num_classes)\n        self.c2 = nn.Linear(fdim, num_classes)\n\n    def forward(self, x):\n        z1 = self.c1(x)\n        if not self.training:\n            return z1\n        z2 = self.c2(x)\n        return z1, z2\n\n\n@TRAINER_REGISTRY.register()\nclass M3SDA(TrainerXU):\n    \"\"\"Moment Matching for Multi-Source Domain Adaptation.\n\n    https://arxiv.org/abs/1812.01754.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN\n        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE\n        if n_domain <= 0:\n            n_domain = self.num_source_domains\n        self.split_batch = batch_size // n_domain\n        self.n_domain = n_domain\n\n        self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F\n        self.lmda = cfg.TRAINER.M3SDA.LMDA\n\n    def check_cfg(self, cfg):\n        assert cfg.DATALOADER.TRAIN_X.SAMPLER == \"RandomDomainSampler\"\n        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X\n\n    def build_model(self):\n        cfg = self.cfg\n\n        print(\"Building F\")\n        self.F = SimpleNet(cfg, cfg.MODEL, 0)\n        self.F.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.F)))\n        self.optim_F = build_optimizer(self.F, cfg.OPTIM)\n        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)\n        self.register_model(\"F\", self.F, self.optim_F, self.sched_F)\n        fdim = self.F.fdim\n\n        print(\"Building C\")\n        self.C = nn.ModuleList(\n            [\n                PairClassifiers(fdim, self.num_classes)\n                for _ in range(self.num_source_domains)\n            ]\n        )\n        self.C.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.C)))\n        self.optim_C = build_optimizer(self.C, cfg.OPTIM)\n        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)\n        self.register_model(\"C\", self.C, self.optim_C, self.sched_C)\n\n    def forward_backward(self, batch_x, batch_u):\n        parsed = self.parse_batch_train(batch_x, batch_u)\n        input_x, label_x, domain_x, input_u = parsed\n\n        input_x = torch.split(input_x, self.split_batch, 0)\n        label_x = torch.split(label_x, self.split_batch, 0)\n        domain_x = torch.split(domain_x, self.split_batch, 0)\n        domain_x = [d[0].item() for d in domain_x]\n\n        # Step A\n        loss_x = 0\n        feat_x = []\n\n        for x, y, d in zip(input_x, label_x, domain_x):\n            f = self.F(x)\n            z1, z2 = self.C[d](f)\n            loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)\n\n            feat_x.append(f)\n\n        loss_x /= self.n_domain\n\n        feat_u = self.F(input_u)\n        loss_msda = self.moment_distance(feat_x, feat_u)\n\n        loss_step_A = loss_x + loss_msda * self.lmda\n        self.model_backward_and_update(loss_step_A)\n\n        # Step B\n        with torch.no_grad():\n            feat_u = self.F(input_u)\n\n        loss_x, loss_dis = 0, 0\n\n        for x, y, d in zip(input_x, label_x, domain_x):\n            with torch.no_grad():\n                f = self.F(x)\n            z1, z2 = self.C[d](f)\n            loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)\n\n            z1, z2 = self.C[d](feat_u)\n            p1 = F.softmax(z1, 1)\n            p2 = F.softmax(z2, 1)\n            loss_dis += self.discrepancy(p1, p2)\n\n        loss_x /= self.n_domain\n        loss_dis /= self.n_domain\n\n        loss_step_B = loss_x - loss_dis\n        self.model_backward_and_update(loss_step_B, \"C\")\n\n        # Step C\n        for _ in range(self.n_step_F):\n            feat_u = self.F(input_u)\n\n            loss_dis = 0\n\n            for d in domain_x:\n                z1, z2 = self.C[d](feat_u)\n                p1 = F.softmax(z1, 1)\n                p2 = F.softmax(z2, 1)\n                loss_dis += self.discrepancy(p1, p2)\n\n            loss_dis /= self.n_domain\n            loss_step_C = loss_dis\n\n            self.model_backward_and_update(loss_step_C, \"F\")\n\n        loss_summary = {\n            \"loss_step_A\": loss_step_A.item(),\n            \"loss_step_B\": loss_step_B.item(),\n            \"loss_step_C\": loss_step_C.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def moment_distance(self, x, u):\n        # x (list): a list of feature matrix.\n        # u (torch.Tensor): feature matrix.\n        x_mean = [xi.mean(0) for xi in x]\n        u_mean = u.mean(0)\n        dist1 = self.pairwise_distance(x_mean, u_mean)\n\n        x_var = [xi.var(0) for xi in x]\n        u_var = u.var(0)\n        dist2 = self.pairwise_distance(x_var, u_var)\n\n        return (dist1+dist2) / 2\n\n    def pairwise_distance(self, x, u):\n        # x (list): a list of feature vector.\n        # u (torch.Tensor): feature vector.\n        dist = 0\n        count = 0\n\n        for xi in x:\n            dist += self.euclidean(xi, u)\n            count += 1\n\n        for i in range(len(x) - 1):\n            for j in range(i + 1, len(x)):\n                dist += self.euclidean(x[i], x[j])\n                count += 1\n\n        return dist / count\n\n    def euclidean(self, input1, input2):\n        return ((input1 - input2)**2).sum().sqrt()\n\n    def discrepancy(self, y1, y2):\n        return (y1 - y2).abs().mean()\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input_x = batch_x[\"img\"]\n        label_x = batch_x[\"label\"]\n        domain_x = batch_x[\"domain\"]\n        input_u = batch_u[\"img\"]\n\n        input_x = input_x.to(self.device)\n        label_x = label_x.to(self.device)\n        input_u = input_u.to(self.device)\n\n        return input_x, label_x, domain_x, input_u\n\n    def model_inference(self, input):\n        f = self.F(input)\n        p = 0\n        for C_i in self.C:\n            z = C_i(f)\n            p += F.softmax(z, 1)\n        p = p / len(self.C)\n        return p\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.engine.trainer import SimpleNet\n\n\n@TRAINER_REGISTRY.register()\nclass MCD(TrainerXU):\n    \"\"\"Maximum Classifier Discrepancy.\n\n    https://arxiv.org/abs/1712.02560.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.n_step_F = cfg.TRAINER.MCD.N_STEP_F\n\n    def build_model(self):\n        cfg = self.cfg\n\n        print(\"Building F\")\n        self.F = SimpleNet(cfg, cfg.MODEL, 0)\n        self.F.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.F)))\n        self.optim_F = build_optimizer(self.F, cfg.OPTIM)\n        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)\n        self.register_model(\"F\", self.F, self.optim_F, self.sched_F)\n        fdim = self.F.fdim\n\n        print(\"Building C1\")\n        self.C1 = nn.Linear(fdim, self.num_classes)\n        self.C1.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.C1)))\n        self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)\n        self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)\n        self.register_model(\"C1\", self.C1, self.optim_C1, self.sched_C1)\n\n        print(\"Building C2\")\n        self.C2 = nn.Linear(fdim, self.num_classes)\n        self.C2.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.C2)))\n        self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)\n        self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)\n        self.register_model(\"C2\", self.C2, self.optim_C2, self.sched_C2)\n\n    def forward_backward(self, batch_x, batch_u):\n        parsed = self.parse_batch_train(batch_x, batch_u)\n        input_x, label_x, input_u = parsed\n\n        # Step A\n        feat_x = self.F(input_x)\n        logit_x1 = self.C1(feat_x)\n        logit_x2 = self.C2(feat_x)\n        loss_x1 = F.cross_entropy(logit_x1, label_x)\n        loss_x2 = F.cross_entropy(logit_x2, label_x)\n        loss_step_A = loss_x1 + loss_x2\n        self.model_backward_and_update(loss_step_A)\n\n        # Step B\n        with torch.no_grad():\n            feat_x = self.F(input_x)\n        logit_x1 = self.C1(feat_x)\n        logit_x2 = self.C2(feat_x)\n        loss_x1 = F.cross_entropy(logit_x1, label_x)\n        loss_x2 = F.cross_entropy(logit_x2, label_x)\n        loss_x = loss_x1 + loss_x2\n\n        with torch.no_grad():\n            feat_u = self.F(input_u)\n        pred_u1 = F.softmax(self.C1(feat_u), 1)\n        pred_u2 = F.softmax(self.C2(feat_u), 1)\n        loss_dis = self.discrepancy(pred_u1, pred_u2)\n\n        loss_step_B = loss_x - loss_dis\n        self.model_backward_and_update(loss_step_B, [\"C1\", \"C2\"])\n\n        # Step C\n        for _ in range(self.n_step_F):\n            feat_u = self.F(input_u)\n            pred_u1 = F.softmax(self.C1(feat_u), 1)\n            pred_u2 = F.softmax(self.C2(feat_u), 1)\n            loss_step_C = self.discrepancy(pred_u1, pred_u2)\n            self.model_backward_and_update(loss_step_C, \"F\")\n\n        loss_summary = {\n            \"loss_step_A\": loss_step_A.item(),\n            \"loss_step_B\": loss_step_B.item(),\n            \"loss_step_C\": loss_step_C.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def discrepancy(self, y1, y2):\n        return (y1 - y2).abs().mean()\n\n    def model_inference(self, input):\n        feat = self.F(input)\n        return self.C1(feat)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/mme.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\nfrom dassl.modeling.ops import ReverseGrad\nfrom dassl.engine.trainer import SimpleNet\n\n\nclass Prototypes(nn.Module):\n\n    def __init__(self, fdim, num_classes, temp=0.05):\n        super().__init__()\n        self.prototypes = nn.Linear(fdim, num_classes, bias=False)\n        self.temp = temp\n\n    def forward(self, x):\n        x = F.normalize(x, p=2, dim=1)\n        out = self.prototypes(x)\n        out = out / self.temp\n        return out\n\n\n@TRAINER_REGISTRY.register()\nclass MME(TrainerXU):\n    \"\"\"Minimax Entropy.\n\n    https://arxiv.org/abs/1904.06487.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.lmda = cfg.TRAINER.MME.LMDA\n\n    def build_model(self):\n        cfg = self.cfg\n\n        print(\"Building F\")\n        self.F = SimpleNet(cfg, cfg.MODEL, 0)\n        self.F.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.F)))\n        self.optim_F = build_optimizer(self.F, cfg.OPTIM)\n        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)\n        self.register_model(\"F\", self.F, self.optim_F, self.sched_F)\n\n        print(\"Building C\")\n        self.C = Prototypes(self.F.fdim, self.num_classes)\n        self.C.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.C)))\n        self.optim_C = build_optimizer(self.C, cfg.OPTIM)\n        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)\n        self.register_model(\"C\", self.C, self.optim_C, self.sched_C)\n\n        self.revgrad = ReverseGrad()\n\n    def forward_backward(self, batch_x, batch_u):\n        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)\n\n        feat_x = self.F(input_x)\n        logit_x = self.C(feat_x)\n        loss_x = F.cross_entropy(logit_x, label_x)\n        self.model_backward_and_update(loss_x)\n\n        feat_u = self.F(input_u)\n        feat_u = self.revgrad(feat_u)\n        logit_u = self.C(feat_u)\n        prob_u = F.softmax(logit_u, 1)\n        loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()\n        self.model_backward_and_update(loss_u * self.lmda)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc_x\": compute_accuracy(logit_x, label_x)[0].item(),\n            \"loss_u\": loss_u.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def model_inference(self, input):\n        return self.C(self.F(input))\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py",
    "content": "import copy\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\nfrom dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update\n\n\n@TRAINER_REGISTRY.register()\nclass SelfEnsembling(TrainerXU):\n    \"\"\"Self-ensembling for visual domain adaptation.\n\n    https://arxiv.org/abs/1706.05208.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA\n        self.conf_thre = cfg.TRAINER.SE.CONF_THRE\n        self.rampup = cfg.TRAINER.SE.RAMPUP\n\n        self.teacher = copy.deepcopy(self.model)\n        self.teacher.train()\n        for param in self.teacher.parameters():\n            param.requires_grad_(False)\n\n    def check_cfg(self, cfg):\n        assert cfg.DATALOADER.K_TRANSFORMS == 2\n\n    def forward_backward(self, batch_x, batch_u):\n        global_step = self.batch_idx + self.epoch * self.num_batches\n        parsed = self.parse_batch_train(batch_x, batch_u)\n        input_x, label_x, input_u1, input_u2 = parsed\n\n        logit_x = self.model(input_x)\n        loss_x = F.cross_entropy(logit_x, label_x)\n\n        prob_u = F.softmax(self.model(input_u1), 1)\n        t_prob_u = F.softmax(self.teacher(input_u2), 1)\n        loss_u = ((prob_u - t_prob_u)**2).sum(1)\n\n        if self.conf_thre:\n            max_prob = t_prob_u.max(1)[0]\n            mask = (max_prob > self.conf_thre).float()\n            loss_u = (loss_u * mask).mean()\n        else:\n            weight_u = sigmoid_rampup(global_step, self.rampup)\n            loss_u = loss_u.mean() * weight_u\n\n        loss = loss_x + loss_u\n        self.model_backward_and_update(loss)\n\n        ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)\n        ema_model_update(self.model, self.teacher, ema_alpha)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc_x\": compute_accuracy(logit_x, label_x)[0].item(),\n            \"loss_u\": loss_u.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input_x = batch_x[\"img\"][0]\n        label_x = batch_x[\"label\"]\n        input_u = batch_u[\"img\"]\n        input_u1, input_u2 = input_u\n\n        input_x = input_x.to(self.device)\n        label_x = label_x.to(self.device)\n        input_u1 = input_u1.to(self.device)\n        input_u2 = input_u2.to(self.device)\n\n        return input_x, label_x, input_u1, input_u2\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py",
    "content": "from torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\n\n\n@TRAINER_REGISTRY.register()\nclass SourceOnly(TrainerXU):\n    \"\"\"Baseline model for domain adaptation, which is\n    trained using source data only.\n    \"\"\"\n\n    def forward_backward(self, batch_x, batch_u):\n        input, label = self.parse_batch_train(batch_x, batch_u)\n        output = self.model(input)\n        loss = F.cross_entropy(output, label)\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss\": loss.item(),\n            \"acc\": compute_accuracy(output, label)[0].item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input = batch_x[\"img\"]\n        label = batch_x[\"label\"]\n        input = input.to(self.device)\n        label = label.to(self.device)\n        return input, label\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py",
    "content": "from .ddaig import DDAIG\nfrom .daeldg import DAELDG\nfrom .vanilla import Vanilla\nfrom .crossgrad import CrossGrad\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.engine.trainer import SimpleNet\n\n\n@TRAINER_REGISTRY.register()\nclass CrossGrad(TrainerX):\n    \"\"\"Cross-gradient training.\n\n    https://arxiv.org/abs/1804.10745.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.eps_f = cfg.TRAINER.CG.EPS_F\n        self.eps_d = cfg.TRAINER.CG.EPS_D\n        self.alpha_f = cfg.TRAINER.CG.ALPHA_F\n        self.alpha_d = cfg.TRAINER.CG.ALPHA_D\n\n    def build_model(self):\n        cfg = self.cfg\n\n        print(\"Building F\")\n        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)\n        self.F.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.F)))\n        self.optim_F = build_optimizer(self.F, cfg.OPTIM)\n        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)\n        self.register_model(\"F\", self.F, self.optim_F, self.sched_F)\n\n        print(\"Building D\")\n        self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)\n        self.D.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.D)))\n        self.optim_D = build_optimizer(self.D, cfg.OPTIM)\n        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)\n        self.register_model(\"D\", self.D, self.optim_D, self.sched_D)\n\n    def forward_backward(self, batch):\n        input, label, domain = self.parse_batch_train(batch)\n\n        input.requires_grad = True\n\n        # Compute domain perturbation\n        loss_d = F.cross_entropy(self.D(input), domain)\n        loss_d.backward()\n        grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)\n        input_d = input.data + self.eps_f * grad_d\n\n        # Compute label perturbation\n        input.grad.data.zero_()\n        loss_f = F.cross_entropy(self.F(input), label)\n        loss_f.backward()\n        grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)\n        input_f = input.data + self.eps_d * grad_f\n\n        input = input.detach()\n\n        # Update label net\n        loss_f1 = F.cross_entropy(self.F(input), label)\n        loss_f2 = F.cross_entropy(self.F(input_d), label)\n        loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2\n        self.model_backward_and_update(loss_f, \"F\")\n\n        # Update domain net\n        loss_d1 = F.cross_entropy(self.D(input), domain)\n        loss_d2 = F.cross_entropy(self.D(input_f), domain)\n        loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2\n        self.model_backward_and_update(loss_d, \"D\")\n\n        loss_summary = {\"loss_f\": loss_f.item(), \"loss_d\": loss_d.item()}\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def model_inference(self, input):\n        return self.F(input)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom dassl.data import DataManager\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.metrics import compute_accuracy\nfrom dassl.engine.trainer import SimpleNet\nfrom dassl.data.transforms import build_transform\nfrom dassl.modeling.ops.utils import create_onehot\n\n\nclass Experts(nn.Module):\n\n    def __init__(self, n_source, fdim, num_classes):\n        super().__init__()\n        self.linears = nn.ModuleList(\n            [nn.Linear(fdim, num_classes) for _ in range(n_source)]\n        )\n        self.softmax = nn.Softmax(dim=1)\n\n    def forward(self, i, x):\n        x = self.linears[i](x)\n        x = self.softmax(x)\n        return x\n\n\n@TRAINER_REGISTRY.register()\nclass DAELDG(TrainerX):\n    \"\"\"Domain Adaptive Ensemble Learning.\n\n    DG version: only use labeled source data.\n\n    https://arxiv.org/abs/2003.07325.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN\n        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE\n        if n_domain <= 0:\n            n_domain = self.num_source_domains\n        self.split_batch = batch_size // n_domain\n        self.n_domain = n_domain\n\n        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE\n\n    def check_cfg(self, cfg):\n        assert cfg.DATALOADER.TRAIN_X.SAMPLER == \"RandomDomainSampler\"\n        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0\n\n    def build_data_loader(self):\n        cfg = self.cfg\n        tfm_train = build_transform(cfg, is_train=True)\n        custom_tfm_train = [tfm_train]\n        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS\n        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)\n        custom_tfm_train += [tfm_train_strong]\n        dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)\n        self.train_loader_x = dm.train_loader_x\n        self.train_loader_u = dm.train_loader_u\n        self.val_loader = dm.val_loader\n        self.test_loader = dm.test_loader\n        self.num_classes = dm.num_classes\n        self.num_source_domains = dm.num_source_domains\n        self.lab2cname = dm.lab2cname\n\n    def build_model(self):\n        cfg = self.cfg\n\n        print(\"Building F\")\n        self.F = SimpleNet(cfg, cfg.MODEL, 0)\n        self.F.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.F)))\n        self.optim_F = build_optimizer(self.F, cfg.OPTIM)\n        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)\n        self.register_model(\"F\", self.F, self.optim_F, self.sched_F)\n        fdim = self.F.fdim\n\n        print(\"Building E\")\n        self.E = Experts(self.num_source_domains, fdim, self.num_classes)\n        self.E.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.E)))\n        self.optim_E = build_optimizer(self.E, cfg.OPTIM)\n        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)\n        self.register_model(\"E\", self.E, self.optim_E, self.sched_E)\n\n    def forward_backward(self, batch):\n        parsed_data = self.parse_batch_train(batch)\n        input, input2, label, domain = parsed_data\n\n        input = torch.split(input, self.split_batch, 0)\n        input2 = torch.split(input2, self.split_batch, 0)\n        label = torch.split(label, self.split_batch, 0)\n        domain = torch.split(domain, self.split_batch, 0)\n        domain = [d[0].item() for d in domain]\n\n        loss_x = 0\n        loss_cr = 0\n        acc = 0\n\n        feat = [self.F(x) for x in input]\n        feat2 = [self.F(x) for x in input2]\n\n        for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):\n            cr_s = [j for j in domain if j != i]\n\n            # Learning expert\n            pred_i = self.E(i, feat_i)\n            loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()\n            expert_label_i = pred_i.detach()\n            acc += compute_accuracy(pred_i.detach(),\n                                    label_i.max(1)[1])[0].item()\n\n            # Consistency regularization\n            cr_pred = []\n            for j in cr_s:\n                pred_j = self.E(j, feat2_i)\n                pred_j = pred_j.unsqueeze(1)\n                cr_pred.append(pred_j)\n            cr_pred = torch.cat(cr_pred, 1)\n            cr_pred = cr_pred.mean(1)\n            loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()\n\n        loss_x /= self.n_domain\n        loss_cr /= self.n_domain\n        acc /= self.n_domain\n\n        loss = 0\n        loss += loss_x\n        loss += loss_cr\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc\": acc,\n            \"loss_cr\": loss_cr.item()\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch):\n        input = batch[\"img\"]\n        input2 = batch[\"img2\"]\n        label = batch[\"label\"]\n        domain = batch[\"domain\"]\n\n        label = create_onehot(label, self.num_classes)\n\n        input = input.to(self.device)\n        input2 = input2.to(self.device)\n        label = label.to(self.device)\n\n        return input, input2, label, domain\n\n    def model_inference(self, input):\n        f = self.F(input)\n        p = []\n        for k in range(self.num_source_domains):\n            p_k = self.E(k, f)\n            p_k = p_k.unsqueeze(1)\n            p.append(p_k)\n        p = torch.cat(p, 1)\n        p = p.mean(1)\n        return p\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import count_num_param\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.modeling import build_network\nfrom dassl.engine.trainer import SimpleNet\n\n\n@TRAINER_REGISTRY.register()\nclass DDAIG(TrainerX):\n    \"\"\"Deep Domain-Adversarial Image Generation.\n\n    https://arxiv.org/abs/2003.06054.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.lmda = cfg.TRAINER.DDAIG.LMDA\n        self.clamp = cfg.TRAINER.DDAIG.CLAMP\n        self.clamp_min = cfg.TRAINER.DDAIG.CLAMP_MIN\n        self.clamp_max = cfg.TRAINER.DDAIG.CLAMP_MAX\n        self.warmup = cfg.TRAINER.DDAIG.WARMUP\n        self.alpha = cfg.TRAINER.DDAIG.ALPHA\n\n    def build_model(self):\n        cfg = self.cfg\n\n        print(\"Building F\")\n        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)\n        self.F.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.F)))\n        self.optim_F = build_optimizer(self.F, cfg.OPTIM)\n        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)\n        self.register_model(\"F\", self.F, self.optim_F, self.sched_F)\n\n        print(\"Building D\")\n        self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)\n        self.D.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.D)))\n        self.optim_D = build_optimizer(self.D, cfg.OPTIM)\n        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)\n        self.register_model(\"D\", self.D, self.optim_D, self.sched_D)\n\n        print(\"Building G\")\n        self.G = build_network(cfg.TRAINER.DDAIG.G_ARCH, verbose=cfg.VERBOSE)\n        self.G.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.G)))\n        self.optim_G = build_optimizer(self.G, cfg.OPTIM)\n        self.sched_G = build_lr_scheduler(self.optim_G, cfg.OPTIM)\n        self.register_model(\"G\", self.G, self.optim_G, self.sched_G)\n\n    def forward_backward(self, batch):\n        input, label, domain = self.parse_batch_train(batch)\n\n        #############\n        # Update G\n        #############\n        input_p = self.G(input, lmda=self.lmda)\n        if self.clamp:\n            input_p = torch.clamp(\n                input_p, min=self.clamp_min, max=self.clamp_max\n            )\n        loss_g = 0\n        # Minimize label loss\n        loss_g += F.cross_entropy(self.F(input_p), label)\n        # Maximize domain loss\n        loss_g -= F.cross_entropy(self.D(input_p), domain)\n        self.model_backward_and_update(loss_g, \"G\")\n\n        # Perturb data with new G\n        with torch.no_grad():\n            input_p = self.G(input, lmda=self.lmda)\n            if self.clamp:\n                input_p = torch.clamp(\n                    input_p, min=self.clamp_min, max=self.clamp_max\n                )\n\n        #############\n        # Update F\n        #############\n        loss_f = F.cross_entropy(self.F(input), label)\n        if (self.epoch + 1) > self.warmup:\n            loss_fp = F.cross_entropy(self.F(input_p), label)\n            loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp\n        self.model_backward_and_update(loss_f, \"F\")\n\n        #############\n        # Update D\n        #############\n        loss_d = F.cross_entropy(self.D(input), domain)\n        self.model_backward_and_update(loss_d, \"D\")\n\n        loss_summary = {\n            \"loss_g\": loss_g.item(),\n            \"loss_f\": loss_f.item(),\n            \"loss_d\": loss_d.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def model_inference(self, input):\n        return self.F(input)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py",
    "content": "from torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.metrics import compute_accuracy\n\n\n@TRAINER_REGISTRY.register()\nclass Vanilla(TrainerX):\n    \"\"\"Vanilla baseline.\"\"\"\n\n    def forward_backward(self, batch):\n        input, label = self.parse_batch_train(batch)\n        output = self.model(input)\n        loss = F.cross_entropy(output, label)\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss\": loss.item(),\n            \"acc\": compute_accuracy(output, label)[0].item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch):\n        input = batch[\"img\"]\n        label = batch[\"label\"]\n        input = input.to(self.device)\n        label = label.to(self.device)\n        return input, label\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py",
    "content": "from .entmin import EntMin\nfrom .fixmatch import FixMatch\nfrom .mixmatch import MixMatch\nfrom .mean_teacher import MeanTeacher\nfrom .sup_baseline import SupBaseline\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\n\n\n@TRAINER_REGISTRY.register()\nclass EntMin(TrainerXU):\n    \"\"\"Entropy Minimization.\n\n    http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.lmda = cfg.TRAINER.ENTMIN.LMDA\n\n    def forward_backward(self, batch_x, batch_u):\n        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)\n\n        output_x = self.model(input_x)\n        loss_x = F.cross_entropy(output_x, label_x)\n\n        output_u = F.softmax(self.model(input_u), 1)\n        loss_u = (-output_u * torch.log(output_u + 1e-5)).sum(1).mean()\n\n        loss = loss_x + loss_u * self.lmda\n\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc_x\": compute_accuracy(output_x, label_x)[0].item(),\n            \"loss_u\": loss_u.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.data import DataManager\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\nfrom dassl.data.transforms import build_transform\n\n\n@TRAINER_REGISTRY.register()\nclass FixMatch(TrainerXU):\n    \"\"\"FixMatch: Simplifying Semi-Supervised Learning with\n    Consistency and Confidence.\n\n    https://arxiv.org/abs/2001.07685.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U\n        self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE\n\n    def check_cfg(self, cfg):\n        assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0\n\n    def build_data_loader(self):\n        cfg = self.cfg\n        tfm_train = build_transform(cfg, is_train=True)\n        custom_tfm_train = [tfm_train]\n        choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS\n        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)\n        custom_tfm_train += [tfm_train_strong]\n        self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)\n        self.train_loader_x = self.dm.train_loader_x\n        self.train_loader_u = self.dm.train_loader_u\n        self.val_loader = self.dm.val_loader\n        self.test_loader = self.dm.test_loader\n        self.num_classes = self.dm.num_classes\n\n    def assess_y_pred_quality(self, y_pred, y_true, mask):\n        n_masked_correct = (y_pred.eq(y_true).float() * mask).sum()\n        acc_thre = n_masked_correct / (mask.sum() + 1e-5)\n        acc_raw = y_pred.eq(y_true).sum() / y_pred.numel()  # raw accuracy\n        keep_rate = mask.sum() / mask.numel()\n        output = {\n            \"acc_thre\": acc_thre,\n            \"acc_raw\": acc_raw,\n            \"keep_rate\": keep_rate\n        }\n        return output\n\n    def forward_backward(self, batch_x, batch_u):\n        parsed_data = self.parse_batch_train(batch_x, batch_u)\n        input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data\n        input_u = torch.cat([input_x, input_u], 0)\n        input_u2 = torch.cat([input_x2, input_u2], 0)\n        n_x = input_x.size(0)\n\n        # Generate pseudo labels\n        with torch.no_grad():\n            output_u = F.softmax(self.model(input_u), 1)\n            max_prob, label_u_pred = output_u.max(1)\n            mask_u = (max_prob >= self.conf_thre).float()\n\n            # Evaluate pseudo labels' accuracy\n            y_u_pred_stats = self.assess_y_pred_quality(\n                label_u_pred[n_x:], label_u, mask_u[n_x:]\n            )\n\n        # Supervised loss\n        output_x = self.model(input_x)\n        loss_x = F.cross_entropy(output_x, label_x)\n\n        # Unsupervised loss\n        output_u = self.model(input_u2)\n        loss_u = F.cross_entropy(output_u, label_u_pred, reduction=\"none\")\n        loss_u = (loss_u * mask_u).mean()\n\n        loss = loss_x + loss_u * self.weight_u\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc_x\": compute_accuracy(output_x, label_x)[0].item(),\n            \"loss_u\": loss_u.item(),\n            \"y_u_pred_acc_raw\": y_u_pred_stats[\"acc_raw\"],\n            \"y_u_pred_acc_thre\": y_u_pred_stats[\"acc_thre\"],\n            \"y_u_pred_keep\": y_u_pred_stats[\"keep_rate\"],\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input_x = batch_x[\"img\"]\n        input_x2 = batch_x[\"img2\"]\n        label_x = batch_x[\"label\"]\n        input_u = batch_u[\"img\"]\n        input_u2 = batch_u[\"img2\"]\n        # label_u is used only for evaluating pseudo labels' accuracy\n        label_u = batch_u[\"label\"]\n\n        input_x = input_x.to(self.device)\n        input_x2 = input_x2.to(self.device)\n        label_x = label_x.to(self.device)\n        input_u = input_u.to(self.device)\n        input_u2 = input_u2.to(self.device)\n        label_u = label_u.to(self.device)\n\n        return input_x, input_x2, label_x, input_u, input_u2, label_u\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py",
    "content": "import copy\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\nfrom dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update\n\n\n@TRAINER_REGISTRY.register()\nclass MeanTeacher(TrainerXU):\n    \"\"\"Mean teacher.\n\n    https://arxiv.org/abs/1703.01780.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.weight_u = cfg.TRAINER.MEANTEA.WEIGHT_U\n        self.ema_alpha = cfg.TRAINER.MEANTEA.EMA_ALPHA\n        self.rampup = cfg.TRAINER.MEANTEA.RAMPUP\n\n        self.teacher = copy.deepcopy(self.model)\n        self.teacher.train()\n        for param in self.teacher.parameters():\n            param.requires_grad_(False)\n\n    def forward_backward(self, batch_x, batch_u):\n        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)\n\n        logit_x = self.model(input_x)\n        loss_x = F.cross_entropy(logit_x, label_x)\n\n        target_u = F.softmax(self.teacher(input_u), 1)\n        prob_u = F.softmax(self.model(input_u), 1)\n        loss_u = ((prob_u - target_u)**2).sum(1).mean()\n\n        weight_u = self.weight_u * sigmoid_rampup(self.epoch, self.rampup)\n        loss = loss_x + loss_u*weight_u\n        self.model_backward_and_update(loss)\n\n        global_step = self.batch_idx + self.epoch * self.num_batches\n        ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)\n        ema_model_update(self.model, self.teacher, ema_alpha)\n\n        loss_summary = {\n            \"loss_x\": loss_x.item(),\n            \"acc_x\": compute_accuracy(logit_x, label_x)[0].item(),\n            \"loss_u\": loss_u.item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.modeling.ops import mixup\nfrom dassl.modeling.ops.utils import (\n    sharpen_prob, create_onehot, linear_rampup, shuffle_index\n)\n\n\n@TRAINER_REGISTRY.register()\nclass MixMatch(TrainerXU):\n    \"\"\"MixMatch: A Holistic Approach to Semi-Supervised Learning.\n\n    https://arxiv.org/abs/1905.02249.\n    \"\"\"\n\n    def __init__(self, cfg):\n        super().__init__(cfg)\n        self.weight_u = cfg.TRAINER.MIXMATCH.WEIGHT_U\n        self.temp = cfg.TRAINER.MIXMATCH.TEMP\n        self.beta = cfg.TRAINER.MIXMATCH.MIXUP_BETA\n        self.rampup = cfg.TRAINER.MIXMATCH.RAMPUP\n\n    def check_cfg(self, cfg):\n        assert cfg.DATALOADER.K_TRANSFORMS > 1\n\n    def forward_backward(self, batch_x, batch_u):\n        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)\n        num_x = input_x.shape[0]\n\n        global_step = self.batch_idx + self.epoch * self.num_batches\n        weight_u = self.weight_u * linear_rampup(global_step, self.rampup)\n\n        # Generate pseudo-label for unlabeled data\n        with torch.no_grad():\n            output_u = 0\n            for input_ui in input_u:\n                output_ui = F.softmax(self.model(input_ui), 1)\n                output_u += output_ui\n            output_u /= len(input_u)\n            label_u = sharpen_prob(output_u, self.temp)\n            label_u = [label_u] * len(input_u)\n            label_u = torch.cat(label_u, 0)\n            input_u = torch.cat(input_u, 0)\n\n        # Combine and shuffle labeled and unlabeled data\n        input_xu = torch.cat([input_x, input_u], 0)\n        label_xu = torch.cat([label_x, label_u], 0)\n        input_xu, label_xu = shuffle_index(input_xu, label_xu)\n\n        # Mixup\n        input_x, label_x = mixup(\n            input_x,\n            input_xu[:num_x],\n            label_x,\n            label_xu[:num_x],\n            self.beta,\n            preserve_order=True,\n        )\n\n        input_u, label_u = mixup(\n            input_u,\n            input_xu[num_x:],\n            label_u,\n            label_xu[num_x:],\n            self.beta,\n            preserve_order=True,\n        )\n\n        # Compute losses\n        output_x = F.softmax(self.model(input_x), 1)\n        loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()\n\n        output_u = F.softmax(self.model(input_u), 1)\n        loss_u = ((label_u - output_u)**2).mean()\n\n        loss = loss_x + loss_u*weight_u\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\"loss_x\": loss_x.item(), \"loss_u\": loss_u.item()}\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input_x = batch_x[\"img\"][0]\n        label_x = batch_x[\"label\"]\n        label_x = create_onehot(label_x, self.num_classes)\n        input_u = batch_u[\"img\"]\n\n        input_x = input_x.to(self.device)\n        label_x = label_x.to(self.device)\n        input_u = [input_ui.to(self.device) for input_ui in input_u]\n\n        return input_x, label_x, input_u\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py",
    "content": "from torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import compute_accuracy\n\n\n@TRAINER_REGISTRY.register()\nclass SupBaseline(TrainerXU):\n    \"\"\"Supervised Baseline.\"\"\"\n\n    def forward_backward(self, batch_x, batch_u):\n        input, label = self.parse_batch_train(batch_x, batch_u)\n        output = self.model(input)\n        loss = F.cross_entropy(output, label)\n        self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss\": loss.item(),\n            \"acc\": compute_accuracy(output, label)[0].item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input = batch_x[\"img\"]\n        label = batch_x[\"label\"]\n        input = input.to(self.device)\n        label = label.to(self.device)\n        return input, label\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/trainer.py",
    "content": "import json\nimport time\nimport numpy as np\nimport os.path as osp\nimport datetime\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom dassl.data import DataManager\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.utils import (\n    MetricMeter, AverageMeter, tolist_if_not, count_num_param, load_checkpoint,\n    save_checkpoint, mkdir_if_missing, resume_from_checkpoint,\n    load_pretrained_weights\n)\nfrom dassl.modeling import build_head, build_backbone\nfrom dassl.evaluation import build_evaluator\n\n\nclass SimpleNet(nn.Module):\n    \"\"\"A simple neural network composed of a CNN backbone\n    and optionally a head such as mlp for classification.\n    \"\"\"\n\n    def __init__(self, cfg, model_cfg, num_classes, **kwargs):\n        super().__init__()\n        self.backbone = build_backbone(\n            model_cfg.BACKBONE.NAME,\n            verbose=cfg.VERBOSE,\n            pretrained=model_cfg.BACKBONE.PRETRAINED,\n            **kwargs,\n        )\n        fdim = self.backbone.out_features\n\n        self.head = None\n        if model_cfg.HEAD.NAME and model_cfg.HEAD.HIDDEN_LAYERS:\n            self.head = build_head(\n                model_cfg.HEAD.NAME,\n                verbose=cfg.VERBOSE,\n                in_features=fdim,\n                hidden_layers=model_cfg.HEAD.HIDDEN_LAYERS,\n                activation=model_cfg.HEAD.ACTIVATION,\n                bn=model_cfg.HEAD.BN,\n                dropout=model_cfg.HEAD.DROPOUT,\n                **kwargs,\n            )\n            fdim = self.head.out_features\n\n        self.classifier = None\n        if num_classes > 0:\n            self.classifier = nn.Linear(fdim, num_classes)\n\n        self._fdim = fdim\n\n    @property\n    def fdim(self):\n        return self._fdim\n\n    def forward(self, x, return_feature=False):\n        f = self.backbone(x)\n        if self.head is not None:\n            f = self.head(f)\n\n        if self.classifier is None:\n            return f\n\n        y = self.classifier(f)\n\n        if return_feature:\n            return y, f\n\n        return y\n\n\nclass TrainerBase:\n    \"\"\"Base class for iterative trainer.\"\"\"\n\n    def __init__(self):\n        self._models = OrderedDict()\n        self._optims = OrderedDict()\n        self._scheds = OrderedDict()\n        self._writer = None\n\n    def register_model(self, name=\"model\", model=None, optim=None, sched=None):\n        if self.__dict__.get(\"_models\") is None:\n            raise AttributeError(\n                \"Cannot assign model before super().__init__() call\"\n            )\n\n        if self.__dict__.get(\"_optims\") is None:\n            raise AttributeError(\n                \"Cannot assign optim before super().__init__() call\"\n            )\n\n        if self.__dict__.get(\"_scheds\") is None:\n            raise AttributeError(\n                \"Cannot assign sched before super().__init__() call\"\n            )\n\n        assert name not in self._models, \"Found duplicate model names\"\n\n        self._models[name] = model\n        self._optims[name] = optim\n        self._scheds[name] = sched\n\n    def get_model_names(self, names=None):\n        names_real = list(self._models.keys())\n        if names is not None:\n            names = tolist_if_not(names)\n            for name in names:\n                assert name in names_real\n            return names\n        else:\n            return names_real\n\n    def save_model(self, epoch, directory, is_best=False, model_name=\"\"):\n        names = self.get_model_names()\n\n        for name in names:\n            model_dict = self._models[name].state_dict()\n\n            optim_dict = None\n            if self._optims[name] is not None:\n                optim_dict = self._optims[name].state_dict()\n\n            sched_dict = None\n            if self._scheds[name] is not None:\n                sched_dict = self._scheds[name].state_dict()\n\n            save_checkpoint(\n                {\n                    \"state_dict\": model_dict,\n                    \"epoch\": epoch + 1,\n                    \"optimizer\": optim_dict,\n                    \"scheduler\": sched_dict,\n                },\n                osp.join(directory, name),\n                is_best=is_best,\n                model_name=model_name,\n            )\n\n    def resume_model_if_exist(self, directory):\n        names = self.get_model_names()\n        file_missing = False\n\n        for name in names:\n            path = osp.join(directory, name)\n            if not osp.exists(path):\n                file_missing = True\n                break\n\n        if file_missing:\n            print(\"No checkpoint found, train from scratch\")\n            return 0\n\n        print(\n            'Found checkpoint in \"{}\". Will resume training'.format(directory)\n        )\n\n        for name in names:\n            path = osp.join(directory, name)\n            start_epoch = resume_from_checkpoint(\n                path, self._models[name], self._optims[name],\n                self._scheds[name]\n            )\n\n        return start_epoch\n\n    def load_model(self, directory, epoch=None):\n        if not directory:\n            print(\n                \"Note that load_model() is skipped as no pretrained \"\n                \"model is given (ignore this if it's done on purpose)\"\n            )\n            return\n\n        names = self.get_model_names()\n\n        # By default, the best model is loaded\n        model_file = \"model-best.pth.tar\"\n\n        if epoch is not None:\n            model_file = \"model.pth.tar-\" + str(epoch)\n\n        for name in names:\n            model_path = osp.join(directory, name, model_file)\n\n            if not osp.exists(model_path):\n                raise FileNotFoundError(\n                    'Model not found at \"{}\"'.format(model_path)\n                )\n\n            checkpoint = load_checkpoint(model_path)\n            state_dict = checkpoint[\"state_dict\"]\n            epoch = checkpoint[\"epoch\"]\n\n            print(\n                \"Loading weights to {} \"\n                'from \"{}\" (epoch = {})'.format(name, model_path, epoch)\n            )\n            self._models[name].load_state_dict(state_dict)\n\n    def set_model_mode(self, mode=\"train\", names=None):\n        names = self.get_model_names(names)\n\n        for name in names:\n            if mode == \"train\":\n                self._models[name].train()\n            elif mode in [\"test\", \"eval\"]:\n                self._models[name].eval()\n            else:\n                raise KeyError\n\n    def update_lr(self, names=None):\n        names = self.get_model_names(names)\n\n        for name in names:\n            if self._scheds[name] is not None:\n                self._scheds[name].step()\n\n    def detect_anomaly(self, loss):\n        if not torch.isfinite(loss).all():\n            raise FloatingPointError(\"Loss is infinite or NaN!\")\n\n    def init_writer(self, log_dir):\n        if self.__dict__.get(\"_writer\") is None or self._writer is None:\n            print(\n                \"Initializing summary writer for tensorboard \"\n                \"with log_dir={}\".format(log_dir)\n            )\n            self._writer = SummaryWriter(log_dir=log_dir)\n\n    def close_writer(self):\n        if self._writer is not None:\n            self._writer.close()\n\n    def write_scalar(self, tag, scalar_value, global_step=None):\n        if self._writer is None:\n            # Do nothing if writer is not initialized\n            # Note that writer is only used when training is needed\n            pass\n        else:\n            self._writer.add_scalar(tag, scalar_value, global_step)\n\n    def train(self, start_epoch, max_epoch):\n        \"\"\"Generic training loops.\"\"\"\n        self.start_epoch = start_epoch\n        self.max_epoch = max_epoch\n\n        self.before_train()\n        for self.epoch in range(self.start_epoch, self.max_epoch):\n            self.before_epoch()\n            self.run_epoch()\n            self.after_epoch()\n        self.after_train()\n\n    def before_train(self):\n        pass\n\n    def after_train(self):\n        pass\n\n    def before_epoch(self):\n        pass\n\n    def after_epoch(self):\n        pass\n\n    def run_epoch(self):\n        raise NotImplementedError\n\n    def test(self):\n        raise NotImplementedError\n\n    def parse_batch_train(self, batch):\n        raise NotImplementedError\n\n    def parse_batch_test(self, batch):\n        raise NotImplementedError\n\n    def forward_backward(self, batch):\n        raise NotImplementedError\n\n    def model_inference(self, input):\n        raise NotImplementedError\n\n    def model_zero_grad(self, names=None):\n        names = self.get_model_names(names)\n        for name in names:\n            if self._optims[name] is not None:\n                self._optims[name].zero_grad()\n\n    def model_backward(self, loss):\n        self.detect_anomaly(loss)\n        loss.backward()\n\n    def model_update(self, names=None):\n        names = self.get_model_names(names)\n        for name in names:\n            if self._optims[name] is not None:\n                self._optims[name].step()\n\n    def model_backward_and_update(self, loss, names=None):\n        self.model_zero_grad(names)\n        self.model_backward(loss)\n        self.model_update(names)\n\n    def prograd_backward_and_update(\n        self, loss_a, loss_b, lambda_=1, names=None\n    ):\n        # loss_b not increase is okay\n        # loss_a has to decline\n        self.model_zero_grad(names)\n        # get name of the model parameters\n        names = self.get_model_names(names)\n        # backward loss_a\n        self.detect_anomaly(loss_b)\n        loss_b.backward(retain_graph=True)\n        # normalize gradient\n        b_grads = []\n        for name in names:\n            for p in self._models[name].parameters():\n                b_grads.append(p.grad.clone())\n\n        # optimizer don't step\n        for name in names:\n            self._optims[name].zero_grad()\n\n        # backward loss_a\n        self.detect_anomaly(loss_a)\n        loss_a.backward()\n        for name in names:\n            for p, b_grad in zip(self._models[name].parameters(), b_grads):\n                # calculate cosine distance\n                b_grad_norm = b_grad / torch.linalg.norm(b_grad)\n                a_grad = p.grad.clone()\n                a_grad_norm = a_grad / torch.linalg.norm(a_grad)\n\n                if torch.dot(a_grad_norm.flatten(), b_grad_norm.flatten()) < 0:\n                    p.grad = a_grad - lambda_ * torch.dot(\n                        a_grad.flatten(), b_grad_norm.flatten()\n                    ) * b_grad_norm\n\n        # optimizer\n        for name in names:\n            self._optims[name].step()\n\n\nclass SimpleTrainer(TrainerBase):\n    \"\"\"A simple trainer class implementing generic functions.\"\"\"\n\n    def __init__(self, cfg):\n        super().__init__()\n        self.check_cfg(cfg)\n\n        if torch.cuda.is_available() and cfg.USE_CUDA:\n            self.device = torch.device(\"cuda\")\n        else:\n            self.device = torch.device(\"cpu\")\n\n        # Save as attributes some frequently used variables\n        self.start_epoch = self.epoch = 0\n        self.max_epoch = cfg.OPTIM.MAX_EPOCH\n        self.output_dir = cfg.OUTPUT_DIR\n\n        self.cfg = cfg\n        self.build_data_loader()\n        self.build_model()\n        self.evaluator = build_evaluator(cfg, lab2cname=self.lab2cname)\n        self.best_result = -np.inf\n\n    def check_cfg(self, cfg):\n        \"\"\"Check whether some variables are set correctly for\n        the trainer (optional).\n\n        For example, a trainer might require a particular sampler\n        for training such as 'RandomDomainSampler', so it is good\n        to do the checking:\n\n        assert cfg.DATALOADER.SAMPLER_TRAIN == 'RandomDomainSampler'\n        \"\"\"\n        pass\n\n    def build_data_loader(self):\n        \"\"\"Create essential data-related attributes.\n\n        A re-implementation of this method must create the\n        same attributes (except self.dm).\n        \"\"\"\n        dm = DataManager(self.cfg)\n\n        self.train_loader_x = dm.train_loader_x\n        self.train_loader_u = dm.train_loader_u  # optional, can be None\n        self.val_loader = dm.val_loader  # optional, can be None\n        self.test_loader = dm.test_loader\n        self.num_classes = dm.num_classes\n        self.num_source_domains = dm.num_source_domains\n        self.lab2cname = dm.lab2cname  # dict {label: classname}\n\n        self.dm = dm\n\n    def build_model(self):\n        \"\"\"Build and register model.\n\n        The default builds a classification model along with its\n        optimizer and scheduler.\n\n        Custom trainers can re-implement this method if necessary.\n        \"\"\"\n        cfg = self.cfg\n\n        print(\"Building model\")\n        self.model = SimpleNet(cfg, cfg.MODEL, self.num_classes)\n        if cfg.MODEL.INIT_WEIGHTS:\n            load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)\n        self.model.to(self.device)\n        print(\"# params: {:,}\".format(count_num_param(self.model)))\n        self.optim = build_optimizer(self.model, cfg.OPTIM)\n        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)\n        self.register_model(\"model\", self.model, self.optim, self.sched)\n\n        device_count = torch.cuda.device_count()\n        if device_count > 1:\n            print(\n                f\"Detected {device_count} GPUs. Wrap the model with nn.DataParallel\"\n            )\n            self.model = nn.DataParallel(self.model)\n\n    def train(self):\n        super().train(self.start_epoch, self.max_epoch)\n\n    def before_train(self):\n        directory = self.cfg.OUTPUT_DIR\n        if self.cfg.RESUME:\n            directory = self.cfg.RESUME\n        self.start_epoch = self.resume_model_if_exist(directory)\n\n        # Initialize summary writer\n        writer_dir = osp.join(self.output_dir, \"tensorboard\")\n        mkdir_if_missing(writer_dir)\n        self.init_writer(writer_dir)\n\n        # Remember the starting time (for computing the elapsed time)\n        self.time_start = time.time()\n\n    def after_train(self):\n        print(\"Finished training\")\n\n        do_test = not self.cfg.TEST.NO_TEST\n        if do_test:\n            if self.cfg.TEST.FINAL_MODEL == \"best_val\":\n                print(\"Deploy the model with the best val performance\")\n                self.load_model(self.output_dir)\n            self.test()\n\n        # Show elapsed time\n        elapsed = round(time.time() - self.time_start)\n        elapsed = str(datetime.timedelta(seconds=elapsed))\n        print(\"Elapsed: {}\".format(elapsed))\n\n        # Close writer\n        self.close_writer()\n\n    def after_epoch(self):\n        last_epoch = (self.epoch + 1) == self.max_epoch\n        do_test = not self.cfg.TEST.NO_TEST\n        meet_checkpoint_freq = (\n            (self.epoch + 1) % self.cfg.TRAIN.CHECKPOINT_FREQ == 0\n            if self.cfg.TRAIN.CHECKPOINT_FREQ > 0 else False\n        )\n\n        if do_test and self.cfg.TEST.FINAL_MODEL == \"best_val\":\n            curr_result = self.test(split=\"val\")\n            is_best = curr_result > self.best_result\n            if is_best:\n                self.best_result = curr_result\n                self.save_model(\n                    self.epoch,\n                    self.output_dir,\n                    model_name=\"model-best.pth.tar\"\n                )\n\n        if meet_checkpoint_freq or last_epoch:\n            self.save_model(self.epoch, self.output_dir)\n\n    @torch.no_grad()\n    def output_test(self, split=None):\n        \"\"\"testing pipline, which could also output the results.\"\"\"\n        self.set_model_mode(\"eval\")\n        self.evaluator.reset()\n\n        output_file = osp.join(self.cfg.OUTPUT_DIR, 'output.json')\n        res_json = {}\n\n        if split is None:\n            split = self.cfg.TEST.SPLIT\n\n        if split == \"val\" and self.val_loader is not None:\n            data_loader = self.val_loader\n            print(\"Do evaluation on {} set\".format(split))\n        else:\n            data_loader = self.test_loader\n            print(\"Do evaluation on test set\")\n\n        for batch_idx, batch in enumerate(tqdm(data_loader)):\n            img_path = batch['impath']\n            input, label = self.parse_batch_test(batch)\n            output = self.model_inference(input)\n            self.evaluator.process(output, label)\n            for i in range(len(img_path)):\n                res_json[img_path[i]] = {\n                    'predict': output[i].cpu().numpy().tolist(),\n                    'gt': label[i].cpu().numpy().tolist()\n                }\n        with open(output_file, 'w') as f:\n            json.dump(res_json, f)\n        results = self.evaluator.evaluate()\n\n        for k, v in results.items():\n            tag = \"{}/{}\".format(split, k)\n            self.write_scalar(tag, v, self.epoch)\n\n        return list(results.values())[0]\n\n    @torch.no_grad()\n    def test(self, split=None):\n        \"\"\"A generic testing pipeline.\"\"\"\n        self.set_model_mode(\"eval\")\n        self.evaluator.reset()\n\n        if split is None:\n            split = self.cfg.TEST.SPLIT\n\n        if split == \"val\" and self.val_loader is not None:\n            data_loader = self.val_loader\n            print(\"Do evaluation on {} set\".format(split))\n        else:\n            data_loader = self.test_loader\n            print(\"Do evaluation on test set\")\n\n        for batch_idx, batch in enumerate(tqdm(data_loader)):\n            input, label = self.parse_batch_test(batch)\n            output = self.model_inference(input)\n            self.evaluator.process(output, label)\n\n        results = self.evaluator.evaluate()\n\n        for k, v in results.items():\n            tag = \"{}/{}\".format(split, k)\n            self.write_scalar(tag, v, self.epoch)\n\n        return list(results.values())[0]\n\n    def model_inference(self, input):\n        return self.model(input)\n\n    def parse_batch_test(self, batch):\n        input = batch[\"img\"]\n        label = batch[\"label\"]\n\n        input = input.to(self.device)\n        label = label.to(self.device)\n\n        return input, label\n\n    def get_current_lr(self, names=None):\n        names = self.get_model_names(names)\n        name = names[0]\n        return self._optims[name].param_groups[0][\"lr\"]\n\n\nclass TrainerXU(SimpleTrainer):\n    \"\"\"A base trainer using both labeled and unlabeled data.\n\n    In the context of domain adaptation, labeled and unlabeled data\n    come from source and target domains respectively.\n\n    When it comes to semi-supervised learning, all data comes from the\n    same domain.\n    \"\"\"\n\n    def run_epoch(self):\n        self.set_model_mode(\"train\")\n        losses = MetricMeter()\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n\n        # Decide to iterate over labeled or unlabeled dataset\n        len_train_loader_x = len(self.train_loader_x)\n        len_train_loader_u = len(self.train_loader_u)\n        if self.cfg.TRAIN.COUNT_ITER == \"train_x\":\n            self.num_batches = len_train_loader_x\n        elif self.cfg.TRAIN.COUNT_ITER == \"train_u\":\n            self.num_batches = len_train_loader_u\n        elif self.cfg.TRAIN.COUNT_ITER == \"smaller_one\":\n            self.num_batches = min(len_train_loader_x, len_train_loader_u)\n        else:\n            raise ValueError\n\n        train_loader_x_iter = iter(self.train_loader_x)\n        train_loader_u_iter = iter(self.train_loader_u)\n\n        end = time.time()\n        for self.batch_idx in range(self.num_batches):\n            try:\n                batch_x = next(train_loader_x_iter)\n            except StopIteration:\n                train_loader_x_iter = iter(self.train_loader_x)\n                batch_x = next(train_loader_x_iter)\n\n            try:\n                batch_u = next(train_loader_u_iter)\n            except StopIteration:\n                train_loader_u_iter = iter(self.train_loader_u)\n                batch_u = next(train_loader_u_iter)\n\n            data_time.update(time.time() - end)\n            loss_summary = self.forward_backward(batch_x, batch_u)\n            batch_time.update(time.time() - end)\n            losses.update(loss_summary)\n\n            if (\n                self.batch_idx + 1\n            ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:\n                nb_remain = 0\n                nb_remain += self.num_batches - self.batch_idx - 1\n                nb_remain += (\n                    self.max_epoch - self.epoch - 1\n                ) * self.num_batches\n                eta_seconds = batch_time.avg * nb_remain\n                eta = str(datetime.timedelta(seconds=int(eta_seconds)))\n                print(\n                    \"epoch [{0}/{1}][{2}/{3}]\\t\"\n                    \"time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t\"\n                    \"data {data_time.val:.3f} ({data_time.avg:.3f})\\t\"\n                    \"eta {eta}\\t\"\n                    \"{losses}\\t\"\n                    \"lr {lr:.6e}\".format(\n                        self.epoch + 1,\n                        self.max_epoch,\n                        self.batch_idx + 1,\n                        self.num_batches,\n                        batch_time=batch_time,\n                        data_time=data_time,\n                        eta=eta,\n                        losses=losses,\n                        lr=self.get_current_lr(),\n                    )\n                )\n\n            n_iter = self.epoch * self.num_batches + self.batch_idx\n            for name, meter in losses.meters.items():\n                self.write_scalar(\"train/\" + name, meter.avg, n_iter)\n            self.write_scalar(\"train/lr\", self.get_current_lr(), n_iter)\n\n            end = time.time()\n\n    def parse_batch_train(self, batch_x, batch_u):\n        input_x = batch_x[\"img\"]\n        label_x = batch_x[\"label\"]\n        input_u = batch_u[\"img\"]\n\n        input_x = input_x.to(self.device)\n        label_x = label_x.to(self.device)\n        input_u = input_u.to(self.device)\n\n        return input_x, label_x, input_u\n\n\nclass TrainerX(SimpleTrainer):\n    \"\"\"A base trainer using labeled data only.\"\"\"\n\n    def run_epoch(self):\n        self.set_model_mode(\"train\")\n        losses = MetricMeter()\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        self.num_batches = len(self.train_loader_x)\n\n        end = time.time()\n        for self.batch_idx, batch in enumerate(self.train_loader_x):\n            data_time.update(time.time() - end)\n            loss_summary = self.forward_backward(batch)\n            batch_time.update(time.time() - end)\n            losses.update(loss_summary)\n\n            if (\n                self.batch_idx + 1\n            ) % self.cfg.TRAIN.PRINT_FREQ == 0 or self.num_batches < self.cfg.TRAIN.PRINT_FREQ:\n                nb_remain = 0\n                nb_remain += self.num_batches - self.batch_idx - 1\n                nb_remain += (\n                    self.max_epoch - self.epoch - 1\n                ) * self.num_batches\n                eta_seconds = batch_time.avg * nb_remain\n                eta = str(datetime.timedelta(seconds=int(eta_seconds)))\n                print(\n                    \"epoch [{0}/{1}][{2}/{3}]\\t\"\n                    \"time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t\"\n                    \"data {data_time.val:.3f} ({data_time.avg:.3f})\\t\"\n                    \"eta {eta}\\t\"\n                    \"{losses}\\t\"\n                    \"lr {lr:.6e}\".format(\n                        self.epoch + 1,\n                        self.max_epoch,\n                        self.batch_idx + 1,\n                        self.num_batches,\n                        batch_time=batch_time,\n                        data_time=data_time,\n                        eta=eta,\n                        losses=losses,\n                        lr=self.get_current_lr(),\n                    )\n                )\n\n            n_iter = self.epoch * self.num_batches + self.batch_idx\n            for name, meter in losses.meters.items():\n                self.write_scalar(\"train/\" + name, meter.avg, n_iter)\n            self.write_scalar(\"train/lr\", self.get_current_lr(), n_iter)\n\n            end = time.time()\n\n    def parse_batch_train(self, batch):\n        input = batch[\"img\"]\n        label = batch[\"label\"]\n        domain = batch[\"domain\"]\n\n        input = input.to(self.device)\n        label = label.to(self.device)\n        domain = domain.to(self.device)\n\n        return input, label, domain\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/evaluation/__init__.py",
    "content": "from .build import build_evaluator, EVALUATOR_REGISTRY  # isort:skip\n\nfrom .evaluator import EvaluatorBase, Classification\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/evaluation/build.py",
    "content": "from dassl.utils import Registry, check_availability\n\nEVALUATOR_REGISTRY = Registry(\"EVALUATOR\")\n\n\ndef build_evaluator(cfg, **kwargs):\n    avai_evaluators = EVALUATOR_REGISTRY.registered_names()\n    check_availability(cfg.TEST.EVALUATOR, avai_evaluators)\n    if cfg.VERBOSE:\n        print(\"Loading evaluator: {}\".format(cfg.TEST.EVALUATOR))\n    return EVALUATOR_REGISTRY.get(cfg.TEST.EVALUATOR)(cfg, **kwargs)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py",
    "content": "import numpy as np\nimport os.path as osp\nfrom collections import OrderedDict, defaultdict\nimport torch\nfrom sklearn.metrics import f1_score, confusion_matrix\n\nfrom .build import EVALUATOR_REGISTRY\n\n\nclass EvaluatorBase:\n    \"\"\"Base evaluator.\"\"\"\n\n    def __init__(self, cfg):\n        self.cfg = cfg\n\n    def reset(self):\n        raise NotImplementedError\n\n    def process(self, mo, gt):\n        raise NotImplementedError\n\n    def evaluate(self):\n        raise NotImplementedError\n\n\n@EVALUATOR_REGISTRY.register()\nclass Classification(EvaluatorBase):\n    \"\"\"Evaluator for classification.\"\"\"\n\n    def __init__(self, cfg, lab2cname=None, **kwargs):\n        super().__init__(cfg)\n        self._lab2cname = lab2cname\n        self._correct = 0\n        self._total = 0\n        self._per_class_res = None\n        self._y_true = []\n        self._y_pred = []\n        if cfg.TEST.PER_CLASS_RESULT:\n            assert lab2cname is not None\n            self._per_class_res = defaultdict(list)\n\n    def reset(self):\n        self._correct = 0\n        self._total = 0\n        self._y_true = []\n        self._y_pred = []\n        if self._per_class_res is not None:\n            self._per_class_res = defaultdict(list)\n\n    def process(self, mo, gt):\n        # mo (torch.Tensor): model output [batch, num_classes]\n        # gt (torch.LongTensor): ground truth [batch]\n        pred = mo.max(1)[1]\n        matches = pred.eq(gt).float()\n        self._correct += int(matches.sum().item())\n        self._total += gt.shape[0]\n\n        self._y_true.extend(gt.data.cpu().numpy().tolist())\n        self._y_pred.extend(pred.data.cpu().numpy().tolist())\n\n        if self._per_class_res is not None:\n            for i, label in enumerate(gt):\n                label = label.item()\n                matches_i = int(matches[i].item())\n                self._per_class_res[label].append(matches_i)\n\n    def evaluate(self):\n        results = OrderedDict()\n        acc = 100.0 * self._correct / self._total\n        err = 100.0 - acc\n        macro_f1 = 100.0 * f1_score(\n            self._y_true,\n            self._y_pred,\n            average=\"macro\",\n            labels=np.unique(self._y_true)\n        )\n\n        # The first value will be returned by trainer.test()\n        results[\"accuracy\"] = acc\n        results[\"error_rate\"] = err\n        results[\"macro_f1\"] = macro_f1\n\n        print(\n            \"=> result\\n\"\n            f\"* total: {self._total:,}\\n\"\n            f\"* correct: {self._correct:,}\\n\"\n            f\"* accuracy: {acc:.2f}%\\n\"\n            f\"* error: {err:.2f}%\\n\"\n            f\"* macro_f1: {macro_f1:.2f}%\"\n        )\n\n        if self._per_class_res is not None:\n            labels = list(self._per_class_res.keys())\n            labels.sort()\n\n            print(\"=> per-class result\")\n            accs = []\n\n            for label in labels:\n                classname = self._lab2cname[label]\n                res = self._per_class_res[label]\n                correct = sum(res)\n                total = len(res)\n                acc = 100.0 * correct / total\n                accs.append(acc)\n                print(\n                    \"* class: {} ({})\\t\"\n                    \"total: {:,}\\t\"\n                    \"correct: {:,}\\t\"\n                    \"acc: {:.2f}%\".format(\n                        label, classname, total, correct, acc\n                    )\n                )\n            mean_acc = np.mean(accs)\n            print(\"* average: {:.2f}%\".format(mean_acc))\n\n            results[\"perclass_accuracy\"] = mean_acc\n\n        if self.cfg.TEST.COMPUTE_CMAT:\n            cmat = confusion_matrix(\n                self._y_true, self._y_pred, normalize=\"true\"\n            )\n            save_path = osp.join(self.cfg.OUTPUT_DIR, \"cmat.pt\")\n            torch.save(cmat, save_path)\n            print('Confusion matrix is saved to \"{}\"'.format(save_path))\n\n        return results\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/metrics/__init__.py",
    "content": "from .accuracy import compute_accuracy\nfrom .distance import (\n    cosine_distance, compute_distance_matrix, euclidean_squared_distance\n)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py",
    "content": "def compute_accuracy(output, target, topk=(1, )):\n    \"\"\"Computes the accuracy over the k top predictions for\n    the specified values of k.\n\n    Args:\n        output (torch.Tensor): prediction matrix with shape (batch_size, num_classes).\n        target (torch.LongTensor): ground truth labels with shape (batch_size).\n        topk (tuple, optional): accuracy at top-k will be computed. For example,\n            topk=(1, 5) means accuracy at top-1 and top-5 will be computed.\n\n    Returns:\n        list: accuracy at top-k.\n    \"\"\"\n    maxk = max(topk)\n    batch_size = target.size(0)\n\n    if isinstance(output, (tuple, list)):\n        output = output[0]\n\n    _, pred = output.topk(maxk, 1, True, True)\n    pred = pred.t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n    res = []\n    for k in topk:\n        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n        acc = correct_k.mul_(100.0 / batch_size)\n        res.append(acc)\n\n    return res\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/metrics/distance.py",
    "content": "\"\"\"\nSource: https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport torch\nfrom torch.nn import functional as F\n\n\ndef compute_distance_matrix(input1, input2, metric=\"euclidean\"):\n    \"\"\"A wrapper function for computing distance matrix.\n\n    Each input matrix has the shape (n_data, feature_dim).\n\n    Args:\n        input1 (torch.Tensor): 2-D feature matrix.\n        input2 (torch.Tensor): 2-D feature matrix.\n        metric (str, optional): \"euclidean\" or \"cosine\".\n            Default is \"euclidean\".\n\n    Returns:\n        torch.Tensor: distance matrix.\n    \"\"\"\n    # check input\n    assert isinstance(input1, torch.Tensor)\n    assert isinstance(input2, torch.Tensor)\n    assert input1.dim() == 2, \"Expected 2-D tensor, but got {}-D\".format(\n        input1.dim()\n    )\n    assert input2.dim() == 2, \"Expected 2-D tensor, but got {}-D\".format(\n        input2.dim()\n    )\n    assert input1.size(1) == input2.size(1)\n\n    if metric == \"euclidean\":\n        distmat = euclidean_squared_distance(input1, input2)\n    elif metric == \"cosine\":\n        distmat = cosine_distance(input1, input2)\n    else:\n        raise ValueError(\n            \"Unknown distance metric: {}. \"\n            'Please choose either \"euclidean\" or \"cosine\"'.format(metric)\n        )\n\n    return distmat\n\n\ndef euclidean_squared_distance(input1, input2):\n    \"\"\"Computes euclidean squared distance.\n\n    Args:\n        input1 (torch.Tensor): 2-D feature matrix.\n        input2 (torch.Tensor): 2-D feature matrix.\n\n    Returns:\n        torch.Tensor: distance matrix.\n    \"\"\"\n    m, n = input1.size(0), input2.size(0)\n    mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)\n    mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()\n    distmat = mat1 + mat2\n    distmat.addmm_(1, -2, input1, input2.t())\n    return distmat\n\n\ndef cosine_distance(input1, input2):\n    \"\"\"Computes cosine distance.\n\n    Args:\n        input1 (torch.Tensor): 2-D feature matrix.\n        input2 (torch.Tensor): 2-D feature matrix.\n\n    Returns:\n        torch.Tensor: distance matrix.\n    \"\"\"\n    input1_normed = F.normalize(input1, p=2, dim=1)\n    input2_normed = F.normalize(input2, p=2, dim=1)\n    distmat = 1 - torch.mm(input1_normed, input2_normed.t())\n    return distmat\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/__init__.py",
    "content": "from .head import HEAD_REGISTRY, build_head\nfrom .network import NETWORK_REGISTRY, build_network\nfrom .backbone import BACKBONE_REGISTRY, Backbone, build_backbone\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py",
    "content": "from .build import build_backbone, BACKBONE_REGISTRY  # isort:skip\nfrom .backbone import Backbone  # isort:skip\n\nfrom .vgg import vgg16\nfrom .resnet import (\n    resnet18, resnet34, resnet50, resnet101, resnet152, resnet18_ms_l1,\n    resnet50_ms_l1, resnet18_ms_l12, resnet50_ms_l12, resnet101_ms_l1,\n    resnet18_ms_l123, resnet50_ms_l123, resnet101_ms_l12, resnet101_ms_l123,\n    resnet18_efdmix_l1, resnet50_efdmix_l1, resnet18_efdmix_l12,\n    resnet50_efdmix_l12, resnet101_efdmix_l1, resnet18_efdmix_l123,\n    resnet50_efdmix_l123, resnet101_efdmix_l12, resnet101_efdmix_l123\n)\nfrom .alexnet import alexnet\nfrom .mobilenetv2 import mobilenetv2\nfrom .wide_resnet import wide_resnet_16_4, wide_resnet_28_2\nfrom .cnn_digitsdg import cnn_digitsdg\nfrom .efficientnet import (\n    efficientnet_b0, efficientnet_b1, efficientnet_b2, efficientnet_b3,\n    efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7\n)\nfrom .shufflenetv2 import (\n    shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5,\n    shufflenet_v2_x2_0\n)\nfrom .cnn_digitsingle import cnn_digitsingle\nfrom .preact_resnet18 import preact_resnet18\nfrom .cnn_digit5_m3sda import cnn_digit5_m3sda\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\nmodel_urls = {\n    \"alexnet\": \"https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth\",\n}\n\n\nclass AlexNet(Backbone):\n\n    def __init__(self):\n        super().__init__()\n        self.features = nn.Sequential(\n            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(64, 192, kernel_size=5, padding=2),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.MaxPool2d(kernel_size=3, stride=2),\n        )\n        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))\n        # Note that self.classifier outputs features rather than logits\n        self.classifier = nn.Sequential(\n            nn.Dropout(),\n            nn.Linear(256 * 6 * 6, 4096),\n            nn.ReLU(inplace=True),\n            nn.Dropout(),\n            nn.Linear(4096, 4096),\n            nn.ReLU(inplace=True),\n        )\n\n        self._out_features = 4096\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        return self.classifier(x)\n\n\ndef init_pretrained_weights(model, model_url):\n    pretrain_dict = model_zoo.load_url(model_url)\n    model.load_state_dict(pretrain_dict, strict=False)\n\n\n@BACKBONE_REGISTRY.register()\ndef alexnet(pretrained=True, **kwargs):\n    model = AlexNet()\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"alexnet\"])\n\n    return model\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py",
    "content": "import torch.nn as nn\n\n\nclass Backbone(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n    def forward(self):\n        pass\n\n    @property\n    def out_features(self):\n        \"\"\"Output feature dimension.\"\"\"\n        if self.__dict__.get(\"_out_features\") is None:\n            return None\n        return self._out_features\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py",
    "content": "from dassl.utils import Registry, check_availability\n\nBACKBONE_REGISTRY = Registry(\"BACKBONE\")\n\n\ndef build_backbone(name, verbose=True, **kwargs):\n    avai_backbones = BACKBONE_REGISTRY.registered_names()\n    check_availability(name, avai_backbones)\n    if verbose:\n        print(\"Backbone: {}\".format(name))\n    return BACKBONE_REGISTRY.get(name)(**kwargs)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py",
    "content": "\"\"\"\nReference\n\nhttps://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA\n\"\"\"\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\n\nclass FeatureExtractor(Backbone):\n\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=2)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=1, padding=2)\n        self.bn2 = nn.BatchNorm2d(64)\n        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2)\n        self.bn3 = nn.BatchNorm2d(128)\n        self.fc1 = nn.Linear(8192, 3072)\n        self.bn1_fc = nn.BatchNorm1d(3072)\n        self.fc2 = nn.Linear(3072, 2048)\n        self.bn2_fc = nn.BatchNorm1d(2048)\n\n        self._out_features = 2048\n\n    def _check_input(self, x):\n        H, W = x.shape[2:]\n        assert (\n            H == 32 and W == 32\n        ), \"Input to network must be 32x32, \" \"but got {}x{}\".format(H, W)\n\n    def forward(self, x):\n        self._check_input(x)\n        x = F.relu(self.bn1(self.conv1(x)))\n        x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)\n        x = F.relu(self.bn2(self.conv2(x)))\n        x = F.max_pool2d(x, stride=2, kernel_size=3, padding=1)\n        x = F.relu(self.bn3(self.conv3(x)))\n        x = x.view(x.size(0), 8192)\n        x = F.relu(self.bn1_fc(self.fc1(x)))\n        x = F.dropout(x, training=self.training)\n        x = F.relu(self.bn2_fc(self.fc2(x)))\n        return x\n\n\n@BACKBONE_REGISTRY.register()\ndef cnn_digit5_m3sda(**kwargs):\n    \"\"\"\n    This architecture was used for the Digit-5 dataset in:\n\n        - Peng et al. Moment Matching for Multi-Source\n        Domain Adaptation. ICCV 2019.\n    \"\"\"\n    return FeatureExtractor()\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsdg.py",
    "content": "import torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.utils import init_network_weights\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\n\nclass Convolution(nn.Module):\n\n    def __init__(self, c_in, c_out):\n        super().__init__()\n        self.conv = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1)\n        self.relu = nn.ReLU(True)\n\n    def forward(self, x):\n        return self.relu(self.conv(x))\n\n\nclass ConvNet(Backbone):\n\n    def __init__(self, c_hidden=64):\n        super().__init__()\n        self.conv1 = Convolution(3, c_hidden)\n        self.conv2 = Convolution(c_hidden, c_hidden)\n        self.conv3 = Convolution(c_hidden, c_hidden)\n        self.conv4 = Convolution(c_hidden, c_hidden)\n\n        self._out_features = 2**2 * c_hidden\n\n    def _check_input(self, x):\n        H, W = x.shape[2:]\n        assert (\n            H == 32 and W == 32\n        ), \"Input to network must be 32x32, \" \"but got {}x{}\".format(H, W)\n\n    def forward(self, x):\n        self._check_input(x)\n        x = self.conv1(x)\n        x = F.max_pool2d(x, 2)\n        x = self.conv2(x)\n        x = F.max_pool2d(x, 2)\n        x = self.conv3(x)\n        x = F.max_pool2d(x, 2)\n        x = self.conv4(x)\n        x = F.max_pool2d(x, 2)\n        return x.view(x.size(0), -1)\n\n\n@BACKBONE_REGISTRY.register()\ndef cnn_digitsdg(**kwargs):\n    \"\"\"\n    This architecture was used for DigitsDG dataset in:\n\n        - Zhou et al. Deep Domain-Adversarial Image Generation\n        for Domain Generalisation. AAAI 2020.\n    \"\"\"\n    model = ConvNet(c_hidden=64)\n    init_network_weights(model, init_type=\"kaiming\")\n    return model\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsingle.py",
    "content": "\"\"\"\nThis model is built based on\nhttps://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py\n\"\"\"\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.utils import init_network_weights\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\n\nclass CNN(Backbone):\n\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.Conv2d(3, 64, 5)\n        self.conv2 = nn.Conv2d(64, 128, 5)\n        self.fc3 = nn.Linear(5 * 5 * 128, 1024)\n        self.fc4 = nn.Linear(1024, 1024)\n\n        self._out_features = 1024\n\n    def _check_input(self, x):\n        H, W = x.shape[2:]\n        assert (\n            H == 32 and W == 32\n        ), \"Input to network must be 32x32, \" \"but got {}x{}\".format(H, W)\n\n    def forward(self, x):\n        self._check_input(x)\n        x = self.conv1(x)\n        x = F.relu(x)\n        x = F.max_pool2d(x, 2)\n\n        x = self.conv2(x)\n        x = F.relu(x)\n        x = F.max_pool2d(x, 2)\n\n        x = x.view(x.size(0), -1)\n\n        x = self.fc3(x)\n        x = F.relu(x)\n\n        x = self.fc4(x)\n        x = F.relu(x)\n\n        return x\n\n\n@BACKBONE_REGISTRY.register()\ndef cnn_digitsingle(**kwargs):\n    model = CNN()\n    init_network_weights(model, init_type=\"kaiming\")\n    return model\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/__init__.py",
    "content": "\"\"\"\nSource: https://github.com/lukemelas/EfficientNet-PyTorch.\n\"\"\"\n__version__ = \"0.6.4\"\nfrom .model import (\n    EfficientNet, efficientnet_b0, efficientnet_b1, efficientnet_b2,\n    efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6,\n    efficientnet_b7\n)\nfrom .utils import (\n    BlockArgs, BlockDecoder, GlobalParams, efficientnet, get_model_params\n)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/model.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .utils import (\n    Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats,\n    get_model_params, efficientnet_params, get_same_padding_conv2d,\n    load_pretrained_weights, calculate_output_image_size\n)\nfrom ..build import BACKBONE_REGISTRY\nfrom ..backbone import Backbone\n\n\nclass MBConvBlock(nn.Module):\n    \"\"\"\n    Mobile Inverted Residual Bottleneck Block\n\n    Args:\n        block_args (namedtuple): BlockArgs, see above\n        global_params (namedtuple): GlobalParam, see above\n\n    Attributes:\n        has_se (bool): Whether the block contains a Squeeze and Excitation layer.\n    \"\"\"\n\n    def __init__(self, block_args, global_params, image_size=None):\n        super().__init__()\n        self._block_args = block_args\n        self._bn_mom = 1 - global_params.batch_norm_momentum\n        self._bn_eps = global_params.batch_norm_epsilon\n        self.has_se = (self._block_args.se_ratio\n                       is not None) and (0 < self._block_args.se_ratio <= 1)\n        self.id_skip = block_args.id_skip  # skip connection and drop connect\n\n        # Expansion phase\n        inp = self._block_args.input_filters  # number of input channels\n        oup = (\n            self._block_args.input_filters * self._block_args.expand_ratio\n        )  # number of output channels\n        if self._block_args.expand_ratio != 1:\n            Conv2d = get_same_padding_conv2d(image_size=image_size)\n            self._expand_conv = Conv2d(\n                in_channels=inp, out_channels=oup, kernel_size=1, bias=False\n            )\n            self._bn0 = nn.BatchNorm2d(\n                num_features=oup, momentum=self._bn_mom, eps=self._bn_eps\n            )\n            # image_size = calculate_output_image_size(image_size, 1) <-- this would do nothing\n\n        # Depthwise convolution phase\n        k = self._block_args.kernel_size\n        s = self._block_args.stride\n        Conv2d = get_same_padding_conv2d(image_size=image_size)\n        self._depthwise_conv = Conv2d(\n            in_channels=oup,\n            out_channels=oup,\n            groups=oup,  # groups makes it depthwise\n            kernel_size=k,\n            stride=s,\n            bias=False,\n        )\n        self._bn1 = nn.BatchNorm2d(\n            num_features=oup, momentum=self._bn_mom, eps=self._bn_eps\n        )\n        image_size = calculate_output_image_size(image_size, s)\n\n        # Squeeze and Excitation layer, if desired\n        if self.has_se:\n            Conv2d = get_same_padding_conv2d(image_size=(1, 1))\n            num_squeezed_channels = max(\n                1,\n                int(\n                    self._block_args.input_filters * self._block_args.se_ratio\n                )\n            )\n            self._se_reduce = Conv2d(\n                in_channels=oup,\n                out_channels=num_squeezed_channels,\n                kernel_size=1\n            )\n            self._se_expand = Conv2d(\n                in_channels=num_squeezed_channels,\n                out_channels=oup,\n                kernel_size=1\n            )\n\n        # Output phase\n        final_oup = self._block_args.output_filters\n        Conv2d = get_same_padding_conv2d(image_size=image_size)\n        self._project_conv = Conv2d(\n            in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False\n        )\n        self._bn2 = nn.BatchNorm2d(\n            num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps\n        )\n        self._swish = MemoryEfficientSwish()\n\n    def forward(self, inputs, drop_connect_rate=None):\n        \"\"\"\n        :param inputs: input tensor\n        :param drop_connect_rate: drop connect rate (float, between 0 and 1)\n        :return: output of block\n        \"\"\"\n\n        # Expansion and Depthwise Convolution\n        x = inputs\n        if self._block_args.expand_ratio != 1:\n            x = self._swish(self._bn0(self._expand_conv(inputs)))\n        x = self._swish(self._bn1(self._depthwise_conv(x)))\n\n        # Squeeze and Excitation\n        if self.has_se:\n            x_squeezed = F.adaptive_avg_pool2d(x, 1)\n            x_squeezed = self._se_expand(\n                self._swish(self._se_reduce(x_squeezed))\n            )\n            x = torch.sigmoid(x_squeezed) * x\n\n        x = self._bn2(self._project_conv(x))\n\n        # Skip connection and drop connect\n        input_filters, output_filters = (\n            self._block_args.input_filters,\n            self._block_args.output_filters,\n        )\n        if (\n            self.id_skip and self._block_args.stride == 1\n            and input_filters == output_filters\n        ):\n            if drop_connect_rate:\n                x = drop_connect(\n                    x, p=drop_connect_rate, training=self.training\n                )\n            x = x + inputs  # skip connection\n        return x\n\n    def set_swish(self, memory_efficient=True):\n        \"\"\"Sets swish function as memory efficient (for training) or standard (for export)\"\"\"\n        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()\n\n\nclass EfficientNet(Backbone):\n    \"\"\"\n    An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods\n\n    Args:\n        blocks_args (list): A list of BlockArgs to construct blocks\n        global_params (namedtuple): A set of GlobalParams shared between blocks\n\n    Example:\n        model = EfficientNet.from_pretrained('efficientnet-b0')\n\n    \"\"\"\n\n    def __init__(self, blocks_args=None, global_params=None):\n        super().__init__()\n        assert isinstance(blocks_args, list), \"blocks_args should be a list\"\n        assert len(blocks_args) > 0, \"block args must be greater than 0\"\n        self._global_params = global_params\n        self._blocks_args = blocks_args\n\n        # Batch norm parameters\n        bn_mom = 1 - self._global_params.batch_norm_momentum\n        bn_eps = self._global_params.batch_norm_epsilon\n\n        # Get stem static or dynamic convolution depending on image size\n        image_size = global_params.image_size\n        Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)\n\n        # Stem\n        in_channels = 3  # rgb\n        out_channels = round_filters(\n            32, self._global_params\n        )  # number of output channels\n        self._conv_stem = Conv2d(\n            in_channels, out_channels, kernel_size=3, stride=2, bias=False\n        )\n        self._bn0 = nn.BatchNorm2d(\n            num_features=out_channels, momentum=bn_mom, eps=bn_eps\n        )\n        image_size = calculate_output_image_size(image_size, 2)\n\n        # Build blocks\n        self._blocks = nn.ModuleList([])\n        for block_args in self._blocks_args:\n\n            # Update block input and output filters based on depth multiplier.\n            block_args = block_args._replace(\n                input_filters=round_filters(\n                    block_args.input_filters, self._global_params\n                ),\n                output_filters=round_filters(\n                    block_args.output_filters, self._global_params\n                ),\n                num_repeat=round_repeats(\n                    block_args.num_repeat, self._global_params\n                ),\n            )\n\n            # The first block needs to take care of stride and filter size increase.\n            self._blocks.append(\n                MBConvBlock(\n                    block_args, self._global_params, image_size=image_size\n                )\n            )\n            image_size = calculate_output_image_size(\n                image_size, block_args.stride\n            )\n            if block_args.num_repeat > 1:\n                block_args = block_args._replace(\n                    input_filters=block_args.output_filters, stride=1\n                )\n            for _ in range(block_args.num_repeat - 1):\n                self._blocks.append(\n                    MBConvBlock(\n                        block_args, self._global_params, image_size=image_size\n                    )\n                )\n                # image_size = calculate_output_image_size(image_size, block_args.stride) # ?\n\n        # Head\n        in_channels = block_args.output_filters  # output of final block\n        out_channels = round_filters(1280, self._global_params)\n        Conv2d = get_same_padding_conv2d(image_size=image_size)\n        self._conv_head = Conv2d(\n            in_channels, out_channels, kernel_size=1, bias=False\n        )\n        self._bn1 = nn.BatchNorm2d(\n            num_features=out_channels, momentum=bn_mom, eps=bn_eps\n        )\n\n        # Final linear layer\n        self._avg_pooling = nn.AdaptiveAvgPool2d(1)\n        self._dropout = nn.Dropout(self._global_params.dropout_rate)\n        # self._fc = nn.Linear(out_channels, self._global_params.num_classes)\n        self._swish = MemoryEfficientSwish()\n\n        self._out_features = out_channels\n\n    def set_swish(self, memory_efficient=True):\n        \"\"\"Sets swish function as memory efficient (for training) or standard (for export)\"\"\"\n        self._swish = MemoryEfficientSwish() if memory_efficient else Swish()\n        for block in self._blocks:\n            block.set_swish(memory_efficient)\n\n    def extract_features(self, inputs):\n        \"\"\"Returns output of the final convolution layer\"\"\"\n\n        # Stem\n        x = self._swish(self._bn0(self._conv_stem(inputs)))\n\n        # Blocks\n        for idx, block in enumerate(self._blocks):\n            drop_connect_rate = self._global_params.drop_connect_rate\n            if drop_connect_rate:\n                drop_connect_rate *= float(idx) / len(self._blocks)\n            x = block(x, drop_connect_rate=drop_connect_rate)\n\n        # Head\n        x = self._swish(self._bn1(self._conv_head(x)))\n\n        return x\n\n    def forward(self, inputs):\n        \"\"\"\n        Calls extract_features to extract features, applies\n        final linear layer, and returns logits.\n        \"\"\"\n        bs = inputs.size(0)\n        # Convolution layers\n        x = self.extract_features(inputs)\n\n        # Pooling and final linear layer\n        x = self._avg_pooling(x)\n        x = x.view(bs, -1)\n        x = self._dropout(x)\n        # x = self._fc(x)\n        return x\n\n    @classmethod\n    def from_name(cls, model_name, override_params=None):\n        cls._check_model_name_is_valid(model_name)\n        blocks_args, global_params = get_model_params(\n            model_name, override_params\n        )\n        return cls(blocks_args, global_params)\n\n    @classmethod\n    def from_pretrained(\n        cls, model_name, advprop=False, num_classes=1000, in_channels=3\n    ):\n        model = cls.from_name(\n            model_name, override_params={\"num_classes\": num_classes}\n        )\n        load_pretrained_weights(\n            model, model_name, load_fc=(num_classes == 1000), advprop=advprop\n        )\n        model._change_in_channels(in_channels)\n        return model\n\n    @classmethod\n    def get_image_size(cls, model_name):\n        cls._check_model_name_is_valid(model_name)\n        _, _, res, _ = efficientnet_params(model_name)\n        return res\n\n    @classmethod\n    def _check_model_name_is_valid(cls, model_name):\n        \"\"\"Validates model name.\"\"\"\n        valid_models = [\"efficientnet-b\" + str(i) for i in range(9)]\n        if model_name not in valid_models:\n            raise ValueError(\n                \"model_name should be one of: \" + \", \".join(valid_models)\n            )\n\n    def _change_in_channels(model, in_channels):\n        if in_channels != 3:\n            Conv2d = get_same_padding_conv2d(\n                image_size=model._global_params.image_size\n            )\n            out_channels = round_filters(32, model._global_params)\n            model._conv_stem = Conv2d(\n                in_channels, out_channels, kernel_size=3, stride=2, bias=False\n            )\n\n\ndef build_efficientnet(name, pretrained):\n    if pretrained:\n        return EfficientNet.from_pretrained(\"efficientnet-{}\".format(name))\n    else:\n        return EfficientNet.from_name(\"efficientnet-{}\".format(name))\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b0(pretrained=True, **kwargs):\n    return build_efficientnet(\"b0\", pretrained)\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b1(pretrained=True, **kwargs):\n    return build_efficientnet(\"b1\", pretrained)\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b2(pretrained=True, **kwargs):\n    return build_efficientnet(\"b2\", pretrained)\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b3(pretrained=True, **kwargs):\n    return build_efficientnet(\"b3\", pretrained)\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b4(pretrained=True, **kwargs):\n    return build_efficientnet(\"b4\", pretrained)\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b5(pretrained=True, **kwargs):\n    return build_efficientnet(\"b5\", pretrained)\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b6(pretrained=True, **kwargs):\n    return build_efficientnet(\"b6\", pretrained)\n\n\n@BACKBONE_REGISTRY.register()\ndef efficientnet_b7(pretrained=True, **kwargs):\n    return build_efficientnet(\"b7\", pretrained)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/utils.py",
    "content": "\"\"\"\nThis file contains helper functions for building the model and for loading model parameters.\nThese helper functions are built to mirror those in the official TensorFlow implementation.\n\"\"\"\n\nimport re\nimport math\nimport collections\nfrom functools import partial\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.utils import model_zoo\n\n########################################################################\n############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############\n########################################################################\n\n# Parameters for the entire model (stem, all blocks, and head)\nGlobalParams = collections.namedtuple(\n    \"GlobalParams\",\n    [\n        \"batch_norm_momentum\",\n        \"batch_norm_epsilon\",\n        \"dropout_rate\",\n        \"num_classes\",\n        \"width_coefficient\",\n        \"depth_coefficient\",\n        \"depth_divisor\",\n        \"min_depth\",\n        \"drop_connect_rate\",\n        \"image_size\",\n    ],\n)\n\n# Parameters for an individual model block\nBlockArgs = collections.namedtuple(\n    \"BlockArgs\",\n    [\n        \"kernel_size\",\n        \"num_repeat\",\n        \"input_filters\",\n        \"output_filters\",\n        \"expand_ratio\",\n        \"id_skip\",\n        \"stride\",\n        \"se_ratio\",\n    ],\n)\n\n# Change namedtuple defaults\nGlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)\nBlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)\n\n\nclass SwishImplementation(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, i):\n        result = i * torch.sigmoid(i)\n        ctx.save_for_backward(i)\n        return result\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        i = ctx.saved_variables[0]\n        sigmoid_i = torch.sigmoid(i)\n        return grad_output * (sigmoid_i * (1 + i * (1-sigmoid_i)))\n\n\nclass MemoryEfficientSwish(nn.Module):\n\n    def forward(self, x):\n        return SwishImplementation.apply(x)\n\n\nclass Swish(nn.Module):\n\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\ndef round_filters(filters, global_params):\n    \"\"\"Calculate and round number of filters based on depth multiplier.\"\"\"\n    multiplier = global_params.width_coefficient\n    if not multiplier:\n        return filters\n    divisor = global_params.depth_divisor\n    min_depth = global_params.min_depth\n    filters *= multiplier\n    min_depth = min_depth or divisor\n    new_filters = max(min_depth, int(filters + divisor/2) // divisor * divisor)\n    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%\n        new_filters += divisor\n    return int(new_filters)\n\n\ndef round_repeats(repeats, global_params):\n    \"\"\"Round number of filters based on depth multiplier.\"\"\"\n    multiplier = global_params.depth_coefficient\n    if not multiplier:\n        return repeats\n    return int(math.ceil(multiplier * repeats))\n\n\ndef drop_connect(inputs, p, training):\n    \"\"\"Drop connect.\"\"\"\n    if not training:\n        return inputs\n    batch_size = inputs.shape[0]\n    keep_prob = 1 - p\n    random_tensor = keep_prob\n    random_tensor += torch.rand(\n        [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device\n    )\n    binary_tensor = torch.floor(random_tensor)\n    output = inputs / keep_prob * binary_tensor\n    return output\n\n\ndef get_same_padding_conv2d(image_size=None):\n    \"\"\"Chooses static padding if you have specified an image size, and dynamic padding otherwise.\n    Static padding is necessary for ONNX exporting of models.\"\"\"\n    if image_size is None:\n        return Conv2dDynamicSamePadding\n    else:\n        return partial(Conv2dStaticSamePadding, image_size=image_size)\n\n\ndef get_width_and_height_from_size(x):\n    \"\"\"Obtains width and height from a int or tuple\"\"\"\n    if isinstance(x, int):\n        return x, x\n    if isinstance(x, list) or isinstance(x, tuple):\n        return x\n    else:\n        raise TypeError()\n\n\ndef calculate_output_image_size(input_image_size, stride):\n    \"\"\"\n    Calculates the output image size when using Conv2dSamePadding with a stride.\n    Necessary for static padding. Thanks to mannatsingh for pointing this out.\n    \"\"\"\n    if input_image_size is None:\n        return None\n    image_height, image_width = get_width_and_height_from_size(\n        input_image_size\n    )\n    stride = stride if isinstance(stride, int) else stride[0]\n    image_height = int(math.ceil(image_height / stride))\n    image_width = int(math.ceil(image_width / stride))\n    return [image_height, image_width]\n\n\nclass Conv2dDynamicSamePadding(nn.Conv2d):\n    \"\"\"2D Convolutions like TensorFlow, for a dynamic image size\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        stride=1,\n        dilation=1,\n        groups=1,\n        bias=True,\n    ):\n        super().__init__(\n            in_channels, out_channels, kernel_size, stride, 0, dilation,\n            groups, bias\n        )\n        self.stride = self.stride if len(self.stride\n                                         ) == 2 else [self.stride[0]] * 2\n\n    def forward(self, x):\n        ih, iw = x.size()[-2:]\n        kh, kw = self.weight.size()[-2:]\n        sh, sw = self.stride\n        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)\n        pad_h = max(\n            (oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0\n        )\n        pad_w = max(\n            (ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0\n        )\n        if pad_h > 0 or pad_w > 0:\n            x = F.pad(\n                x,\n                [pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2]\n            )\n        return F.conv2d(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.groups,\n        )\n\n\nclass Conv2dStaticSamePadding(nn.Conv2d):\n    \"\"\"2D Convolutions like TensorFlow, for a fixed image size\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        image_size=None,\n        **kwargs\n    ):\n        super().__init__(in_channels, out_channels, kernel_size, **kwargs)\n        self.stride = self.stride if len(self.stride\n                                         ) == 2 else [self.stride[0]] * 2\n\n        # Calculate padding based on image size and save it\n        assert image_size is not None\n        ih, iw = (image_size,\n                  image_size) if isinstance(image_size, int) else image_size\n        kh, kw = self.weight.size()[-2:]\n        sh, sw = self.stride\n        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)\n        pad_h = max(\n            (oh-1) * self.stride[0] + (kh-1) * self.dilation[0] + 1 - ih, 0\n        )\n        pad_w = max(\n            (ow-1) * self.stride[1] + (kw-1) * self.dilation[1] + 1 - iw, 0\n        )\n        if pad_h > 0 or pad_w > 0:\n            self.static_padding = nn.ZeroPad2d(\n                (pad_w // 2, pad_w - pad_w//2, pad_h // 2, pad_h - pad_h//2)\n            )\n        else:\n            self.static_padding = Identity()\n\n    def forward(self, x):\n        x = self.static_padding(x)\n        x = F.conv2d(\n            x,\n            self.weight,\n            self.bias,\n            self.stride,\n            self.padding,\n            self.dilation,\n            self.groups,\n        )\n        return x\n\n\nclass Identity(nn.Module):\n\n    def __init__(self, ):\n        super(Identity, self).__init__()\n\n    def forward(self, input):\n        return input\n\n\n########################################################################\n############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############\n########################################################################\n\n\ndef efficientnet_params(model_name):\n    \"\"\"Map EfficientNet model name to parameter coefficients.\"\"\"\n    params_dict = {\n        # Coefficients:   width,depth,res,dropout\n        \"efficientnet-b0\": (1.0, 1.0, 224, 0.2),\n        \"efficientnet-b1\": (1.0, 1.1, 240, 0.2),\n        \"efficientnet-b2\": (1.1, 1.2, 260, 0.3),\n        \"efficientnet-b3\": (1.2, 1.4, 300, 0.3),\n        \"efficientnet-b4\": (1.4, 1.8, 380, 0.4),\n        \"efficientnet-b5\": (1.6, 2.2, 456, 0.4),\n        \"efficientnet-b6\": (1.8, 2.6, 528, 0.5),\n        \"efficientnet-b7\": (2.0, 3.1, 600, 0.5),\n        \"efficientnet-b8\": (2.2, 3.6, 672, 0.5),\n        \"efficientnet-l2\": (4.3, 5.3, 800, 0.5),\n    }\n    return params_dict[model_name]\n\n\nclass BlockDecoder(object):\n    \"\"\"Block Decoder for readability, straight from the official TensorFlow repository\"\"\"\n\n    @staticmethod\n    def _decode_block_string(block_string):\n        \"\"\"Gets a block through a string notation of arguments.\"\"\"\n        assert isinstance(block_string, str)\n\n        ops = block_string.split(\"_\")\n        options = {}\n        for op in ops:\n            splits = re.split(r\"(\\d.*)\", op)\n            if len(splits) >= 2:\n                key, value = splits[:2]\n                options[key] = value\n\n        # Check stride\n        assert (\"s\" in options and len(options[\"s\"]) == 1) or (\n            len(options[\"s\"]) == 2 and options[\"s\"][0] == options[\"s\"][1]\n        )\n\n        return BlockArgs(\n            kernel_size=int(options[\"k\"]),\n            num_repeat=int(options[\"r\"]),\n            input_filters=int(options[\"i\"]),\n            output_filters=int(options[\"o\"]),\n            expand_ratio=int(options[\"e\"]),\n            id_skip=(\"noskip\" not in block_string),\n            se_ratio=float(options[\"se\"]) if \"se\" in options else None,\n            stride=[int(options[\"s\"][0])],\n        )\n\n    @staticmethod\n    def _encode_block_string(block):\n        \"\"\"Encodes a block to a string.\"\"\"\n        args = [\n            \"r%d\" % block.num_repeat,\n            \"k%d\" % block.kernel_size,\n            \"s%d%d\" % (block.strides[0], block.strides[1]),\n            \"e%s\" % block.expand_ratio,\n            \"i%d\" % block.input_filters,\n            \"o%d\" % block.output_filters,\n        ]\n        if 0 < block.se_ratio <= 1:\n            args.append(\"se%s\" % block.se_ratio)\n        if block.id_skip is False:\n            args.append(\"noskip\")\n        return \"_\".join(args)\n\n    @staticmethod\n    def decode(string_list):\n        \"\"\"\n        Decodes a list of string notations to specify blocks inside the network.\n\n        :param string_list: a list of strings, each string is a notation of block\n        :return: a list of BlockArgs namedtuples of block args\n        \"\"\"\n        assert isinstance(string_list, list)\n        blocks_args = []\n        for block_string in string_list:\n            blocks_args.append(BlockDecoder._decode_block_string(block_string))\n        return blocks_args\n\n    @staticmethod\n    def encode(blocks_args):\n        \"\"\"\n        Encodes a list of BlockArgs to a list of strings.\n\n        :param blocks_args: a list of BlockArgs namedtuples of block args\n        :return: a list of strings, each string is a notation of block\n        \"\"\"\n        block_strings = []\n        for block in blocks_args:\n            block_strings.append(BlockDecoder._encode_block_string(block))\n        return block_strings\n\n\ndef efficientnet(\n    width_coefficient=None,\n    depth_coefficient=None,\n    dropout_rate=0.2,\n    drop_connect_rate=0.2,\n    image_size=None,\n    num_classes=1000,\n):\n    \"\"\"Creates a efficientnet model.\"\"\"\n\n    blocks_args = [\n        \"r1_k3_s11_e1_i32_o16_se0.25\",\n        \"r2_k3_s22_e6_i16_o24_se0.25\",\n        \"r2_k5_s22_e6_i24_o40_se0.25\",\n        \"r3_k3_s22_e6_i40_o80_se0.25\",\n        \"r3_k5_s11_e6_i80_o112_se0.25\",\n        \"r4_k5_s22_e6_i112_o192_se0.25\",\n        \"r1_k3_s11_e6_i192_o320_se0.25\",\n    ]\n    blocks_args = BlockDecoder.decode(blocks_args)\n\n    global_params = GlobalParams(\n        batch_norm_momentum=0.99,\n        batch_norm_epsilon=1e-3,\n        dropout_rate=dropout_rate,\n        drop_connect_rate=drop_connect_rate,\n        # data_format='channels_last',  # removed, this is always true in PyTorch\n        num_classes=num_classes,\n        width_coefficient=width_coefficient,\n        depth_coefficient=depth_coefficient,\n        depth_divisor=8,\n        min_depth=None,\n        image_size=image_size,\n    )\n\n    return blocks_args, global_params\n\n\ndef get_model_params(model_name, override_params):\n    \"\"\"Get the block args and global params for a given model\"\"\"\n    if model_name.startswith(\"efficientnet\"):\n        w, d, s, p = efficientnet_params(model_name)\n        # note: all models have drop connect rate = 0.2\n        blocks_args, global_params = efficientnet(\n            width_coefficient=w,\n            depth_coefficient=d,\n            dropout_rate=p,\n            image_size=s\n        )\n    else:\n        raise NotImplementedError(\n            \"model name is not pre-defined: %s\" % model_name\n        )\n    if override_params:\n        # ValueError will be raised here if override_params has fields not included in global_params.\n        global_params = global_params._replace(**override_params)\n    return blocks_args, global_params\n\n\nurl_map = {\n    \"efficientnet-b0\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth\",\n    \"efficientnet-b1\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth\",\n    \"efficientnet-b2\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth\",\n    \"efficientnet-b3\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth\",\n    \"efficientnet-b4\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth\",\n    \"efficientnet-b5\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth\",\n    \"efficientnet-b6\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth\",\n    \"efficientnet-b7\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth\",\n}\n\nurl_map_advprop = {\n    \"efficientnet-b0\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth\",\n    \"efficientnet-b1\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth\",\n    \"efficientnet-b2\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth\",\n    \"efficientnet-b3\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth\",\n    \"efficientnet-b4\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth\",\n    \"efficientnet-b5\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth\",\n    \"efficientnet-b6\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth\",\n    \"efficientnet-b7\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth\",\n    \"efficientnet-b8\":\n    \"https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth\",\n}\n\n\ndef load_pretrained_weights(model, model_name, load_fc=True, advprop=False):\n    \"\"\"Loads pretrained weights, and downloads if loading for the first time.\"\"\"\n    # AutoAugment or Advprop (different preprocessing)\n    url_map_ = url_map_advprop if advprop else url_map\n    state_dict = model_zoo.load_url(url_map_[model_name])\n    model.load_state_dict(state_dict, strict=False)\n    \"\"\"\n    if load_fc:\n        model.load_state_dict(state_dict)\n    else:\n        state_dict.pop('_fc.weight')\n        state_dict.pop('_fc.bias')\n        res = model.load_state_dict(state_dict, strict=False)\n        assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'\n\n    print('Loaded pretrained weights for {}'.format(model_name))\n    \"\"\"\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py",
    "content": "import torch.utils.model_zoo as model_zoo\nfrom torch import nn\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\nmodel_urls = {\n    \"mobilenet_v2\":\n    \"https://download.pytorch.org/models/mobilenet_v2-b0353104.pth\",\n}\n\n\ndef _make_divisible(v, divisor, min_value=None):\n    \"\"\"\n    This function is taken from the original tf repo.\n    It ensures that all layers have a channel number that is divisible by 8\n    It can be seen here:\n    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py\n    :param v:\n    :param divisor:\n    :param min_value:\n    :return:\n    \"\"\"\n    if min_value is None:\n        min_value = divisor\n    new_v = max(min_value, int(v + divisor/2) // divisor * divisor)\n    # Make sure that round down does not go down by more than 10%.\n    if new_v < 0.9 * v:\n        new_v += divisor\n    return new_v\n\n\nclass ConvBNReLU(nn.Sequential):\n\n    def __init__(\n        self, in_planes, out_planes, kernel_size=3, stride=1, groups=1\n    ):\n        padding = (kernel_size-1) // 2\n        super().__init__(\n            nn.Conv2d(\n                in_planes,\n                out_planes,\n                kernel_size,\n                stride,\n                padding,\n                groups=groups,\n                bias=False,\n            ),\n            nn.BatchNorm2d(out_planes),\n            nn.ReLU6(inplace=True),\n        )\n\n\nclass InvertedResidual(nn.Module):\n\n    def __init__(self, inp, oup, stride, expand_ratio):\n        super().__init__()\n        self.stride = stride\n        assert stride in [1, 2]\n\n        hidden_dim = int(round(inp * expand_ratio))\n        self.use_res_connect = self.stride == 1 and inp == oup\n\n        layers = []\n        if expand_ratio != 1:\n            # pw\n            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))\n        layers.extend(\n            [\n                # dw\n                ConvBNReLU(\n                    hidden_dim, hidden_dim, stride=stride, groups=hidden_dim\n                ),\n                # pw-linear\n                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),\n                nn.BatchNorm2d(oup),\n            ]\n        )\n        self.conv = nn.Sequential(*layers)\n\n    def forward(self, x):\n        if self.use_res_connect:\n            return x + self.conv(x)\n        else:\n            return self.conv(x)\n\n\nclass MobileNetV2(Backbone):\n\n    def __init__(\n        self,\n        width_mult=1.0,\n        inverted_residual_setting=None,\n        round_nearest=8,\n        block=None,\n    ):\n        \"\"\"\n        MobileNet V2.\n\n        Args:\n            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount\n            inverted_residual_setting: Network structure\n            round_nearest (int): Round the number of channels in each layer to be a multiple of this number\n            Set to 1 to turn off rounding\n            block: Module specifying inverted residual building block for mobilenet\n        \"\"\"\n        super().__init__()\n\n        if block is None:\n            block = InvertedResidual\n        input_channel = 32\n        last_channel = 1280\n\n        if inverted_residual_setting is None:\n            inverted_residual_setting = [\n                # t, c, n, s\n                [1, 16, 1, 1],\n                [6, 24, 2, 2],\n                [6, 32, 3, 2],\n                [6, 64, 4, 2],\n                [6, 96, 3, 1],\n                [6, 160, 3, 2],\n                [6, 320, 1, 1],\n            ]\n\n        # only check the first element, assuming user knows t,c,n,s are required\n        if (\n            len(inverted_residual_setting) == 0\n            or len(inverted_residual_setting[0]) != 4\n        ):\n            raise ValueError(\n                \"inverted_residual_setting should be non-empty \"\n                \"or a 4-element list, got {}\".\n                format(inverted_residual_setting)\n            )\n\n        # building first layer\n        input_channel = _make_divisible(\n            input_channel * width_mult, round_nearest\n        )\n        self.last_channel = _make_divisible(\n            last_channel * max(1.0, width_mult), round_nearest\n        )\n        features = [ConvBNReLU(3, input_channel, stride=2)]\n        # building inverted residual blocks\n        for t, c, n, s in inverted_residual_setting:\n            output_channel = _make_divisible(c * width_mult, round_nearest)\n            for i in range(n):\n                stride = s if i == 0 else 1\n                features.append(\n                    block(\n                        input_channel, output_channel, stride, expand_ratio=t\n                    )\n                )\n                input_channel = output_channel\n        # building last several layers\n        features.append(\n            ConvBNReLU(input_channel, self.last_channel, kernel_size=1)\n        )\n        # make it nn.Sequential\n        self.features = nn.Sequential(*features)\n\n        self._out_features = self.last_channel\n\n        # weight initialization\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode=\"fan_out\")\n                if m.bias is not None:\n                    nn.init.zeros_(m.bias)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.ones_(m.weight)\n                nn.init.zeros_(m.bias)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                nn.init.zeros_(m.bias)\n\n    def _forward_impl(self, x):\n        # This exists since TorchScript doesn't support inheritance, so the superclass method\n        # (this one) needs to have a name other than `forward` that can be accessed in a subclass\n        x = self.features(x)\n        x = x.mean([2, 3])\n        return x\n\n    def forward(self, x):\n        return self._forward_impl(x)\n\n\ndef init_pretrained_weights(model, model_url):\n    \"\"\"Initializes model with pretrained weights.\n\n    Layers that don't match with pretrained layers in name or size are kept unchanged.\n    \"\"\"\n    if model_url is None:\n        import warnings\n\n        warnings.warn(\n            \"ImageNet pretrained weights are unavailable for this model\"\n        )\n        return\n    pretrain_dict = model_zoo.load_url(model_url)\n    model_dict = model.state_dict()\n    pretrain_dict = {\n        k: v\n        for k, v in pretrain_dict.items()\n        if k in model_dict and model_dict[k].size() == v.size()\n    }\n    model_dict.update(pretrain_dict)\n    model.load_state_dict(model_dict)\n\n\n@BACKBONE_REGISTRY.register()\ndef mobilenetv2(pretrained=True, **kwargs):\n    model = MobileNetV2(**kwargs)\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"mobilenet_v2\"])\n    return model\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\n\nclass PreActBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, in_planes, planes, stride=1):\n        super().__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(\n            in_planes,\n            planes,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            bias=False\n        )\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(\n            planes, planes, kernel_size=3, stride=1, padding=1, bias=False\n        )\n\n        if stride != 1 or in_planes != self.expansion * planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(\n                    in_planes,\n                    self.expansion * planes,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False,\n                )\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(x))\n        shortcut = self.shortcut(out) if hasattr(self, \"shortcut\") else x\n        out = self.conv1(out)\n        out = self.conv2(F.relu(self.bn2(out)))\n        out += shortcut\n        return out\n\n\nclass PreActBottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, in_planes, planes, stride=1):\n        super().__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(\n            planes,\n            planes,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            bias=False\n        )\n        self.bn3 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(\n            planes, self.expansion * planes, kernel_size=1, bias=False\n        )\n\n        if stride != 1 or in_planes != self.expansion * planes:\n            self.shortcut = nn.Sequential(\n                nn.Conv2d(\n                    in_planes,\n                    self.expansion * planes,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False,\n                )\n            )\n\n    def forward(self, x):\n        out = F.relu(self.bn1(x))\n        shortcut = self.shortcut(out) if hasattr(self, \"shortcut\") else x\n        out = self.conv1(out)\n        out = self.conv2(F.relu(self.bn2(out)))\n        out = self.conv3(F.relu(self.bn3(out)))\n        out += shortcut\n        return out\n\n\nclass PreActResNet(Backbone):\n\n    def __init__(self, block, num_blocks):\n        super().__init__()\n        self.in_planes = 64\n\n        self.conv1 = nn.Conv2d(\n            3, 64, kernel_size=3, stride=1, padding=1, bias=False\n        )\n        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)\n        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)\n\n        self._out_features = 512 * block.expansion\n\n    def _make_layer(self, block, planes, num_blocks, stride):\n        strides = [stride] + [1] * (num_blocks-1)\n        layers = []\n        for stride in strides:\n            layers.append(block(self.in_planes, planes, stride))\n            self.in_planes = planes * block.expansion\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.layer1(out)\n        out = self.layer2(out)\n        out = self.layer3(out)\n        out = self.layer4(out)\n        out = F.avg_pool2d(out, 4)\n        out = out.view(out.size(0), -1)\n        return out\n\n\n\"\"\"\nPreact-ResNet18 was used for the CIFAR10 and\nSVHN datasets (both are SSL tasks) in\n\n- Wang et al. Semi-Supervised Learning by\nAugmented Distribution Alignment. ICCV 2019.\n\"\"\"\n\n\n@BACKBONE_REGISTRY.register()\ndef preact_resnet18(**kwargs):\n    return PreActResNet(PreActBlock, [2, 2, 2, 2])\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py",
    "content": "import torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\nmodel_urls = {\n    \"resnet18\": \"https://download.pytorch.org/models/resnet18-5c106cde.pth\",\n    \"resnet34\": \"https://download.pytorch.org/models/resnet34-333f7ec4.pth\",\n    \"resnet50\": \"https://download.pytorch.org/models/resnet50-19c8e357.pth\",\n    \"resnet101\": \"https://download.pytorch.org/models/resnet101-5d3b4d8f.pth\",\n    \"resnet152\": \"https://download.pytorch.org/models/resnet152-b121ed2d.pth\",\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(\n        in_planes,\n        out_planes,\n        kernel_size=3,\n        stride=stride,\n        padding=1,\n        bias=False\n    )\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super().__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super().__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(\n            planes,\n            planes,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            bias=False\n        )\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(\n            planes, planes * self.expansion, kernel_size=1, bias=False\n        )\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(Backbone):\n\n    def __init__(\n        self,\n        block,\n        layers,\n        ms_class=None,\n        ms_layers=[],\n        ms_p=0.5,\n        ms_a=0.1,\n        **kwargs\n    ):\n        self.inplanes = 64\n        super().__init__()\n\n        # backbone network\n        self.conv1 = nn.Conv2d(\n            3, 64, kernel_size=7, stride=2, padding=3, bias=False\n        )\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.global_avgpool = nn.AdaptiveAvgPool2d(1)\n\n        self._out_features = 512 * block.expansion\n\n        self.mixstyle = None\n        if ms_layers:\n            self.mixstyle = ms_class(p=ms_p, alpha=ms_a)\n            for layer_name in ms_layers:\n                assert layer_name in [\"layer1\", \"layer2\", \"layer3\"]\n            print(f\"Insert MixStyle after {ms_layers}\")\n        self.ms_layers = ms_layers\n\n        self._init_params()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(\n                    self.inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False,\n                ),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def _init_params(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(\n                    m.weight, mode=\"fan_out\", nonlinearity=\"relu\"\n                )\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm1d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def featuremaps(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n        x = self.layer1(x)\n        if \"layer1\" in self.ms_layers:\n            x = self.mixstyle(x)\n        x = self.layer2(x)\n        if \"layer2\" in self.ms_layers:\n            x = self.mixstyle(x)\n        x = self.layer3(x)\n        if \"layer3\" in self.ms_layers:\n            x = self.mixstyle(x)\n        return self.layer4(x)\n\n    def forward(self, x):\n        f = self.featuremaps(x)\n        v = self.global_avgpool(f)\n        return v.view(v.size(0), -1)\n\n\ndef init_pretrained_weights(model, model_url):\n    pretrain_dict = model_zoo.load_url(model_url)\n    model.load_state_dict(pretrain_dict, strict=False)\n\n\n\"\"\"\nResidual network configurations:\n--\nresnet18: block=BasicBlock, layers=[2, 2, 2, 2]\nresnet34: block=BasicBlock, layers=[3, 4, 6, 3]\nresnet50: block=Bottleneck, layers=[3, 4, 6, 3]\nresnet101: block=Bottleneck, layers=[3, 4, 23, 3]\nresnet152: block=Bottleneck, layers=[3, 8, 36, 3]\n\"\"\"\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet18(pretrained=True, **kwargs):\n    model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2])\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet18\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet34(pretrained=True, **kwargs):\n    model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3])\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet34\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet50(pretrained=True, **kwargs):\n    model = ResNet(block=Bottleneck, layers=[3, 4, 6, 3])\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet50\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet101(pretrained=True, **kwargs):\n    model = ResNet(block=Bottleneck, layers=[3, 4, 23, 3])\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet101\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet152(pretrained=True, **kwargs):\n    model = ResNet(block=Bottleneck, layers=[3, 8, 36, 3])\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet152\"])\n\n    return model\n\n\n\"\"\"\nResidual networks with mixstyle\n\"\"\"\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet18_ms_l123(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=BasicBlock,\n        layers=[2, 2, 2, 2],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\", \"layer2\", \"layer3\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet18\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet18_ms_l12(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=BasicBlock,\n        layers=[2, 2, 2, 2],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\", \"layer2\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet18\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet18_ms_l1(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=BasicBlock,\n        layers=[2, 2, 2, 2],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\"]\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet18\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet50_ms_l123(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 6, 3],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\", \"layer2\", \"layer3\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet50\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet50_ms_l12(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 6, 3],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\", \"layer2\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet50\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet50_ms_l1(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 6, 3],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\"]\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet50\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet101_ms_l123(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 23, 3],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\", \"layer2\", \"layer3\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet101\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet101_ms_l12(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 23, 3],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\", \"layer2\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet101\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet101_ms_l1(pretrained=True, **kwargs):\n    from dassl.modeling.ops import MixStyle\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 23, 3],\n        ms_class=MixStyle,\n        ms_layers=[\"layer1\"]\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet101\"])\n\n    return model\n\n\n\"\"\"\nResidual networks with efdmix\n\"\"\"\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet18_efdmix_l123(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=BasicBlock,\n        layers=[2, 2, 2, 2],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\", \"layer2\", \"layer3\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet18\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet18_efdmix_l12(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=BasicBlock,\n        layers=[2, 2, 2, 2],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\", \"layer2\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet18\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet18_efdmix_l1(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=BasicBlock,\n        layers=[2, 2, 2, 2],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\"]\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet18\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet50_efdmix_l123(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 6, 3],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\", \"layer2\", \"layer3\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet50\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet50_efdmix_l12(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 6, 3],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\", \"layer2\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet50\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet50_efdmix_l1(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 6, 3],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\"]\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet50\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet101_efdmix_l123(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 23, 3],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\", \"layer2\", \"layer3\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet101\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet101_efdmix_l12(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 23, 3],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\", \"layer2\"],\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet101\"])\n\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef resnet101_efdmix_l1(pretrained=True, **kwargs):\n    from dassl.modeling.ops import EFDMix\n\n    model = ResNet(\n        block=Bottleneck,\n        layers=[3, 4, 23, 3],\n        ms_class=EFDMix,\n        ms_layers=[\"layer1\"]\n    )\n\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"resnet101\"])\n\n    return model\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py",
    "content": "\"\"\"\nCode source: https://github.com/pytorch/vision\n\"\"\"\nimport torch\nimport torch.utils.model_zoo as model_zoo\nfrom torch import nn\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\nmodel_urls = {\n    \"shufflenetv2_x0.5\":\n    \"https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth\",\n    \"shufflenetv2_x1.0\":\n    \"https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth\",\n    \"shufflenetv2_x1.5\": None,\n    \"shufflenetv2_x2.0\": None,\n}\n\n\ndef channel_shuffle(x, groups):\n    batchsize, num_channels, height, width = x.data.size()\n    channels_per_group = num_channels // groups\n\n    # reshape\n    x = x.view(batchsize, groups, channels_per_group, height, width)\n\n    x = torch.transpose(x, 1, 2).contiguous()\n\n    # flatten\n    x = x.view(batchsize, -1, height, width)\n\n    return x\n\n\nclass InvertedResidual(nn.Module):\n\n    def __init__(self, inp, oup, stride):\n        super().__init__()\n\n        if not (1 <= stride <= 3):\n            raise ValueError(\"illegal stride value\")\n        self.stride = stride\n\n        branch_features = oup // 2\n        assert (self.stride != 1) or (inp == branch_features << 1)\n\n        if self.stride > 1:\n            self.branch1 = nn.Sequential(\n                self.depthwise_conv(\n                    inp, inp, kernel_size=3, stride=self.stride, padding=1\n                ),\n                nn.BatchNorm2d(inp),\n                nn.Conv2d(\n                    inp,\n                    branch_features,\n                    kernel_size=1,\n                    stride=1,\n                    padding=0,\n                    bias=False\n                ),\n                nn.BatchNorm2d(branch_features),\n                nn.ReLU(inplace=True),\n            )\n\n        self.branch2 = nn.Sequential(\n            nn.Conv2d(\n                inp if (self.stride > 1) else branch_features,\n                branch_features,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                bias=False,\n            ),\n            nn.BatchNorm2d(branch_features),\n            nn.ReLU(inplace=True),\n            self.depthwise_conv(\n                branch_features,\n                branch_features,\n                kernel_size=3,\n                stride=self.stride,\n                padding=1,\n            ),\n            nn.BatchNorm2d(branch_features),\n            nn.Conv2d(\n                branch_features,\n                branch_features,\n                kernel_size=1,\n                stride=1,\n                padding=0,\n                bias=False,\n            ),\n            nn.BatchNorm2d(branch_features),\n            nn.ReLU(inplace=True),\n        )\n\n    @staticmethod\n    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):\n        return nn.Conv2d(\n            i, o, kernel_size, stride, padding, bias=bias, groups=i\n        )\n\n    def forward(self, x):\n        if self.stride == 1:\n            x1, x2 = x.chunk(2, dim=1)\n            out = torch.cat((x1, self.branch2(x2)), dim=1)\n        else:\n            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)\n\n        out = channel_shuffle(out, 2)\n\n        return out\n\n\nclass ShuffleNetV2(Backbone):\n\n    def __init__(self, stages_repeats, stages_out_channels, **kwargs):\n        super().__init__()\n        if len(stages_repeats) != 3:\n            raise ValueError(\n                \"expected stages_repeats as list of 3 positive ints\"\n            )\n        if len(stages_out_channels) != 5:\n            raise ValueError(\n                \"expected stages_out_channels as list of 5 positive ints\"\n            )\n        self._stage_out_channels = stages_out_channels\n\n        input_channels = 3\n        output_channels = self._stage_out_channels[0]\n        self.conv1 = nn.Sequential(\n            nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),\n            nn.BatchNorm2d(output_channels),\n            nn.ReLU(inplace=True),\n        )\n        input_channels = output_channels\n\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n\n        stage_names = [\"stage{}\".format(i) for i in [2, 3, 4]]\n        for name, repeats, output_channels in zip(\n            stage_names, stages_repeats, self._stage_out_channels[1:]\n        ):\n            seq = [InvertedResidual(input_channels, output_channels, 2)]\n            for i in range(repeats - 1):\n                seq.append(\n                    InvertedResidual(output_channels, output_channels, 1)\n                )\n            setattr(self, name, nn.Sequential(*seq))\n            input_channels = output_channels\n\n        output_channels = self._stage_out_channels[-1]\n        self.conv5 = nn.Sequential(\n            nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),\n            nn.BatchNorm2d(output_channels),\n            nn.ReLU(inplace=True),\n        )\n        self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))\n\n        self._out_features = output_channels\n\n    def featuremaps(self, x):\n        x = self.conv1(x)\n        x = self.maxpool(x)\n        x = self.stage2(x)\n        x = self.stage3(x)\n        x = self.stage4(x)\n        x = self.conv5(x)\n        return x\n\n    def forward(self, x):\n        f = self.featuremaps(x)\n        v = self.global_avgpool(f)\n        return v.view(v.size(0), -1)\n\n\ndef init_pretrained_weights(model, model_url):\n    \"\"\"Initializes model with pretrained weights.\n\n    Layers that don't match with pretrained layers in name or size are kept unchanged.\n    \"\"\"\n    if model_url is None:\n        import warnings\n\n        warnings.warn(\n            \"ImageNet pretrained weights are unavailable for this model\"\n        )\n        return\n    pretrain_dict = model_zoo.load_url(model_url)\n    model_dict = model.state_dict()\n    pretrain_dict = {\n        k: v\n        for k, v in pretrain_dict.items()\n        if k in model_dict and model_dict[k].size() == v.size()\n    }\n    model_dict.update(pretrain_dict)\n    model.load_state_dict(model_dict)\n\n\n@BACKBONE_REGISTRY.register()\ndef shufflenet_v2_x0_5(pretrained=True, **kwargs):\n    model = ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"shufflenetv2_x0.5\"])\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef shufflenet_v2_x1_0(pretrained=True, **kwargs):\n    model = ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"shufflenetv2_x1.0\"])\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef shufflenet_v2_x1_5(pretrained=True, **kwargs):\n    model = ShuffleNetV2([4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"shufflenetv2_x1.5\"])\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef shufflenet_v2_x2_0(pretrained=True, **kwargs):\n    model = ShuffleNetV2([4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)\n    if pretrained:\n        init_pretrained_weights(model, model_urls[\"shufflenetv2_x2.0\"])\n    return model\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\ntry:\n    from torch.hub import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\n\nmodel_urls = {\n    \"vgg11\": \"https://download.pytorch.org/models/vgg11-bbd30ac9.pth\",\n    \"vgg13\": \"https://download.pytorch.org/models/vgg13-c768596a.pth\",\n    \"vgg16\": \"https://download.pytorch.org/models/vgg16-397923af.pth\",\n    \"vgg19\": \"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth\",\n    \"vgg11_bn\": \"https://download.pytorch.org/models/vgg11_bn-6002323d.pth\",\n    \"vgg13_bn\": \"https://download.pytorch.org/models/vgg13_bn-abd245e5.pth\",\n    \"vgg16_bn\": \"https://download.pytorch.org/models/vgg16_bn-6c64b313.pth\",\n    \"vgg19_bn\": \"https://download.pytorch.org/models/vgg19_bn-c79401a0.pth\",\n}\n\n\nclass VGG(Backbone):\n\n    def __init__(self, features, init_weights=True):\n        super().__init__()\n        self.features = features\n        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))\n        # Note that self.classifier outputs features rather than logits\n        self.classifier = nn.Sequential(\n            nn.Linear(512 * 7 * 7, 4096),\n            nn.ReLU(True),\n            nn.Dropout(),\n            nn.Linear(4096, 4096),\n            nn.ReLU(True),\n            nn.Dropout(),\n        )\n\n        self._out_features = 4096\n\n        if init_weights:\n            self._initialize_weights()\n\n    def forward(self, x):\n        x = self.features(x)\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n        return self.classifier(x)\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(\n                    m.weight, mode=\"fan_out\", nonlinearity=\"relu\"\n                )\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                nn.init.normal_(m.weight, 0, 0.01)\n                nn.init.constant_(m.bias, 0)\n\n\ndef make_layers(cfg, batch_norm=False):\n    layers = []\n    in_channels = 3\n    for v in cfg:\n        if v == \"M\":\n            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]\n        else:\n            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)\n            if batch_norm:\n                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]\n            else:\n                layers += [conv2d, nn.ReLU(inplace=True)]\n            in_channels = v\n    return nn.Sequential(*layers)\n\n\ncfgs = {\n    \"A\": [64, \"M\", 128, \"M\", 256, 256, \"M\", 512, 512, \"M\", 512, 512, \"M\"],\n    \"B\":\n    [64, 64, \"M\", 128, 128, \"M\", 256, 256, \"M\", 512, 512, \"M\", 512, 512, \"M\"],\n    \"D\": [\n        64,\n        64,\n        \"M\",\n        128,\n        128,\n        \"M\",\n        256,\n        256,\n        256,\n        \"M\",\n        512,\n        512,\n        512,\n        \"M\",\n        512,\n        512,\n        512,\n        \"M\",\n    ],\n    \"E\": [\n        64,\n        64,\n        \"M\",\n        128,\n        128,\n        \"M\",\n        256,\n        256,\n        256,\n        256,\n        \"M\",\n        512,\n        512,\n        512,\n        512,\n        \"M\",\n        512,\n        512,\n        512,\n        512,\n        \"M\",\n    ],\n}\n\n\ndef _vgg(arch, cfg, batch_norm, pretrained):\n    init_weights = False if pretrained else True\n    model = VGG(\n        make_layers(cfgs[cfg], batch_norm=batch_norm),\n        init_weights=init_weights\n    )\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch], progress=True)\n        model.load_state_dict(state_dict, strict=False)\n    return model\n\n\n@BACKBONE_REGISTRY.register()\ndef vgg16(pretrained=True, **kwargs):\n    return _vgg(\"vgg16\", \"D\", False, pretrained)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py",
    "content": "\"\"\"\nModified from https://github.com/xternalz/WideResNet-pytorch\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\n\nclass BasicBlock(nn.Module):\n\n    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):\n        super().__init__()\n        self.bn1 = nn.BatchNorm2d(in_planes)\n        self.relu1 = nn.LeakyReLU(0.01, inplace=True)\n        self.conv1 = nn.Conv2d(\n            in_planes,\n            out_planes,\n            kernel_size=3,\n            stride=stride,\n            padding=1,\n            bias=False\n        )\n        self.bn2 = nn.BatchNorm2d(out_planes)\n        self.relu2 = nn.LeakyReLU(0.01, inplace=True)\n        self.conv2 = nn.Conv2d(\n            out_planes,\n            out_planes,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n            bias=False\n        )\n        self.droprate = dropRate\n        self.equalInOut = in_planes == out_planes\n        self.convShortcut = (\n            (not self.equalInOut) and nn.Conv2d(\n                in_planes,\n                out_planes,\n                kernel_size=1,\n                stride=stride,\n                padding=0,\n                bias=False,\n            ) or None\n        )\n\n    def forward(self, x):\n        if not self.equalInOut:\n            x = self.relu1(self.bn1(x))\n        else:\n            out = self.relu1(self.bn1(x))\n        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))\n        if self.droprate > 0:\n            out = F.dropout(out, p=self.droprate, training=self.training)\n        out = self.conv2(out)\n        return torch.add(x if self.equalInOut else self.convShortcut(x), out)\n\n\nclass NetworkBlock(nn.Module):\n\n    def __init__(\n        self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0\n    ):\n        super().__init__()\n        self.layer = self._make_layer(\n            block, in_planes, out_planes, nb_layers, stride, dropRate\n        )\n\n    def _make_layer(\n        self, block, in_planes, out_planes, nb_layers, stride, dropRate\n    ):\n        layers = []\n        for i in range(int(nb_layers)):\n            layers.append(\n                block(\n                    i == 0 and in_planes or out_planes,\n                    out_planes,\n                    i == 0 and stride or 1,\n                    dropRate,\n                )\n            )\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.layer(x)\n\n\nclass WideResNet(Backbone):\n\n    def __init__(self, depth, widen_factor, dropRate=0.0):\n        super().__init__()\n        nChannels = [\n            16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor\n        ]\n        assert (depth-4) % 6 == 0\n        n = (depth-4) / 6\n        block = BasicBlock\n        # 1st conv before any network block\n        self.conv1 = nn.Conv2d(\n            3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False\n        )\n        # 1st block\n        self.block1 = NetworkBlock(\n            n, nChannels[0], nChannels[1], block, 1, dropRate\n        )\n        # 2nd block\n        self.block2 = NetworkBlock(\n            n, nChannels[1], nChannels[2], block, 2, dropRate\n        )\n        # 3rd block\n        self.block3 = NetworkBlock(\n            n, nChannels[2], nChannels[3], block, 2, dropRate\n        )\n        # global average pooling and classifier\n        self.bn1 = nn.BatchNorm2d(nChannels[3])\n        self.relu = nn.LeakyReLU(0.01, inplace=True)\n\n        self._out_features = nChannels[3]\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(\n                    m.weight, mode=\"fan_out\", nonlinearity=\"relu\"\n                )\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.block1(out)\n        out = self.block2(out)\n        out = self.block3(out)\n        out = self.relu(self.bn1(out))\n        out = F.adaptive_avg_pool2d(out, 1)\n        return out.view(out.size(0), -1)\n\n\n@BACKBONE_REGISTRY.register()\ndef wide_resnet_28_2(**kwargs):\n    return WideResNet(28, 2)\n\n\n@BACKBONE_REGISTRY.register()\ndef wide_resnet_16_4(**kwargs):\n    return WideResNet(16, 4)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py",
    "content": "from .build import build_head, HEAD_REGISTRY  # isort:skip\n\nfrom .mlp import mlp\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/head/build.py",
    "content": "from dassl.utils import Registry, check_availability\n\nHEAD_REGISTRY = Registry(\"HEAD\")\n\n\ndef build_head(name, verbose=True, **kwargs):\n    avai_heads = HEAD_REGISTRY.registered_names()\n    check_availability(name, avai_heads)\n    if verbose:\n        print(\"Head: {}\".format(name))\n    return HEAD_REGISTRY.get(name)(**kwargs)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py",
    "content": "import functools\nimport torch.nn as nn\n\nfrom .build import HEAD_REGISTRY\n\n\nclass MLP(nn.Module):\n\n    def __init__(\n        self,\n        in_features=2048,\n        hidden_layers=[],\n        activation=\"relu\",\n        bn=True,\n        dropout=0.0,\n    ):\n        super().__init__()\n        if isinstance(hidden_layers, int):\n            hidden_layers = [hidden_layers]\n\n        assert len(hidden_layers) > 0\n        self.out_features = hidden_layers[-1]\n\n        mlp = []\n\n        if activation == \"relu\":\n            act_fn = functools.partial(nn.ReLU, inplace=True)\n        elif activation == \"leaky_relu\":\n            act_fn = functools.partial(nn.LeakyReLU, inplace=True)\n        else:\n            raise NotImplementedError\n\n        for hidden_dim in hidden_layers:\n            mlp += [nn.Linear(in_features, hidden_dim)]\n            if bn:\n                mlp += [nn.BatchNorm1d(hidden_dim)]\n            mlp += [act_fn()]\n            if dropout > 0:\n                mlp += [nn.Dropout(dropout)]\n            in_features = hidden_dim\n\n        self.mlp = nn.Sequential(*mlp)\n\n    def forward(self, x):\n        return self.mlp(x)\n\n\n@HEAD_REGISTRY.register()\ndef mlp(**kwargs):\n    return MLP(**kwargs)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py",
    "content": "from .build import build_network, NETWORK_REGISTRY  # isort:skip\n\nfrom .ddaig_fcn import (\n    fcn_3x32_gctx, fcn_3x64_gctx, fcn_3x32_gctx_stn, fcn_3x64_gctx_stn\n)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/network/build.py",
    "content": "from dassl.utils import Registry, check_availability\n\nNETWORK_REGISTRY = Registry(\"NETWORK\")\n\n\ndef build_network(name, verbose=True, **kwargs):\n    avai_models = NETWORK_REGISTRY.registered_names()\n    check_availability(name, avai_models)\n    if verbose:\n        print(\"Network: {}\".format(name))\n    return NETWORK_REGISTRY.get(name)(**kwargs)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py",
    "content": "\"\"\"\nCredit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n\"\"\"\nimport functools\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom .build import NETWORK_REGISTRY\n\n\ndef init_network_weights(model, init_type=\"normal\", gain=0.02):\n\n    def _init_func(m):\n        classname = m.__class__.__name__\n        if hasattr(m, \"weight\") and (\n            classname.find(\"Conv\") != -1 or classname.find(\"Linear\") != -1\n        ):\n            if init_type == \"normal\":\n                nn.init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == \"xavier\":\n                nn.init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == \"kaiming\":\n                nn.init.kaiming_normal_(m.weight.data, a=0, mode=\"fan_in\")\n            elif init_type == \"orthogonal\":\n                nn.init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError(\n                    \"initialization method {} is not implemented\".\n                    format(init_type)\n                )\n            if hasattr(m, \"bias\") and m.bias is not None:\n                nn.init.constant_(m.bias.data, 0.0)\n        elif classname.find(\"BatchNorm2d\") != -1:\n            nn.init.constant_(m.weight.data, 1.0)\n            nn.init.constant_(m.bias.data, 0.0)\n        elif classname.find(\"InstanceNorm2d\") != -1:\n            if m.weight is not None and m.bias is not None:\n                nn.init.constant_(m.weight.data, 1.0)\n                nn.init.constant_(m.bias.data, 0.0)\n\n    model.apply(_init_func)\n\n\ndef get_norm_layer(norm_type=\"instance\"):\n    if norm_type == \"batch\":\n        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)\n    elif norm_type == \"instance\":\n        norm_layer = functools.partial(\n            nn.InstanceNorm2d, affine=False, track_running_stats=False\n        )\n    elif norm_type == \"none\":\n        norm_layer = None\n    else:\n        raise NotImplementedError(\n            \"normalization layer [%s] is not found\" % norm_type\n        )\n    return norm_layer\n\n\nclass ResnetBlock(nn.Module):\n\n    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):\n        super().__init__()\n        self.conv_block = self.build_conv_block(\n            dim, padding_type, norm_layer, use_dropout, use_bias\n        )\n\n    def build_conv_block(\n        self, dim, padding_type, norm_layer, use_dropout, use_bias\n    ):\n        conv_block = []\n        p = 0\n        if padding_type == \"reflect\":\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == \"replicate\":\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == \"zero\":\n            p = 1\n        else:\n            raise NotImplementedError(\n                \"padding [%s] is not implemented\" % padding_type\n            )\n\n        conv_block += [\n            nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),\n            norm_layer(dim),\n            nn.ReLU(True),\n        ]\n        if use_dropout:\n            conv_block += [nn.Dropout(0.5)]\n\n        p = 0\n        if padding_type == \"reflect\":\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == \"replicate\":\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == \"zero\":\n            p = 1\n        else:\n            raise NotImplementedError(\n                \"padding [%s] is not implemented\" % padding_type\n            )\n        conv_block += [\n            nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),\n            norm_layer(dim),\n        ]\n\n        return nn.Sequential(*conv_block)\n\n    def forward(self, x):\n        return x + self.conv_block(x)\n\n\nclass LocNet(nn.Module):\n    \"\"\"Localization network.\"\"\"\n\n    def __init__(\n        self,\n        input_nc,\n        nc=32,\n        n_blocks=3,\n        use_dropout=False,\n        padding_type=\"zero\",\n        image_size=32,\n    ):\n        super().__init__()\n\n        backbone = []\n        backbone += [\n            nn.Conv2d(\n                input_nc, nc, kernel_size=3, stride=2, padding=1, bias=False\n            )\n        ]\n        backbone += [nn.BatchNorm2d(nc)]\n        backbone += [nn.ReLU(True)]\n        for _ in range(n_blocks):\n            backbone += [\n                ResnetBlock(\n                    nc,\n                    padding_type=padding_type,\n                    norm_layer=nn.BatchNorm2d,\n                    use_dropout=use_dropout,\n                    use_bias=False,\n                )\n            ]\n            backbone += [nn.MaxPool2d(2, stride=2)]\n        self.backbone = nn.Sequential(*backbone)\n        reduced_imsize = int(image_size * 0.5**(n_blocks + 1))\n        self.fc_loc = nn.Linear(nc * reduced_imsize**2, 2 * 2)\n\n    def forward(self, x):\n        x = self.backbone(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc_loc(x)\n        x = torch.tanh(x)\n        x = x.view(-1, 2, 2)\n        theta = x.data.new_zeros(x.size(0), 2, 3)\n        theta[:, :, :2] = x\n        return theta\n\n\nclass FCN(nn.Module):\n    \"\"\"Fully convolutional network.\"\"\"\n\n    def __init__(\n        self,\n        input_nc,\n        output_nc,\n        nc=32,\n        n_blocks=3,\n        norm_layer=nn.BatchNorm2d,\n        use_dropout=False,\n        padding_type=\"reflect\",\n        gctx=True,\n        stn=False,\n        image_size=32,\n    ):\n        super().__init__()\n\n        backbone = []\n\n        p = 0\n        if padding_type == \"reflect\":\n            backbone += [nn.ReflectionPad2d(1)]\n        elif padding_type == \"replicate\":\n            backbone += [nn.ReplicationPad2d(1)]\n        elif padding_type == \"zero\":\n            p = 1\n        else:\n            raise NotImplementedError\n        backbone += [\n            nn.Conv2d(\n                input_nc, nc, kernel_size=3, stride=1, padding=p, bias=False\n            )\n        ]\n        backbone += [norm_layer(nc)]\n        backbone += [nn.ReLU(True)]\n\n        for _ in range(n_blocks):\n            backbone += [\n                ResnetBlock(\n                    nc,\n                    padding_type=padding_type,\n                    norm_layer=norm_layer,\n                    use_dropout=use_dropout,\n                    use_bias=False,\n                )\n            ]\n        self.backbone = nn.Sequential(*backbone)\n\n        # global context fusion layer\n        self.gctx_fusion = None\n        if gctx:\n            self.gctx_fusion = nn.Sequential(\n                nn.Conv2d(\n                    2 * nc, nc, kernel_size=1, stride=1, padding=0, bias=False\n                ),\n                norm_layer(nc),\n                nn.ReLU(True),\n            )\n\n        self.regress = nn.Sequential(\n            nn.Conv2d(\n                nc, output_nc, kernel_size=1, stride=1, padding=0, bias=True\n            ),\n            nn.Tanh(),\n        )\n\n        self.locnet = None\n        if stn:\n            self.locnet = LocNet(\n                input_nc, nc=nc, n_blocks=n_blocks, image_size=image_size\n            )\n\n    def init_loc_layer(self):\n        \"\"\"Initialize the weights/bias with identity transformation.\"\"\"\n        if self.locnet is not None:\n            self.locnet.fc_loc.weight.data.zero_()\n            self.locnet.fc_loc.bias.data.copy_(\n                torch.tensor([1, 0, 0, 1], dtype=torch.float)\n            )\n\n    def stn(self, x):\n        \"\"\"Spatial transformer network.\"\"\"\n        theta = self.locnet(x)\n        grid = F.affine_grid(theta, x.size())\n        return F.grid_sample(x, grid), theta\n\n    def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):\n        \"\"\"\n        Args:\n            x (torch.Tensor): input mini-batch.\n            lmda (float): multiplier for perturbation.\n            return_p (bool): return perturbation.\n            return_stn_output (bool): return the output of stn.\n        \"\"\"\n        theta = None\n        if self.locnet is not None:\n            x, theta = self.stn(x)\n        input = x\n\n        x = self.backbone(x)\n        if self.gctx_fusion is not None:\n            c = F.adaptive_avg_pool2d(x, (1, 1))\n            c = c.expand_as(x)\n            x = torch.cat([x, c], 1)\n            x = self.gctx_fusion(x)\n\n        p = self.regress(x)\n        x_p = input + lmda*p\n\n        if return_stn_output:\n            return x_p, p, input\n\n        if return_p:\n            return x_p, p\n\n        return x_p\n\n\n@NETWORK_REGISTRY.register()\ndef fcn_3x32_gctx(**kwargs):\n    norm_layer = get_norm_layer(norm_type=\"instance\")\n    net = FCN(3, 3, nc=32, n_blocks=3, norm_layer=norm_layer)\n    init_network_weights(net, init_type=\"normal\", gain=0.02)\n    return net\n\n\n@NETWORK_REGISTRY.register()\ndef fcn_3x64_gctx(**kwargs):\n    norm_layer = get_norm_layer(norm_type=\"instance\")\n    net = FCN(3, 3, nc=64, n_blocks=3, norm_layer=norm_layer)\n    init_network_weights(net, init_type=\"normal\", gain=0.02)\n    return net\n\n\n@NETWORK_REGISTRY.register()\ndef fcn_3x32_gctx_stn(image_size=32, **kwargs):\n    norm_layer = get_norm_layer(norm_type=\"instance\")\n    net = FCN(\n        3,\n        3,\n        nc=32,\n        n_blocks=3,\n        norm_layer=norm_layer,\n        stn=True,\n        image_size=image_size\n    )\n    init_network_weights(net, init_type=\"normal\", gain=0.02)\n    net.init_loc_layer()\n    return net\n\n\n@NETWORK_REGISTRY.register()\ndef fcn_3x64_gctx_stn(image_size=224, **kwargs):\n    norm_layer = get_norm_layer(norm_type=\"instance\")\n    net = FCN(\n        3,\n        3,\n        nc=64,\n        n_blocks=3,\n        norm_layer=norm_layer,\n        stn=True,\n        image_size=image_size\n    )\n    init_network_weights(net, init_type=\"normal\", gain=0.02)\n    net.init_loc_layer()\n    return net\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py",
    "content": "from .mmd import MaximumMeanDiscrepancy\nfrom .dsbn import DSBN1d, DSBN2d\nfrom .mixup import mixup\nfrom .efdmix import (\n    EFDMix, random_efdmix, activate_efdmix, run_with_efdmix, deactivate_efdmix,\n    crossdomain_efdmix, run_without_efdmix\n)\nfrom .mixstyle import (\n    MixStyle, random_mixstyle, activate_mixstyle, run_with_mixstyle,\n    deactivate_mixstyle, crossdomain_mixstyle, run_without_mixstyle\n)\nfrom .transnorm import TransNorm1d, TransNorm2d\nfrom .sequential2 import Sequential2\nfrom .reverse_grad import ReverseGrad\nfrom .cross_entropy import cross_entropy\nfrom .optimal_transport import SinkhornDivergence, MinibatchEnergyDistance\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\n\ndef cross_entropy(input, target, label_smooth=0, reduction=\"mean\"):\n    \"\"\"Cross entropy loss.\n\n    Args:\n        input (torch.Tensor): logit matrix with shape of (batch, num_classes).\n        target (torch.LongTensor): int label matrix.\n        label_smooth (float, optional): label smoothing hyper-parameter.\n            Default is 0.\n        reduction (str, optional): how the losses for a mini-batch\n            will be aggregated. Default is 'mean'.\n    \"\"\"\n    num_classes = input.shape[1]\n    log_prob = F.log_softmax(input, dim=1)\n    zeros = torch.zeros(log_prob.size())\n    target = zeros.scatter_(1, target.unsqueeze(1).data.cpu(), 1)\n    target = target.type_as(input)\n    target = (1-label_smooth) * target + label_smooth/num_classes\n    loss = (-target * log_prob).sum(1)\n    if reduction == \"mean\":\n        return loss.mean()\n    elif reduction == \"sum\":\n        return loss.sum()\n    elif reduction == \"none\":\n        return loss\n    else:\n        raise ValueError\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py",
    "content": "import torch.nn as nn\n\n\nclass _DSBN(nn.Module):\n    \"\"\"Domain Specific Batch Normalization.\n\n    Args:\n        num_features (int): number of features.\n        n_domain (int): number of domains.\n        bn_type (str): type of bn. Choices are ['1d', '2d'].\n    \"\"\"\n\n    def __init__(self, num_features, n_domain, bn_type):\n        super().__init__()\n        if bn_type == \"1d\":\n            BN = nn.BatchNorm1d\n        elif bn_type == \"2d\":\n            BN = nn.BatchNorm2d\n        else:\n            raise ValueError\n\n        self.bn = nn.ModuleList(BN(num_features) for _ in range(n_domain))\n\n        self.valid_domain_idxs = list(range(n_domain))\n        self.n_domain = n_domain\n        self.domain_idx = 0\n\n    def select_bn(self, domain_idx=0):\n        assert domain_idx in self.valid_domain_idxs\n        self.domain_idx = domain_idx\n\n    def forward(self, x):\n        return self.bn[self.domain_idx](x)\n\n\nclass DSBN1d(_DSBN):\n\n    def __init__(self, num_features, n_domain):\n        super().__init__(num_features, n_domain, \"1d\")\n\n\nclass DSBN2d(_DSBN):\n\n    def __init__(self, num_features, n_domain):\n        super().__init__(num_features, n_domain, \"2d\")\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py",
    "content": "import random\nfrom contextlib import contextmanager\nimport torch\nimport torch.nn as nn\n\n\ndef deactivate_efdmix(m):\n    if type(m) == EFDMix:\n        m.set_activation_status(False)\n\n\ndef activate_efdmix(m):\n    if type(m) == EFDMix:\n        m.set_activation_status(True)\n\n\ndef random_efdmix(m):\n    if type(m) == EFDMix:\n        m.update_mix_method(\"random\")\n\n\ndef crossdomain_efdmix(m):\n    if type(m) == EFDMix:\n        m.update_mix_method(\"crossdomain\")\n\n\n@contextmanager\ndef run_without_efdmix(model):\n    # Assume MixStyle was initially activated\n    try:\n        model.apply(deactivate_efdmix)\n        yield\n    finally:\n        model.apply(activate_efdmix)\n\n\n@contextmanager\ndef run_with_efdmix(model, mix=None):\n    # Assume MixStyle was initially deactivated\n    if mix == \"random\":\n        model.apply(random_efdmix)\n\n    elif mix == \"crossdomain\":\n        model.apply(crossdomain_efdmix)\n\n    try:\n        model.apply(activate_efdmix)\n        yield\n    finally:\n        model.apply(deactivate_efdmix)\n\n\nclass EFDMix(nn.Module):\n    \"\"\"EFDMix.\n\n    Reference:\n      Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.\n    \"\"\"\n\n    def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix=\"random\"):\n        \"\"\"\n        Args:\n          p (float): probability of using MixStyle.\n          alpha (float): parameter of the Beta distribution.\n          eps (float): scaling parameter to avoid numerical issues.\n          mix (str): how to mix.\n        \"\"\"\n        super().__init__()\n        self.p = p\n        self.beta = torch.distributions.Beta(alpha, alpha)\n        self.eps = eps\n        self.alpha = alpha\n        self.mix = mix\n        self._activated = True\n\n    def __repr__(self):\n        return (\n            f\"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})\"\n        )\n\n    def set_activation_status(self, status=True):\n        self._activated = status\n\n    def update_mix_method(self, mix=\"random\"):\n        self.mix = mix\n\n    def forward(self, x):\n        if not self.training or not self._activated:\n            return x\n\n        if random.random() > self.p:\n            return x\n\n        B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)\n        x_view = x.view(B, C, -1)\n        value_x, index_x = torch.sort(x_view)  # sort inputs\n        lmda = self.beta.sample((B, 1, 1))\n        lmda = lmda.to(x.device)\n\n        if self.mix == \"random\":\n            # random shuffle\n            perm = torch.randperm(B)\n\n        elif self.mix == \"crossdomain\":\n            # split into two halves and swap the order\n            perm = torch.arange(B - 1, -1, -1)  # inverse index\n            perm_b, perm_a = perm.chunk(2)\n            perm_b = perm_b[torch.randperm(perm_b.shape[0])]\n            perm_a = perm_a[torch.randperm(perm_a.shape[0])]\n            perm = torch.cat([perm_b, perm_a], 0)\n\n        else:\n            raise NotImplementedError\n\n        inverse_index = index_x.argsort(-1)\n        x_view_copy = value_x[perm].gather(-1, inverse_index)\n        new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)\n        return new_x.view(B, C, W, H)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py",
    "content": "import random\nfrom contextlib import contextmanager\nimport torch\nimport torch.nn as nn\n\n\ndef deactivate_mixstyle(m):\n    if type(m) == MixStyle:\n        m.set_activation_status(False)\n\n\ndef activate_mixstyle(m):\n    if type(m) == MixStyle:\n        m.set_activation_status(True)\n\n\ndef random_mixstyle(m):\n    if type(m) == MixStyle:\n        m.update_mix_method(\"random\")\n\n\ndef crossdomain_mixstyle(m):\n    if type(m) == MixStyle:\n        m.update_mix_method(\"crossdomain\")\n\n\n@contextmanager\ndef run_without_mixstyle(model):\n    # Assume MixStyle was initially activated\n    try:\n        model.apply(deactivate_mixstyle)\n        yield\n    finally:\n        model.apply(activate_mixstyle)\n\n\n@contextmanager\ndef run_with_mixstyle(model, mix=None):\n    # Assume MixStyle was initially deactivated\n    if mix == \"random\":\n        model.apply(random_mixstyle)\n\n    elif mix == \"crossdomain\":\n        model.apply(crossdomain_mixstyle)\n\n    try:\n        model.apply(activate_mixstyle)\n        yield\n    finally:\n        model.apply(deactivate_mixstyle)\n\n\nclass MixStyle(nn.Module):\n    \"\"\"MixStyle.\n\n    Reference:\n      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.\n    \"\"\"\n\n    def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix=\"random\"):\n        \"\"\"\n        Args:\n          p (float): probability of using MixStyle.\n          alpha (float): parameter of the Beta distribution.\n          eps (float): scaling parameter to avoid numerical issues.\n          mix (str): how to mix.\n        \"\"\"\n        super().__init__()\n        self.p = p\n        self.beta = torch.distributions.Beta(alpha, alpha)\n        self.eps = eps\n        self.alpha = alpha\n        self.mix = mix\n        self._activated = True\n\n    def __repr__(self):\n        return (\n            f\"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})\"\n        )\n\n    def set_activation_status(self, status=True):\n        self._activated = status\n\n    def update_mix_method(self, mix=\"random\"):\n        self.mix = mix\n\n    def forward(self, x):\n        if not self.training or not self._activated:\n            return x\n\n        if random.random() > self.p:\n            return x\n\n        B = x.size(0)\n\n        mu = x.mean(dim=[2, 3], keepdim=True)\n        var = x.var(dim=[2, 3], keepdim=True)\n        sig = (var + self.eps).sqrt()\n        mu, sig = mu.detach(), sig.detach()\n        x_normed = (x-mu) / sig\n\n        lmda = self.beta.sample((B, 1, 1, 1))\n        lmda = lmda.to(x.device)\n\n        if self.mix == \"random\":\n            # random shuffle\n            perm = torch.randperm(B)\n\n        elif self.mix == \"crossdomain\":\n            # split into two halves and swap the order\n            perm = torch.arange(B - 1, -1, -1)  # inverse index\n            perm_b, perm_a = perm.chunk(2)\n            perm_b = perm_b[torch.randperm(perm_b.shape[0])]\n            perm_a = perm_a[torch.randperm(perm_a.shape[0])]\n            perm = torch.cat([perm_b, perm_a], 0)\n\n        else:\n            raise NotImplementedError\n\n        mu2, sig2 = mu[perm], sig[perm]\n        mu_mix = mu*lmda + mu2 * (1-lmda)\n        sig_mix = sig*lmda + sig2 * (1-lmda)\n\n        return x_normed*sig_mix + mu_mix\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py",
    "content": "import torch\n\n\ndef mixup(x1, x2, y1, y2, beta, preserve_order=False):\n    \"\"\"Mixup.\n\n    Args:\n        x1 (torch.Tensor): data with shape of (b, c, h, w).\n        x2 (torch.Tensor): data with shape of (b, c, h, w).\n        y1 (torch.Tensor): label with shape of (b, n).\n        y2 (torch.Tensor): label with shape of (b, n).\n        beta (float): hyper-parameter for Beta sampling.\n        preserve_order (bool): apply lmda=max(lmda, 1-lmda).\n            Default is False.\n    \"\"\"\n    lmda = torch.distributions.Beta(beta, beta).sample([x1.shape[0], 1, 1, 1])\n    if preserve_order:\n        lmda = torch.max(lmda, 1 - lmda)\n    lmda = lmda.to(x1.device)\n    xmix = x1*lmda + x2 * (1-lmda)\n    lmda = lmda[:, :, 0, 0]\n    ymix = y1*lmda + y2 * (1-lmda)\n    return xmix, ymix\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\nclass MaximumMeanDiscrepancy(nn.Module):\n\n    def __init__(self, kernel_type=\"rbf\", normalize=False):\n        super().__init__()\n        self.kernel_type = kernel_type\n        self.normalize = normalize\n\n    def forward(self, x, y):\n        # x, y: two batches of data with shape (batch, dim)\n        # MMD^2(x, y) = k(x, x') - 2k(x, y) + k(y, y')\n        if self.normalize:\n            x = F.normalize(x, dim=1)\n            y = F.normalize(y, dim=1)\n        if self.kernel_type == \"linear\":\n            return self.linear_mmd(x, y)\n        elif self.kernel_type == \"poly\":\n            return self.poly_mmd(x, y)\n        elif self.kernel_type == \"rbf\":\n            return self.rbf_mmd(x, y)\n        else:\n            raise NotImplementedError\n\n    def linear_mmd(self, x, y):\n        # k(x, y) = x^T y\n        k_xx = self.remove_self_distance(torch.mm(x, x.t()))\n        k_yy = self.remove_self_distance(torch.mm(y, y.t()))\n        k_xy = torch.mm(x, y.t())\n        return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()\n\n    def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):\n        # k(x, y) = (alpha * x^T y + c)^d\n        k_xx = self.remove_self_distance(torch.mm(x, x.t()))\n        k_xx = (alpha*k_xx + c).pow(d)\n        k_yy = self.remove_self_distance(torch.mm(y, y.t()))\n        k_yy = (alpha*k_yy + c).pow(d)\n        k_xy = torch.mm(x, y.t())\n        k_xy = (alpha*k_xy + c).pow(d)\n        return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()\n\n    def rbf_mmd(self, x, y):\n        # k_xx\n        d_xx = self.euclidean_squared_distance(x, x)\n        d_xx = self.remove_self_distance(d_xx)\n        k_xx = self.rbf_kernel_mixture(d_xx)\n        # k_yy\n        d_yy = self.euclidean_squared_distance(y, y)\n        d_yy = self.remove_self_distance(d_yy)\n        k_yy = self.rbf_kernel_mixture(d_yy)\n        # k_xy\n        d_xy = self.euclidean_squared_distance(x, y)\n        k_xy = self.rbf_kernel_mixture(d_xy)\n        return k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()\n\n    @staticmethod\n    def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):\n        K = 0\n        for sigma in sigmas:\n            gamma = 1.0 / (2.0 * sigma**2)\n            K += torch.exp(-gamma * exponent)\n        return K\n\n    @staticmethod\n    def remove_self_distance(distmat):\n        tmp_list = []\n        for i, row in enumerate(distmat):\n            row1 = torch.cat([row[:i], row[i + 1:]])\n            tmp_list.append(row1)\n        return torch.stack(tmp_list)\n\n    @staticmethod\n    def euclidean_squared_distance(x, y):\n        m, n = x.size(0), y.size(0)\n        distmat = (\n            torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +\n            torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()\n        )\n        # distmat.addmm_(1, -2, x, y.t())\n        distmat.addmm_(x, y.t(), beta=1, alpha=-2)\n        return distmat\n\n\nif __name__ == \"__main__\":\n    mmd = MaximumMeanDiscrepancy(kernel_type=\"rbf\")\n    input1, input2 = torch.rand(3, 100), torch.rand(3, 100)\n    d = mmd(input1, input2)\n    print(d.item())\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\nclass OptimalTransport(nn.Module):\n\n    @staticmethod\n    def distance(batch1, batch2, dist_metric=\"cosine\"):\n        if dist_metric == \"cosine\":\n            batch1 = F.normalize(batch1, p=2, dim=1)\n            batch2 = F.normalize(batch2, p=2, dim=1)\n            dist_mat = 1 - torch.mm(batch1, batch2.t())\n        elif dist_metric == \"euclidean\":\n            m, n = batch1.size(0), batch2.size(0)\n            dist_mat = (\n                torch.pow(batch1, 2).sum(dim=1, keepdim=True).expand(m, n) +\n                torch.pow(batch2, 2).sum(dim=1, keepdim=True).expand(n, m).t()\n            )\n            dist_mat.addmm_(\n                1, -2, batch1, batch2.t()\n            )  # squared euclidean distance\n        elif dist_metric == \"fast_euclidean\":\n            batch1 = batch1.unsqueeze(-2)\n            batch2 = batch2.unsqueeze(-3)\n            dist_mat = torch.sum((torch.abs(batch1 - batch2))**2, -1)\n        else:\n            raise ValueError(\n                \"Unknown cost function: {}. Expected to \"\n                \"be one of [cosine | euclidean]\".format(dist_metric)\n            )\n        return dist_mat\n\n\nclass SinkhornDivergence(OptimalTransport):\n    thre = 1e-3\n\n    def __init__(\n        self,\n        dist_metric=\"cosine\",\n        eps=0.01,\n        max_iter=5,\n        bp_to_sinkhorn=False\n    ):\n        super().__init__()\n        self.dist_metric = dist_metric\n        self.eps = eps\n        self.max_iter = max_iter\n        self.bp_to_sinkhorn = bp_to_sinkhorn\n\n    def forward(self, x, y):\n        # x, y: two batches of data with shape (batch, dim)\n        W_xy = self.transport_cost(x, y)\n        W_xx = self.transport_cost(x, x)\n        W_yy = self.transport_cost(y, y)\n        return 2*W_xy - W_xx - W_yy\n\n    def transport_cost(self, x, y, return_pi=False):\n        C = self.distance(x, y, dist_metric=self.dist_metric)\n        pi = self.sinkhorn_iterate(C, self.eps, self.max_iter, self.thre)\n        if not self.bp_to_sinkhorn:\n            pi = pi.detach()\n        cost = torch.sum(pi * C)\n        if return_pi:\n            return cost, pi\n        return cost\n\n    @staticmethod\n    def sinkhorn_iterate(C, eps, max_iter, thre):\n        nx, ny = C.shape\n        mu = torch.ones(nx, dtype=C.dtype, device=C.device) * (1.0/nx)\n        nu = torch.ones(ny, dtype=C.dtype, device=C.device) * (1.0/ny)\n        u = torch.zeros_like(mu)\n        v = torch.zeros_like(nu)\n\n        def M(_C, _u, _v):\n            \"\"\"Modified cost for logarithmic updates.\n            Eq: M_{ij} = (-c_{ij} + u_i + v_j) / epsilon\n            \"\"\"\n            return (-_C + _u.unsqueeze(-1) + _v.unsqueeze(-2)) / eps\n\n        real_iter = 0  # check if algorithm terminates before max_iter\n        # Sinkhorn iterations\n        for i in range(max_iter):\n            u0 = u\n            u = eps * (\n                torch.log(mu + 1e-8) - torch.logsumexp(M(C, u, v), dim=1)\n            ) + u\n            v = (\n                eps * (\n                    torch.log(nu + 1e-8) -\n                    torch.logsumexp(M(C, u, v).permute(1, 0), dim=1)\n                ) + v\n            )\n            err = (u - u0).abs().sum()\n            real_iter += 1\n            if err.item() < thre:\n                break\n        # Transport plan pi = diag(a)*K*diag(b)\n        return torch.exp(M(C, u, v))\n\n\nclass MinibatchEnergyDistance(SinkhornDivergence):\n\n    def __init__(\n        self,\n        dist_metric=\"cosine\",\n        eps=0.01,\n        max_iter=5,\n        bp_to_sinkhorn=False\n    ):\n        super().__init__(\n            dist_metric=dist_metric,\n            eps=eps,\n            max_iter=max_iter,\n            bp_to_sinkhorn=bp_to_sinkhorn,\n        )\n\n    def forward(self, x, y):\n        x1, x2 = torch.split(x, x.size(0) // 2, dim=0)\n        y1, y2 = torch.split(y, y.size(0) // 2, dim=0)\n        cost = 0\n        cost += self.transport_cost(x1, y1)\n        cost += self.transport_cost(x1, y2)\n        cost += self.transport_cost(x2, y1)\n        cost += self.transport_cost(x2, y2)\n        cost -= 2 * self.transport_cost(x1, x2)\n        cost -= 2 * self.transport_cost(y1, y2)\n        return cost\n\n\nif __name__ == \"__main__\":\n    # example: https://dfdazac.github.io/sinkhorn.html\n    import numpy as np\n\n    n_points = 5\n    a = np.array([[i, 0] for i in range(n_points)])\n    b = np.array([[i, 1] for i in range(n_points)])\n    x = torch.tensor(a, dtype=torch.float)\n    y = torch.tensor(b, dtype=torch.float)\n    sinkhorn = SinkhornDivergence(\n        dist_metric=\"euclidean\", eps=0.01, max_iter=5\n    )\n    dist, pi = sinkhorn.transport_cost(x, y, True)\n    import pdb\n\n    pdb.set_trace()\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py",
    "content": "import torch.nn as nn\nfrom torch.autograd import Function\n\n\nclass _ReverseGrad(Function):\n\n    @staticmethod\n    def forward(ctx, input, grad_scaling):\n        ctx.grad_scaling = grad_scaling\n        return input.view_as(input)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_scaling = ctx.grad_scaling\n        return -grad_scaling * grad_output, None\n\n\nreverse_grad = _ReverseGrad.apply\n\n\nclass ReverseGrad(nn.Module):\n    \"\"\"Gradient reversal layer.\n\n    It acts as an identity layer in the forward,\n    but reverses the sign of the gradient in\n    the backward.\n    \"\"\"\n\n    def forward(self, x, grad_scaling=1.0):\n        assert (grad_scaling >=\n                0), \"grad_scaling must be non-negative, \" \"but got {}\".format(\n                    grad_scaling\n                )\n        return reverse_grad(x, grad_scaling)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py",
    "content": "import torch.nn as nn\n\n\nclass Sequential2(nn.Sequential):\n    \"\"\"An alternative sequential container to nn.Sequential,\n    which accepts an arbitrary number of input arguments.\n    \"\"\"\n\n    def forward(self, *inputs):\n        for module in self._modules.values():\n            if isinstance(inputs, tuple):\n                inputs = module(*inputs)\n            else:\n                inputs = module(inputs)\n        return inputs\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py",
    "content": "import torch\nimport torch.nn as nn\n\n\nclass _TransNorm(nn.Module):\n    \"\"\"Transferable normalization.\n\n    Reference:\n        - Wang et al. Transferable Normalization: Towards Improving\n        Transferability of Deep Neural Networks. NeurIPS 2019.\n\n    Args:\n        num_features (int): number of features.\n        eps (float): epsilon.\n        momentum (float): value for updating running_mean and running_var.\n        adaptive_alpha (bool): apply domain adaptive alpha.\n    \"\"\"\n\n    def __init__(\n        self, num_features, eps=1e-5, momentum=0.1, adaptive_alpha=True\n    ):\n        super().__init__()\n        self.num_features = num_features\n        self.eps = eps\n        self.momentum = momentum\n        self.adaptive_alpha = adaptive_alpha\n\n        self.register_buffer(\"running_mean_s\", torch.zeros(num_features))\n        self.register_buffer(\"running_var_s\", torch.ones(num_features))\n        self.register_buffer(\"running_mean_t\", torch.zeros(num_features))\n        self.register_buffer(\"running_var_t\", torch.ones(num_features))\n\n        self.weight = nn.Parameter(torch.ones(num_features))\n        self.bias = nn.Parameter(torch.zeros(num_features))\n\n    def resnet_running_stats(self):\n        self.running_mean_s.zero_()\n        self.running_var_s.fill_(1)\n        self.running_mean_t.zero_()\n        self.running_var_t.fill_(1)\n\n    def reset_parameters(self):\n        nn.init.ones_(self.weight)\n        nn.init.zeros_(self.bias)\n\n    def _check_input(self, x):\n        raise NotImplementedError\n\n    def _compute_alpha(self, mean_s, var_s, mean_t, var_t):\n        C = self.num_features\n        ratio_s = mean_s / (var_s + self.eps).sqrt()\n        ratio_t = mean_t / (var_t + self.eps).sqrt()\n        dist = (ratio_s - ratio_t).abs()\n        dist_inv = 1 / (1+dist)\n        return C * dist_inv / dist_inv.sum()\n\n    def forward(self, input):\n        self._check_input(input)\n        C = self.num_features\n        if input.dim() == 2:\n            new_shape = (1, C)\n        elif input.dim() == 4:\n            new_shape = (1, C, 1, 1)\n        else:\n            raise ValueError\n\n        weight = self.weight.view(*new_shape)\n        bias = self.bias.view(*new_shape)\n\n        if not self.training:\n            mean_t = self.running_mean_t.view(*new_shape)\n            var_t = self.running_var_t.view(*new_shape)\n            output = (input-mean_t) / (var_t + self.eps).sqrt()\n            output = output*weight + bias\n\n            if self.adaptive_alpha:\n                mean_s = self.running_mean_s.view(*new_shape)\n                var_s = self.running_var_s.view(*new_shape)\n                alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)\n                alpha = alpha.reshape(*new_shape)\n                output = (1 + alpha.detach()) * output\n\n            return output\n\n        input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)\n\n        x_s = input_s.transpose(0, 1).reshape(C, -1)\n        mean_s = x_s.mean(1)\n        var_s = x_s.var(1)\n        self.running_mean_s.mul_(self.momentum)\n        self.running_mean_s.add_((1 - self.momentum) * mean_s.data)\n        self.running_var_s.mul_(self.momentum)\n        self.running_var_s.add_((1 - self.momentum) * var_s.data)\n        mean_s = mean_s.reshape(*new_shape)\n        var_s = var_s.reshape(*new_shape)\n        output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()\n        output_s = output_s*weight + bias\n\n        x_t = input_t.transpose(0, 1).reshape(C, -1)\n        mean_t = x_t.mean(1)\n        var_t = x_t.var(1)\n        self.running_mean_t.mul_(self.momentum)\n        self.running_mean_t.add_((1 - self.momentum) * mean_t.data)\n        self.running_var_t.mul_(self.momentum)\n        self.running_var_t.add_((1 - self.momentum) * var_t.data)\n        mean_t = mean_t.reshape(*new_shape)\n        var_t = var_t.reshape(*new_shape)\n        output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()\n        output_t = output_t*weight + bias\n\n        output = torch.cat([output_s, output_t], 0)\n\n        if self.adaptive_alpha:\n            alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)\n            alpha = alpha.reshape(*new_shape)\n            output = (1 + alpha.detach()) * output\n\n        return output\n\n\nclass TransNorm1d(_TransNorm):\n\n    def _check_input(self, x):\n        if x.dim() != 2:\n            raise ValueError(\n                \"Expected the input to be 2-D, \"\n                \"but got {}-D\".format(x.dim())\n            )\n\n\nclass TransNorm2d(_TransNorm):\n\n    def _check_input(self, x):\n        if x.dim() != 4:\n            raise ValueError(\n                \"Expected the input to be 4-D, \"\n                \"but got {}-D\".format(x.dim())\n            )\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py",
    "content": "import numpy as np\nimport torch\n\n\ndef sharpen_prob(p, temperature=2):\n    \"\"\"Sharpening probability with a temperature.\n\n    Args:\n        p (torch.Tensor): probability matrix (batch_size, n_classes)\n        temperature (float): temperature.\n    \"\"\"\n    p = p.pow(temperature)\n    return p / p.sum(1, keepdim=True)\n\n\ndef reverse_index(data, label):\n    \"\"\"Reverse order.\"\"\"\n    inv_idx = torch.arange(data.size(0) - 1, -1, -1).long()\n    return data[inv_idx], label[inv_idx]\n\n\ndef shuffle_index(data, label):\n    \"\"\"Shuffle order.\"\"\"\n    rnd_idx = torch.randperm(data.shape[0])\n    return data[rnd_idx], label[rnd_idx]\n\n\ndef create_onehot(label, num_classes):\n    \"\"\"Create one-hot tensor.\n\n    We suggest using nn.functional.one_hot.\n\n    Args:\n        label (torch.Tensor): 1-D tensor.\n        num_classes (int): number of classes.\n    \"\"\"\n    onehot = torch.zeros(label.shape[0], num_classes)\n    return onehot.scatter(1, label.unsqueeze(1).data.cpu(), 1)\n\n\ndef sigmoid_rampup(current, rampup_length):\n    \"\"\"Exponential rampup.\n\n    Args:\n        current (int): current step.\n        rampup_length (int): maximum step.\n    \"\"\"\n    assert rampup_length > 0\n    current = np.clip(current, 0.0, rampup_length)\n    phase = 1.0 - current/rampup_length\n    return float(np.exp(-5.0 * phase * phase))\n\n\ndef linear_rampup(current, rampup_length):\n    \"\"\"Linear rampup.\n\n    Args:\n        current (int): current step.\n        rampup_length (int): maximum step.\n    \"\"\"\n    assert rampup_length > 0\n    ratio = np.clip(current / rampup_length, 0.0, 1.0)\n    return float(ratio)\n\n\ndef ema_model_update(model, ema_model, alpha):\n    \"\"\"Exponential moving average of model parameters.\n\n    Args:\n        model (nn.Module): model being trained.\n        ema_model (nn.Module): ema of the model.\n        alpha (float): ema decay rate.\n    \"\"\"\n    for ema_param, param in zip(ema_model.parameters(), model.parameters()):\n        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/__init__.py",
    "content": "from .optimizer import build_optimizer\nfrom .lr_scheduler import build_lr_scheduler\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/lr_scheduler.py",
    "content": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport torch\nfrom torch.optim.lr_scheduler import _LRScheduler\n\nAVAI_SCHEDS = [\"single_step\", \"multi_step\", \"cosine\"]\n\n\nclass _BaseWarmupScheduler(_LRScheduler):\n\n    def __init__(\n        self,\n        optimizer,\n        successor,\n        warmup_epoch,\n        last_epoch=-1,\n        verbose=False\n    ):\n        self.successor = successor\n        self.warmup_epoch = warmup_epoch\n        super().__init__(optimizer, last_epoch, verbose)\n\n    def get_lr(self):\n        raise NotImplementedError\n\n    def step(self, epoch=None):\n        if self.last_epoch >= self.warmup_epoch:\n            self.successor.step(epoch)\n            self._last_lr = self.successor.get_last_lr()\n        else:\n            super().step(epoch)\n\n\nclass ConstantWarmupScheduler(_BaseWarmupScheduler):\n\n    def __init__(\n        self,\n        optimizer,\n        successor,\n        warmup_epoch,\n        cons_lr,\n        last_epoch=-1,\n        verbose=False\n    ):\n        self.cons_lr = cons_lr\n        super().__init__(\n            optimizer, successor, warmup_epoch, last_epoch, verbose\n        )\n\n    def get_lr(self):\n        if self.last_epoch >= self.warmup_epoch:\n            return self.successor.get_last_lr()\n        return [self.cons_lr for _ in self.base_lrs]\n\n\nclass LinearWarmupScheduler(_BaseWarmupScheduler):\n\n    def __init__(\n        self,\n        optimizer,\n        successor,\n        warmup_epoch,\n        min_lr,\n        last_epoch=-1,\n        verbose=False\n    ):\n        self.min_lr = min_lr\n        super().__init__(\n            optimizer, successor, warmup_epoch, last_epoch, verbose\n        )\n\n    def get_lr(self):\n        if self.last_epoch >= self.warmup_epoch:\n            return self.successor.get_last_lr()\n        if self.last_epoch == 0:\n            return [self.min_lr for _ in self.base_lrs]\n        return [\n            lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs\n        ]\n\n\ndef build_lr_scheduler(optimizer, optim_cfg):\n    \"\"\"A function wrapper for building a learning rate scheduler.\n\n    Args:\n        optimizer (Optimizer): an Optimizer.\n        optim_cfg (CfgNode): optimization config.\n    \"\"\"\n    lr_scheduler = optim_cfg.LR_SCHEDULER\n    stepsize = optim_cfg.STEPSIZE\n    gamma = optim_cfg.GAMMA\n    max_epoch = optim_cfg.MAX_EPOCH\n\n    if lr_scheduler not in AVAI_SCHEDS:\n        raise ValueError(\n            \"Unsupported scheduler: {}. Must be one of {}\".format(\n                lr_scheduler, AVAI_SCHEDS\n            )\n        )\n\n    if lr_scheduler == \"single_step\":\n        if isinstance(stepsize, (list, tuple)):\n            stepsize = stepsize[-1]\n\n        if not isinstance(stepsize, int):\n            raise TypeError(\n                \"For single_step lr_scheduler, stepsize must \"\n                \"be an integer, but got {}\".format(type(stepsize))\n            )\n\n        if stepsize <= 0:\n            stepsize = max_epoch\n\n        scheduler = torch.optim.lr_scheduler.StepLR(\n            optimizer, step_size=stepsize, gamma=gamma\n        )\n\n    elif lr_scheduler == \"multi_step\":\n        if not isinstance(stepsize, (list, tuple)):\n            raise TypeError(\n                \"For multi_step lr_scheduler, stepsize must \"\n                \"be a list, but got {}\".format(type(stepsize))\n            )\n\n        scheduler = torch.optim.lr_scheduler.MultiStepLR(\n            optimizer, milestones=stepsize, gamma=gamma\n        )\n\n    elif lr_scheduler == \"cosine\":\n        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n            optimizer, float(max_epoch)\n        )\n\n    if optim_cfg.WARMUP_EPOCH > 0:\n        if not optim_cfg.WARMUP_RECOUNT:\n            scheduler.last_epoch = optim_cfg.WARMUP_EPOCH\n\n        if optim_cfg.WARMUP_TYPE == \"constant\":\n            scheduler = ConstantWarmupScheduler(\n                optimizer, scheduler, optim_cfg.WARMUP_EPOCH,\n                optim_cfg.WARMUP_CONS_LR\n            )\n\n        elif optim_cfg.WARMUP_TYPE == \"linear\":\n            scheduler = LinearWarmupScheduler(\n                optimizer, scheduler, optim_cfg.WARMUP_EPOCH,\n                optim_cfg.WARMUP_MIN_LR\n            )\n\n        else:\n            raise ValueError\n\n    return scheduler\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/optimizer.py",
    "content": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport warnings\nimport torch\nimport torch.nn as nn\n\nfrom .radam import RAdam\n\nAVAI_OPTIMS = [\"adam\", \"amsgrad\", \"sgd\", \"rmsprop\", \"radam\", \"adamw\"]\n\n\ndef build_optimizer(model, optim_cfg):\n    \"\"\"A function wrapper for building an optimizer.\n\n    Args:\n        model (nn.Module or iterable): model.\n        optim_cfg (CfgNode): optimization config.\n    \"\"\"\n    optim = optim_cfg.NAME\n    lr = optim_cfg.LR\n    weight_decay = optim_cfg.WEIGHT_DECAY\n    momentum = optim_cfg.MOMENTUM\n    sgd_dampening = optim_cfg.SGD_DAMPNING\n    sgd_nesterov = optim_cfg.SGD_NESTEROV\n    rmsprop_alpha = optim_cfg.RMSPROP_ALPHA\n    adam_beta1 = optim_cfg.ADAM_BETA1\n    adam_beta2 = optim_cfg.ADAM_BETA2\n    staged_lr = optim_cfg.STAGED_LR\n    new_layers = optim_cfg.NEW_LAYERS\n    base_lr_mult = optim_cfg.BASE_LR_MULT\n\n    if optim not in AVAI_OPTIMS:\n        raise ValueError(\n            \"Unsupported optim: {}. Must be one of {}\".format(\n                optim, AVAI_OPTIMS\n            )\n        )\n\n    if staged_lr:\n        if not isinstance(model, nn.Module):\n            raise TypeError(\n                \"When staged_lr is True, model given to \"\n                \"build_optimizer() must be an instance of nn.Module\"\n            )\n\n        if isinstance(model, nn.DataParallel):\n            model = model.module\n\n        if isinstance(new_layers, str):\n            if new_layers is None:\n                warnings.warn(\n                    \"new_layers is empty, therefore, staged_lr is useless\"\n                )\n            new_layers = [new_layers]\n\n        base_params = []\n        base_layers = []\n        new_params = []\n\n        for name, module in model.named_children():\n            if name in new_layers:\n                new_params += [p for p in module.parameters()]\n            else:\n                base_params += [p for p in module.parameters()]\n                base_layers.append(name)\n\n        param_groups = [\n            {\n                \"params\": base_params,\n                \"lr\": lr * base_lr_mult\n            },\n            {\n                \"params\": new_params\n            },\n        ]\n\n    else:\n        if isinstance(model, nn.Module):\n            param_groups = model.parameters()\n        else:\n            param_groups = model\n\n    if optim == \"adam\":\n        optimizer = torch.optim.Adam(\n            param_groups,\n            lr=lr,\n            weight_decay=weight_decay,\n            betas=(adam_beta1, adam_beta2),\n        )\n\n    elif optim == \"amsgrad\":\n        optimizer = torch.optim.Adam(\n            param_groups,\n            lr=lr,\n            weight_decay=weight_decay,\n            betas=(adam_beta1, adam_beta2),\n            amsgrad=True,\n        )\n\n    elif optim == \"sgd\":\n        optimizer = torch.optim.SGD(\n            param_groups,\n            lr=lr,\n            momentum=momentum,\n            weight_decay=weight_decay,\n            dampening=sgd_dampening,\n            nesterov=sgd_nesterov,\n        )\n\n    elif optim == \"rmsprop\":\n        optimizer = torch.optim.RMSprop(\n            param_groups,\n            lr=lr,\n            momentum=momentum,\n            weight_decay=weight_decay,\n            alpha=rmsprop_alpha,\n        )\n\n    elif optim == \"radam\":\n        optimizer = RAdam(\n            param_groups,\n            lr=lr,\n            weight_decay=weight_decay,\n            betas=(adam_beta1, adam_beta2),\n        )\n\n    elif optim == \"adamw\":\n        optimizer = torch.optim.AdamW(\n            param_groups,\n            lr=lr,\n            weight_decay=weight_decay,\n            betas=(adam_beta1, adam_beta2),\n        )\n\n    return optimizer\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/radam.py",
    "content": "\"\"\"\nImported from: https://github.com/LiyuanLucasLiu/RAdam\n\nhttps://arxiv.org/abs/1908.03265\n\n@article{liu2019radam,\n  title={On the Variance of the Adaptive Learning Rate and Beyond},\n  author={Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei},\n  journal={arXiv preprint arXiv:1908.03265},\n  year={2019}\n}\n\"\"\"\nimport math\nimport torch\nfrom torch.optim.optimizer import Optimizer\n\n\nclass RAdam(Optimizer):\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        betas=(0.9, 0.999),\n        eps=1e-8,\n        weight_decay=0,\n        degenerated_to_sgd=True,\n    ):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\n                \"Invalid beta parameter at index 0: {}\".format(betas[0])\n            )\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\n                \"Invalid beta parameter at index 1: {}\".format(betas[1])\n            )\n\n        self.degenerated_to_sgd = degenerated_to_sgd\n        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n        self.buffer = [[None, None, None] for ind in range(10)]\n        super(RAdam, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(RAdam, self).__setstate__(state)\n\n    def step(self, closure=None):\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data.float()\n                if grad.is_sparse:\n                    raise RuntimeError(\n                        \"RAdam does not support sparse gradients\"\n                    )\n\n                p_data_fp32 = p.data.float()\n\n                state = self.state[p]\n\n                if len(state) == 0:\n                    state[\"step\"] = 0\n                    state[\"exp_avg\"] = torch.zeros_like(p_data_fp32)\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p_data_fp32)\n                else:\n                    state[\"exp_avg\"] = state[\"exp_avg\"].type_as(p_data_fp32)\n                    state[\"exp_avg_sq\"] = state[\"exp_avg_sq\"].type_as(\n                        p_data_fp32\n                    )\n\n                exp_avg, exp_avg_sq = state[\"exp_avg\"], state[\"exp_avg_sq\"]\n                beta1, beta2 = group[\"betas\"]\n\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n\n                state[\"step\"] += 1\n                buffered = self.buffer[int(state[\"step\"] % 10)]\n                if state[\"step\"] == buffered[0]:\n                    N_sma, step_size = buffered[1], buffered[2]\n                else:\n                    buffered[0] = state[\"step\"]\n                    beta2_t = beta2**state[\"step\"]\n                    N_sma_max = 2 / (1-beta2) - 1\n                    N_sma = N_sma_max - 2 * state[\"step\"\n                                                  ] * beta2_t / (1-beta2_t)\n                    buffered[1] = N_sma\n\n                    # more conservative since it's an approximated value\n                    if N_sma >= 5:\n                        step_size = math.sqrt(\n                            (1-beta2_t) * (N_sma-4) / (N_sma_max-4) *\n                            (N_sma-2) / N_sma * N_sma_max / (N_sma_max-2)\n                        ) / (1 - beta1**state[\"step\"])\n                    elif self.degenerated_to_sgd:\n                        step_size = 1.0 / (1 - beta1**state[\"step\"])\n                    else:\n                        step_size = -1\n                    buffered[2] = step_size\n\n                # more conservative since it's an approximated value\n                if N_sma >= 5:\n                    if group[\"weight_decay\"] != 0:\n                        p_data_fp32.add_(\n                            -group[\"weight_decay\"] * group[\"lr\"], p_data_fp32\n                        )\n                    denom = exp_avg_sq.sqrt().add_(group[\"eps\"])\n                    p_data_fp32.addcdiv_(\n                        -step_size * group[\"lr\"], exp_avg, denom\n                    )\n                    p.data.copy_(p_data_fp32)\n                elif step_size > 0:\n                    if group[\"weight_decay\"] != 0:\n                        p_data_fp32.add_(\n                            -group[\"weight_decay\"] * group[\"lr\"], p_data_fp32\n                        )\n                    p_data_fp32.add_(-step_size * group[\"lr\"], exp_avg)\n                    p.data.copy_(p_data_fp32)\n\n        return loss\n\n\nclass PlainRAdam(Optimizer):\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        betas=(0.9, 0.999),\n        eps=1e-8,\n        weight_decay=0,\n        degenerated_to_sgd=True,\n    ):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\n                \"Invalid beta parameter at index 0: {}\".format(betas[0])\n            )\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\n                \"Invalid beta parameter at index 1: {}\".format(betas[1])\n            )\n\n        self.degenerated_to_sgd = degenerated_to_sgd\n        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n\n        super(PlainRAdam, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(PlainRAdam, self).__setstate__(state)\n\n    def step(self, closure=None):\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data.float()\n                if grad.is_sparse:\n                    raise RuntimeError(\n                        \"RAdam does not support sparse gradients\"\n                    )\n\n                p_data_fp32 = p.data.float()\n\n                state = self.state[p]\n\n                if len(state) == 0:\n                    state[\"step\"] = 0\n                    state[\"exp_avg\"] = torch.zeros_like(p_data_fp32)\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p_data_fp32)\n                else:\n                    state[\"exp_avg\"] = state[\"exp_avg\"].type_as(p_data_fp32)\n                    state[\"exp_avg_sq\"] = state[\"exp_avg_sq\"].type_as(\n                        p_data_fp32\n                    )\n\n                exp_avg, exp_avg_sq = state[\"exp_avg\"], state[\"exp_avg_sq\"]\n                beta1, beta2 = group[\"betas\"]\n\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n\n                state[\"step\"] += 1\n                beta2_t = beta2**state[\"step\"]\n                N_sma_max = 2 / (1-beta2) - 1\n                N_sma = N_sma_max - 2 * state[\"step\"] * beta2_t / (1-beta2_t)\n\n                # more conservative since it's an approximated value\n                if N_sma >= 5:\n                    if group[\"weight_decay\"] != 0:\n                        p_data_fp32.add_(\n                            -group[\"weight_decay\"] * group[\"lr\"], p_data_fp32\n                        )\n                    step_size = (\n                        group[\"lr\"] * math.sqrt(\n                            (1-beta2_t) * (N_sma-4) / (N_sma_max-4) *\n                            (N_sma-2) / N_sma * N_sma_max / (N_sma_max-2)\n                        ) / (1 - beta1**state[\"step\"])\n                    )\n                    denom = exp_avg_sq.sqrt().add_(group[\"eps\"])\n                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)\n                    p.data.copy_(p_data_fp32)\n                elif self.degenerated_to_sgd:\n                    if group[\"weight_decay\"] != 0:\n                        p_data_fp32.add_(\n                            -group[\"weight_decay\"] * group[\"lr\"], p_data_fp32\n                        )\n                    step_size = group[\"lr\"] / (1 - beta1**state[\"step\"])\n                    p_data_fp32.add_(-step_size, exp_avg)\n                    p.data.copy_(p_data_fp32)\n\n        return loss\n\n\nclass AdamW(Optimizer):\n\n    def __init__(\n        self,\n        params,\n        lr=1e-3,\n        betas=(0.9, 0.999),\n        eps=1e-8,\n        weight_decay=0,\n        warmup=0\n    ):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\n                \"Invalid beta parameter at index 0: {}\".format(betas[0])\n            )\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\n                \"Invalid beta parameter at index 1: {}\".format(betas[1])\n            )\n\n        defaults = dict(\n            lr=lr,\n            betas=betas,\n            eps=eps,\n            weight_decay=weight_decay,\n            warmup=warmup\n        )\n        super(AdamW, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(AdamW, self).__setstate__(state)\n\n    def step(self, closure=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data.float()\n                if grad.is_sparse:\n                    raise RuntimeError(\n                        \"Adam does not support sparse gradients, please consider SparseAdam instead\"\n                    )\n\n                p_data_fp32 = p.data.float()\n\n                state = self.state[p]\n\n                if len(state) == 0:\n                    state[\"step\"] = 0\n                    state[\"exp_avg\"] = torch.zeros_like(p_data_fp32)\n                    state[\"exp_avg_sq\"] = torch.zeros_like(p_data_fp32)\n                else:\n                    state[\"exp_avg\"] = state[\"exp_avg\"].type_as(p_data_fp32)\n                    state[\"exp_avg_sq\"] = state[\"exp_avg_sq\"].type_as(\n                        p_data_fp32\n                    )\n\n                exp_avg, exp_avg_sq = state[\"exp_avg\"], state[\"exp_avg_sq\"]\n                beta1, beta2 = group[\"betas\"]\n\n                state[\"step\"] += 1\n\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n\n                denom = exp_avg_sq.sqrt().add_(group[\"eps\"])\n                bias_correction1 = 1 - beta1**state[\"step\"]\n                bias_correction2 = 1 - beta2**state[\"step\"]\n\n                if group[\"warmup\"] > state[\"step\"]:\n                    scheduled_lr = 1e-8 + state[\"step\"] * group[\"lr\"] / group[\n                        \"warmup\"]\n                else:\n                    scheduled_lr = group[\"lr\"]\n\n                step_size = (\n                    scheduled_lr * math.sqrt(bias_correction2) /\n                    bias_correction1\n                )\n\n                if group[\"weight_decay\"] != 0:\n                    p_data_fp32.add_(\n                        -group[\"weight_decay\"] * scheduled_lr, p_data_fp32\n                    )\n\n                p_data_fp32.addcdiv_(-step_size, exp_avg, denom)\n\n                p.data.copy_(p_data_fp32)\n\n        return loss\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/__init__.py",
    "content": "from .tools import *\nfrom .logger import *\nfrom .meters import *\nfrom .registry import *\nfrom .torchtools import *\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/logger.py",
    "content": "import os\nimport sys\nimport time\nimport os.path as osp\n\nfrom .tools import mkdir_if_missing\n\n__all__ = [\"Logger\", \"setup_logger\"]\n\n\nclass Logger:\n    \"\"\"Write console output to external text file.\n\n    Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py>`_\n\n    Args:\n        fpath (str): directory to save logging file.\n\n    Examples::\n       >>> import sys\n       >>> import os.path as osp\n       >>> save_dir = 'output/experiment-1'\n       >>> log_name = 'train.log'\n       >>> sys.stdout = Logger(osp.join(save_dir, log_name))\n    \"\"\"\n\n    def __init__(self, fpath=None):\n        self.console = sys.stdout\n        self.file = None\n        if fpath is not None:\n            mkdir_if_missing(osp.dirname(fpath))\n            self.file = open(fpath, \"w\")\n\n    def __del__(self):\n        self.close()\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, *args):\n        self.close()\n\n    def write(self, msg):\n        self.console.write(msg)\n        if self.file is not None:\n            self.file.write(msg)\n\n    def flush(self):\n        self.console.flush()\n        if self.file is not None:\n            self.file.flush()\n            os.fsync(self.file.fileno())\n\n    def close(self):\n        self.console.close()\n        if self.file is not None:\n            self.file.close()\n\n\ndef setup_logger(output=None):\n    if output is None:\n        return\n\n    if output.endswith(\".txt\") or output.endswith(\".log\"):\n        fpath = output\n    else:\n        fpath = osp.join(output, \"log.txt\")\n\n    if osp.exists(fpath):\n        # make sure the existing log file is not over-written\n        fpath += time.strftime(\"-%Y-%m-%d-%H-%M-%S\")\n\n    sys.stdout = Logger(fpath)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/meters.py",
    "content": "from collections import defaultdict\nimport torch\n\n__all__ = [\"AverageMeter\", \"MetricMeter\"]\n\n\nclass AverageMeter:\n    \"\"\"Compute and store the average and current value.\n\n    Examples::\n        >>> # 1. Initialize a meter to record loss\n        >>> losses = AverageMeter()\n        >>> # 2. Update meter after every mini-batch update\n        >>> losses.update(loss_value, batch_size)\n    \"\"\"\n\n    def __init__(self, ema=False):\n        \"\"\"\n        Args:\n            ema (bool, optional): apply exponential moving average.\n        \"\"\"\n        self.ema = ema\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        if isinstance(val, torch.Tensor):\n            val = val.item()\n\n        self.val = val\n        self.sum += val * n\n        self.count += n\n\n        if self.ema:\n            self.avg = self.avg * 0.9 + self.val * 0.1\n        else:\n            self.avg = self.sum / self.count\n\n\nclass MetricMeter:\n    \"\"\"Store the average and current value for a set of metrics.\n\n    Examples::\n        >>> # 1. Create an instance of MetricMeter\n        >>> metric = MetricMeter()\n        >>> # 2. Update using a dictionary as input\n        >>> input_dict = {'loss_1': value_1, 'loss_2': value_2}\n        >>> metric.update(input_dict)\n        >>> # 3. Convert to string and print\n        >>> print(str(metric))\n    \"\"\"\n\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(AverageMeter)\n        self.delimiter = delimiter\n\n    def update(self, input_dict):\n        if input_dict is None:\n            return\n\n        if not isinstance(input_dict, dict):\n            raise TypeError(\n                \"Input to MetricMeter.update() must be a dictionary\"\n            )\n\n        for k, v in input_dict.items():\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            self.meters[k].update(v)\n\n    def __str__(self):\n        output_str = []\n        for name, meter in self.meters.items():\n            output_str.append(f\"{name} {meter.val:.4f} ({meter.avg:.4f})\")\n        return self.delimiter.join(output_str)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/registry.py",
    "content": "\"\"\"\nModified from https://github.com/facebookresearch/fvcore\n\"\"\"\n__all__ = [\"Registry\"]\n\n\nclass Registry:\n    \"\"\"A registry providing name -> object mapping, to support\n    custom modules.\n\n    To create a registry (e.g. a backbone registry):\n\n    .. code-block:: python\n\n        BACKBONE_REGISTRY = Registry('BACKBONE')\n\n    To register an object:\n\n    .. code-block:: python\n\n        @BACKBONE_REGISTRY.register()\n        class MyBackbone(nn.Module):\n            ...\n\n    Or:\n\n    .. code-block:: python\n\n        BACKBONE_REGISTRY.register(MyBackbone)\n    \"\"\"\n\n    def __init__(self, name):\n        self._name = name\n        self._obj_map = dict()\n\n    def _do_register(self, name, obj, force=False):\n        if name in self._obj_map and not force:\n            raise KeyError(\n                'An object named \"{}\" was already '\n                'registered in \"{}\" registry'.format(name, self._name)\n            )\n\n        self._obj_map[name] = obj\n\n    def register(self, obj=None, force=False):\n        if obj is None:\n            # Used as a decorator\n            def wrapper(fn_or_class):\n                name = fn_or_class.__name__\n                self._do_register(name, fn_or_class, force=force)\n                return fn_or_class\n\n            return wrapper\n\n        # Used as a function call\n        name = obj.__name__\n        self._do_register(name, obj, force=force)\n\n    def get(self, name):\n        if name not in self._obj_map:\n            raise KeyError(\n                'Object name \"{}\" does not exist '\n                'in \"{}\" registry'.format(name, self._name)\n            )\n\n        return self._obj_map[name]\n\n    def registered_names(self):\n        return list(self._obj_map.keys())\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/tools.py",
    "content": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport os\nimport sys\nimport json\nimport time\nimport errno\nimport numpy as np\nimport random\nimport os.path as osp\nimport warnings\nfrom difflib import SequenceMatcher\nimport PIL\nimport torch\nfrom PIL import Image\n\n__all__ = [\n    \"mkdir_if_missing\",\n    \"check_isfile\",\n    \"read_json\",\n    \"write_json\",\n    \"set_random_seed\",\n    \"download_url\",\n    \"read_image\",\n    \"collect_env_info\",\n    \"listdir_nohidden\",\n    \"get_most_similar_str_to_a_from_b\",\n    \"check_availability\",\n    \"tolist_if_not\",\n]\n\n\ndef mkdir_if_missing(dirname):\n    \"\"\"Create dirname if it is missing.\"\"\"\n    if not osp.exists(dirname):\n        try:\n            os.makedirs(dirname)\n        except OSError as e:\n            if e.errno != errno.EEXIST:\n                raise\n\n\ndef check_isfile(fpath):\n    \"\"\"Check if the given path is a file.\n\n    Args:\n        fpath (str): file path.\n\n    Returns:\n       bool\n    \"\"\"\n    isfile = osp.isfile(fpath)\n    if not isfile:\n        warnings.warn('No file found at \"{}\"'.format(fpath))\n    return isfile\n\n\ndef read_json(fpath):\n    \"\"\"Read json file from a path.\"\"\"\n    with open(fpath, \"r\") as f:\n        obj = json.load(f)\n    return obj\n\n\ndef write_json(obj, fpath):\n    \"\"\"Writes to a json file.\"\"\"\n    mkdir_if_missing(osp.dirname(fpath))\n    with open(fpath, \"w\") as f:\n        json.dump(obj, f, indent=4, separators=(\",\", \": \"))\n\n\ndef set_random_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef download_url(url, dst):\n    \"\"\"Download file from a url to a destination.\n\n    Args:\n        url (str): url to download file.\n        dst (str): destination path.\n    \"\"\"\n    from six.moves import urllib\n\n    print('* url=\"{}\"'.format(url))\n    print('* destination=\"{}\"'.format(dst))\n\n    def _reporthook(count, block_size, total_size):\n        global start_time\n        if count == 0:\n            start_time = time.time()\n            return\n        duration = time.time() - start_time\n        progress_size = int(count * block_size)\n        speed = int(progress_size / (1024*duration))\n        percent = int(count * block_size * 100 / total_size)\n        sys.stdout.write(\n            \"\\r...%d%%, %d MB, %d KB/s, %d seconds passed\" %\n            (percent, progress_size / (1024*1024), speed, duration)\n        )\n        sys.stdout.flush()\n\n    urllib.request.urlretrieve(url, dst, _reporthook)\n    sys.stdout.write(\"\\n\")\n\n\ndef read_image(path):\n    \"\"\"Read image from path using ``PIL.Image``.\n\n    Args:\n        path (str): path to an image.\n\n    Returns:\n        PIL image\n    \"\"\"\n    if not osp.exists(path):\n        raise IOError(\"No file exists at {}\".format(path))\n\n    while True:\n        try:\n            img = Image.open(path).convert(\"RGB\")\n            return img\n        except IOError:\n            print(\n                \"Cannot read image from {}, \"\n                \"probably due to heavy IO. Will re-try\".format(path)\n            )\n\n\ndef collect_env_info():\n    \"\"\"Return env info as a string.\n\n    Code source: github.com/facebookresearch/maskrcnn-benchmark\n    \"\"\"\n    from torch.utils.collect_env import get_pretty_env_info\n\n    env_str = get_pretty_env_info()\n    env_str += \"\\n        Pillow ({})\".format(PIL.__version__)\n    return env_str\n\n\ndef listdir_nohidden(path, sort=False):\n    \"\"\"List non-hidden items in a directory.\n\n    Args:\n         path (str): directory path.\n         sort (bool): sort the items.\n    \"\"\"\n    items = [f for f in os.listdir(path) if not f.startswith(\".\")]\n    if sort:\n        items.sort()\n    return items\n\n\ndef get_most_similar_str_to_a_from_b(a, b):\n    \"\"\"Return the most similar string to a in b.\n\n    Args:\n        a (str): probe string.\n        b (list): a list of candidate strings.\n    \"\"\"\n    highest_sim = 0\n    chosen = None\n    for candidate in b:\n        sim = SequenceMatcher(None, a, candidate).ratio()\n        if sim >= highest_sim:\n            highest_sim = sim\n            chosen = candidate\n    return chosen\n\n\ndef check_availability(requested, available):\n    \"\"\"Check if an element is available in a list.\n\n    Args:\n        requested (str): probe string.\n        available (list): a list of available strings.\n    \"\"\"\n    if requested not in available:\n        psb_ans = get_most_similar_str_to_a_from_b(requested, available)\n        raise ValueError(\n            \"The requested one is expected \"\n            \"to belong to {}, but got [{}] \"\n            \"(do you mean [{}]?)\".format(available, requested, psb_ans)\n        )\n\n\ndef tolist_if_not(x):\n    \"\"\"Convert to a list.\"\"\"\n    if not isinstance(x, list):\n        x = [x]\n    return x\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/torchtools.py",
    "content": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport pickle\nimport shutil\nimport os.path as osp\nimport warnings\nfrom functools import partial\nfrom collections import OrderedDict\nimport torch\nimport torch.nn as nn\n\nfrom .tools import mkdir_if_missing\n\n__all__ = [\n    \"save_checkpoint\",\n    \"load_checkpoint\",\n    \"resume_from_checkpoint\",\n    \"open_all_layers\",\n    \"open_specified_layers\",\n    \"count_num_param\",\n    \"load_pretrained_weights\",\n    \"init_network_weights\",\n]\n\n\ndef save_checkpoint(\n    state,\n    save_dir,\n    is_best=False,\n    remove_module_from_keys=True,\n    model_name=\"\"\n):\n    r\"\"\"Save checkpoint.\n\n    Args:\n        state (dict): dictionary.\n        save_dir (str): directory to save checkpoint.\n        is_best (bool, optional): if True, this checkpoint will be copied and named\n            ``model-best.pth.tar``. Default is False.\n        remove_module_from_keys (bool, optional): whether to remove \"module.\"\n            from layer names. Default is True.\n        model_name (str, optional): model name to save.\n\n    Examples::\n        >>> state = {\n        >>>     'state_dict': model.state_dict(),\n        >>>     'epoch': 10,\n        >>>     'optimizer': optimizer.state_dict()\n        >>> }\n        >>> save_checkpoint(state, 'log/my_model')\n    \"\"\"\n    mkdir_if_missing(save_dir)\n\n    if remove_module_from_keys:\n        # remove 'module.' in state_dict's keys\n        state_dict = state[\"state_dict\"]\n        new_state_dict = OrderedDict()\n        for k, v in state_dict.items():\n            if k.startswith(\"module.\"):\n                k = k[7:]\n            new_state_dict[k] = v\n        state[\"state_dict\"] = new_state_dict\n\n    # save model\n    epoch = state[\"epoch\"]\n    if not model_name:\n        model_name = \"model.pth.tar-\" + str(epoch)\n    fpath = osp.join(save_dir, model_name)\n    torch.save(state, fpath)\n    print('Checkpoint saved to \"{}\"'.format(fpath))\n\n    # save current model name\n    checkpoint_file = osp.join(save_dir, \"checkpoint\")\n    checkpoint = open(checkpoint_file, \"w+\")\n    checkpoint.write(\"{}\\n\".format(osp.basename(fpath)))\n    checkpoint.close()\n\n    if is_best:\n        best_fpath = osp.join(osp.dirname(fpath), \"model-best.pth.tar\")\n        shutil.copy(fpath, best_fpath)\n        print('Best checkpoint saved to \"{}\"'.format(best_fpath))\n\n\ndef load_checkpoint(fpath):\n    r\"\"\"Load checkpoint.\n\n    ``UnicodeDecodeError`` can be well handled, which means\n    python2-saved files can be read from python3.\n\n    Args:\n        fpath (str): path to checkpoint.\n\n    Returns:\n        dict\n\n    Examples::\n        >>> fpath = 'log/my_model/model.pth.tar-10'\n        >>> checkpoint = load_checkpoint(fpath)\n    \"\"\"\n    if fpath is None:\n        raise ValueError(\"File path is None\")\n\n    if not osp.exists(fpath):\n        raise FileNotFoundError('File is not found at \"{}\"'.format(fpath))\n\n    map_location = None if torch.cuda.is_available() else \"cpu\"\n\n    try:\n        checkpoint = torch.load(fpath, map_location=map_location)\n\n    except UnicodeDecodeError:\n        pickle.load = partial(pickle.load, encoding=\"latin1\")\n        pickle.Unpickler = partial(pickle.Unpickler, encoding=\"latin1\")\n        checkpoint = torch.load(\n            fpath, pickle_module=pickle, map_location=map_location\n        )\n\n    except Exception:\n        print('Unable to load checkpoint from \"{}\"'.format(fpath))\n        raise\n\n    return checkpoint\n\n\ndef resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None):\n    r\"\"\"Resume training from a checkpoint.\n\n    This will load (1) model weights and (2) ``state_dict``\n    of optimizer if ``optimizer`` is not None.\n\n    Args:\n        fdir (str): directory where the model was saved.\n        model (nn.Module): model.\n        optimizer (Optimizer, optional): an Optimizer.\n        scheduler (Scheduler, optional): an Scheduler.\n\n    Returns:\n        int: start_epoch.\n\n    Examples::\n        >>> fdir = 'log/my_model'\n        >>> start_epoch = resume_from_checkpoint(fdir, model, optimizer, scheduler)\n    \"\"\"\n    with open(osp.join(fdir, \"checkpoint\"), \"r\") as checkpoint:\n        model_name = checkpoint.readlines()[0].strip(\"\\n\")\n        fpath = osp.join(fdir, model_name)\n\n    print('Loading checkpoint from \"{}\"'.format(fpath))\n    checkpoint = load_checkpoint(fpath)\n    model.load_state_dict(checkpoint[\"state_dict\"])\n    print(\"Loaded model weights\")\n\n    if optimizer is not None and \"optimizer\" in checkpoint.keys():\n        optimizer.load_state_dict(checkpoint[\"optimizer\"])\n        print(\"Loaded optimizer\")\n\n    if scheduler is not None and \"scheduler\" in checkpoint.keys():\n        scheduler.load_state_dict(checkpoint[\"scheduler\"])\n        print(\"Loaded scheduler\")\n\n    start_epoch = checkpoint[\"epoch\"]\n    print(\"Previous epoch: {}\".format(start_epoch))\n\n    return start_epoch\n\n\ndef adjust_learning_rate(\n    optimizer,\n    base_lr,\n    epoch,\n    stepsize=20,\n    gamma=0.1,\n    linear_decay=False,\n    final_lr=0,\n    max_epoch=100,\n):\n    r\"\"\"Adjust learning rate.\n\n    Deprecated.\n    \"\"\"\n    if linear_decay:\n        # linearly decay learning rate from base_lr to final_lr\n        frac_done = epoch / max_epoch\n        lr = frac_done*final_lr + (1.0-frac_done) * base_lr\n    else:\n        # decay learning rate by gamma for every stepsize\n        lr = base_lr * (gamma**(epoch // stepsize))\n\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = lr\n\n\ndef set_bn_to_eval(m):\n    r\"\"\"Set BatchNorm layers to eval mode.\"\"\"\n    # 1. no update for running mean and var\n    # 2. scale and shift parameters are still trainable\n    classname = m.__class__.__name__\n    if classname.find(\"BatchNorm\") != -1:\n        m.eval()\n\n\ndef open_all_layers(model):\n    r\"\"\"Open all layers in model for training.\n\n    Examples::\n        >>> open_all_layers(model)\n    \"\"\"\n    model.train()\n    for p in model.parameters():\n        p.requires_grad = True\n\n\ndef open_specified_layers(model, open_layers):\n    r\"\"\"Open specified layers in model for training while keeping\n    other layers frozen.\n\n    Args:\n        model (nn.Module): neural net model.\n        open_layers (str or list): layers open for training.\n\n    Examples::\n        >>> # Only model.classifier will be updated.\n        >>> open_layers = 'classifier'\n        >>> open_specified_layers(model, open_layers)\n        >>> # Only model.fc and model.classifier will be updated.\n        >>> open_layers = ['fc', 'classifier']\n        >>> open_specified_layers(model, open_layers)\n    \"\"\"\n    if isinstance(model, nn.DataParallel):\n        model = model.module\n\n    if isinstance(open_layers, str):\n        open_layers = [open_layers]\n\n    for layer in open_layers:\n        assert hasattr(\n            model, layer\n        ), '\"{}\" is not an attribute of the model, please provide the correct name'.format(\n            layer\n        )\n\n    for name, module in model.named_children():\n        if name in open_layers:\n            module.train()\n            for p in module.parameters():\n                p.requires_grad = True\n        else:\n            module.eval()\n            for p in module.parameters():\n                p.requires_grad = False\n\n\ndef count_num_param(model):\n    r\"\"\"Count number of parameters in a model.\n\n    Args:\n        model (nn.Module): network model.\n\n    Examples::\n        >>> model_size = count_num_param(model)\n    \"\"\"\n    return sum(p.numel() for p in model.parameters())\n\n\ndef load_pretrained_weights(model, weight_path):\n    r\"\"\"Load pretrianed weights to model.\n\n    Features::\n        - Incompatible layers (unmatched in name or size) will be ignored.\n        - Can automatically deal with keys containing \"module.\".\n\n    Args:\n        model (nn.Module): network model.\n        weight_path (str): path to pretrained weights.\n\n    Examples::\n        >>> weight_path = 'log/my_model/model-best.pth.tar'\n        >>> load_pretrained_weights(model, weight_path)\n    \"\"\"\n    checkpoint = load_checkpoint(weight_path)\n    if \"state_dict\" in checkpoint:\n        state_dict = checkpoint[\"state_dict\"]\n    else:\n        state_dict = checkpoint\n\n    model_dict = model.state_dict()\n    new_state_dict = OrderedDict()\n    matched_layers, discarded_layers = [], []\n\n    for k, v in state_dict.items():\n        if k.startswith(\"module.\"):\n            k = k[7:]  # discard module.\n\n        if k in model_dict and model_dict[k].size() == v.size():\n            new_state_dict[k] = v\n            matched_layers.append(k)\n        else:\n            discarded_layers.append(k)\n\n    model_dict.update(new_state_dict)\n    model.load_state_dict(model_dict)\n\n    if len(matched_layers) == 0:\n        warnings.warn(\n            'The pretrained weights \"{}\" cannot be loaded, '\n            \"please check the key names manually \"\n            \"(** ignored and continue **)\".format(weight_path)\n        )\n    else:\n        print(\n            'Successfully loaded pretrained weights from \"{}\"'.\n            format(weight_path)\n        )\n        if len(discarded_layers) > 0:\n            print(\n                \"** The following layers are discarded \"\n                \"due to unmatched keys or layer size: {}\".\n                format(discarded_layers)\n            )\n\n\ndef init_network_weights(model, init_type=\"normal\", gain=0.02):\n\n    def _init_func(m):\n        classname = m.__class__.__name__\n\n        if hasattr(m, \"weight\") and (\n            classname.find(\"Conv\") != -1 or classname.find(\"Linear\") != -1\n        ):\n            if init_type == \"normal\":\n                nn.init.normal_(m.weight.data, 0.0, gain)\n            elif init_type == \"xavier\":\n                nn.init.xavier_normal_(m.weight.data, gain=gain)\n            elif init_type == \"kaiming\":\n                nn.init.kaiming_normal_(m.weight.data, a=0, mode=\"fan_in\")\n            elif init_type == \"orthogonal\":\n                nn.init.orthogonal_(m.weight.data, gain=gain)\n            else:\n                raise NotImplementedError(\n                    \"initialization method {} is not implemented\".\n                    format(init_type)\n                )\n            if hasattr(m, \"bias\") and m.bias is not None:\n                nn.init.constant_(m.bias.data, 0.0)\n\n        elif classname.find(\"BatchNorm\") != -1:\n            nn.init.constant_(m.weight.data, 1.0)\n            nn.init.constant_(m.bias.data, 0.0)\n\n        elif classname.find(\"InstanceNorm\") != -1:\n            if m.weight is not None and m.bias is not None:\n                nn.init.constant_(m.weight.data, 1.0)\n                nn.init.constant_(m.bias.data, 0.0)\n\n    model.apply(_init_func)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/da/cifar_stl.py",
    "content": "import sys\nimport pprint as pp\nimport os.path as osp\nfrom torchvision.datasets import STL10, CIFAR10\n\nfrom dassl.utils import mkdir_if_missing\n\ncifar_label2name = {\n    0: \"airplane\",\n    1: \"car\",  # the original name was 'automobile'\n    2: \"bird\",\n    3: \"cat\",\n    4: \"deer\",\n    5: \"dog\",\n    6: \"frog\",  # conflict class\n    7: \"horse\",\n    8: \"ship\",\n    9: \"truck\",\n}\n\nstl_label2name = {\n    0: \"airplane\",\n    1: \"bird\",\n    2: \"car\",\n    3: \"cat\",\n    4: \"deer\",\n    5: \"dog\",\n    6: \"horse\",\n    7: \"monkey\",  # conflict class\n    8: \"ship\",\n    9: \"truck\",\n}\n\nnew_name2label = {\n    \"airplane\": 0,\n    \"bird\": 1,\n    \"car\": 2,\n    \"cat\": 3,\n    \"deer\": 4,\n    \"dog\": 5,\n    \"horse\": 6,\n    \"ship\": 7,\n    \"truck\": 8,\n}\n\n\ndef extract_and_save_image(dataset, save_dir, discard, label2name):\n    if osp.exists(save_dir):\n        print('Folder \"{}\" already exists'.format(save_dir))\n        return\n\n    print('Extracting images to \"{}\" ...'.format(save_dir))\n    mkdir_if_missing(save_dir)\n\n    for i in range(len(dataset)):\n        img, label = dataset[i]\n        if label == discard:\n            continue\n        class_name = label2name[label]\n        label_new = new_name2label[class_name]\n        class_dir = osp.join(\n            save_dir,\n            str(label_new).zfill(3) + \"_\" + class_name\n        )\n        mkdir_if_missing(class_dir)\n        impath = osp.join(class_dir, str(i + 1).zfill(5) + \".jpg\")\n        img.save(impath)\n\n\ndef download_and_prepare(name, root, discarded_label, label2name):\n    print(\"Dataset: {}\".format(name))\n    print(\"Root: {}\".format(root))\n    print(\"Old labels:\")\n    pp.pprint(label2name)\n    print(\"Discarded label: {}\".format(discarded_label))\n    print(\"New labels:\")\n    pp.pprint(new_name2label)\n\n    if name == \"cifar\":\n        train = CIFAR10(root, train=True, download=True)\n        test = CIFAR10(root, train=False)\n    else:\n        train = STL10(root, split=\"train\", download=True)\n        test = STL10(root, split=\"test\")\n\n    train_dir = osp.join(root, name, \"train\")\n    test_dir = osp.join(root, name, \"test\")\n\n    extract_and_save_image(train, train_dir, discarded_label, label2name)\n    extract_and_save_image(test, test_dir, discarded_label, label2name)\n\n\nif __name__ == \"__main__\":\n    download_and_prepare(\"cifar\", sys.argv[1], 6, cifar_label2name)\n    download_and_prepare(\"stl\", sys.argv[1], 7, stl_label2name)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/da/digit5.py",
    "content": "import os\nimport numpy as np\nimport os.path as osp\nimport argparse\nfrom PIL import Image\nfrom scipy.io import loadmat\n\n\ndef mkdir_if_missing(directory):\n    if not osp.exists(directory):\n        os.makedirs(directory)\n\n\ndef extract_and_save(data, label, save_dir):\n    for i, (x, y) in enumerate(zip(data, label)):\n        if x.shape[2] == 1:\n            x = np.repeat(x, 3, axis=2)\n        if y == 10:\n            y = 0\n        x = Image.fromarray(x, mode=\"RGB\")\n        save_path = osp.join(\n            save_dir,\n            str(i + 1).zfill(6) + \"_\" + str(y) + \".jpg\"\n        )\n        x.save(save_path)\n\n\ndef load_mnist(data_dir, raw_data_dir):\n    filepath = osp.join(raw_data_dir, \"mnist_data.mat\")\n    data = loadmat(filepath)\n\n    train_data = np.reshape(data[\"train_32\"], (55000, 32, 32, 1))\n    test_data = np.reshape(data[\"test_32\"], (10000, 32, 32, 1))\n\n    train_label = np.nonzero(data[\"label_train\"])[1]\n    test_label = np.nonzero(data[\"label_test\"])[1]\n\n    return train_data, test_data, train_label, test_label\n\n\ndef load_mnist_m(data_dir, raw_data_dir):\n    filepath = osp.join(raw_data_dir, \"mnistm_with_label.mat\")\n    data = loadmat(filepath)\n\n    train_data = data[\"train\"]\n    test_data = data[\"test\"]\n\n    train_label = np.nonzero(data[\"label_train\"])[1]\n    test_label = np.nonzero(data[\"label_test\"])[1]\n\n    return train_data, test_data, train_label, test_label\n\n\ndef load_svhn(data_dir, raw_data_dir):\n    train = loadmat(osp.join(raw_data_dir, \"svhn_train_32x32.mat\"))\n    train_data = train[\"X\"].transpose(3, 0, 1, 2)\n    train_label = train[\"y\"][:, 0]\n\n    test = loadmat(osp.join(raw_data_dir, \"svhn_test_32x32.mat\"))\n    test_data = test[\"X\"].transpose(3, 0, 1, 2)\n    test_label = test[\"y\"][:, 0]\n\n    return train_data, test_data, train_label, test_label\n\n\ndef load_syn(data_dir, raw_data_dir):\n    filepath = osp.join(raw_data_dir, \"syn_number.mat\")\n    data = loadmat(filepath)\n\n    train_data = data[\"train_data\"]\n    test_data = data[\"test_data\"]\n\n    train_label = data[\"train_label\"][:, 0]\n    test_label = data[\"test_label\"][:, 0]\n\n    return train_data, test_data, train_label, test_label\n\n\ndef load_usps(data_dir, raw_data_dir):\n    filepath = osp.join(raw_data_dir, \"usps_28x28.mat\")\n    data = loadmat(filepath)[\"dataset\"]\n\n    train_data = data[0][0].transpose(0, 2, 3, 1)\n    test_data = data[1][0].transpose(0, 2, 3, 1)\n\n    train_data *= 255\n    test_data *= 255\n\n    train_data = train_data.astype(np.uint8)\n    test_data = test_data.astype(np.uint8)\n\n    train_label = data[0][1][:, 0]\n    test_label = data[1][1][:, 0]\n\n    return train_data, test_data, train_label, test_label\n\n\ndef main(data_dir):\n    data_dir = osp.abspath(osp.expanduser(data_dir))\n    raw_data_dir = osp.join(data_dir, \"Digit-Five\")\n\n    if not osp.exists(data_dir):\n        raise FileNotFoundError('\"{}\" does not exist'.format(data_dir))\n\n    datasets = [\"mnist\", \"mnist_m\", \"svhn\", \"syn\", \"usps\"]\n\n    for name in datasets:\n        print(\"Creating {}\".format(name))\n\n        output = eval(\"load_\" + name)(data_dir, raw_data_dir)\n        train_data, test_data, train_label, test_label = output\n\n        print(\"# train: {}\".format(train_data.shape[0]))\n        print(\"# test: {}\".format(test_data.shape[0]))\n\n        train_dir = osp.join(data_dir, name, \"train_images\")\n        mkdir_if_missing(train_dir)\n        test_dir = osp.join(data_dir, name, \"test_images\")\n        mkdir_if_missing(test_dir)\n\n        extract_and_save(train_data, train_label, train_dir)\n        extract_and_save(test_data, test_label, test_dir)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"data_dir\", type=str, help=\"directory containing Digit-Five/\"\n    )\n    args = parser.parse_args()\n    main(args.data_dir)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/da/visda17.sh",
    "content": "# ------------------------------------------------------------------------\n# ROOT is the root directory where you put your domain datasets.\n# \n# Suppose you wanna put the dataset under $DATA, which stores all the\n# domain datasets, run the following command in your terminal to\n# download VisDa17:\n#\n# $ sh visda17.sh $DATA\n#------------------------------------------------------------------------\n\nROOT=$1\nmkdir $ROOT/visda17\ncd $ROOT/visda17\n\nwget http://csr.bu.edu/ftp/visda17/clf/train.tar\ntar xvf train.tar\n\nwget http://csr.bu.edu/ftp/visda17/clf/validation.tar\ntar xvf validation.tar  \n\nwget http://csr.bu.edu/ftp/visda17/clf/test.tar\ntar xvf test.tar\n\nwget https://raw.githubusercontent.com/VisionLearningGroup/taskcv-2017-public/master/classification/data/image_list.txt -O test/image_list.txt"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/dg/cifar_c.py",
    "content": "\"\"\"\nThis script\n- creates a folder named \"cifar10_c\" under the same directory as 'CIFAR-10-C'\n- extracts images from .npy files and save them as .jpg.\n\"\"\"\nimport os\nimport sys\nimport numpy as np\nimport os.path as osp\nfrom PIL import Image\n\nfrom dassl.utils import mkdir_if_missing\n\n\ndef extract_and_save(images, labels, level, dst):\n    # level denotes the corruption intensity level (0-based)\n    assert 0 <= level <= 4\n\n    for i in range(10000):\n        real_i = i + level*10000\n        im = Image.fromarray(images[real_i])\n        label = int(labels[real_i])\n        category_dir = osp.join(dst, str(label).zfill(3))\n        mkdir_if_missing(category_dir)\n        save_path = osp.join(category_dir, str(i + 1).zfill(5) + \".jpg\")\n        im.save(save_path)\n\n\ndef main(npy_folder):\n    npy_folder = osp.abspath(osp.expanduser(npy_folder))\n    dataset_cap = osp.basename(npy_folder)\n\n    assert dataset_cap in [\"CIFAR-10-C\", \"CIFAR-100-C\"]\n\n    if dataset_cap == \"CIFAR-10-C\":\n        dataset = \"cifar10_c\"\n    else:\n        dataset = \"cifar100_c\"\n\n    if not osp.exists(npy_folder):\n        print('The given folder \"{}\" does not exist'.format(npy_folder))\n\n    root = osp.dirname(npy_folder)\n    im_folder = osp.join(root, dataset)\n\n    mkdir_if_missing(im_folder)\n\n    dirnames = os.listdir(npy_folder)\n    dirnames.remove(\"labels.npy\")\n    if \"README.txt\" in dirnames:\n        dirnames.remove(\"README.txt\")\n    assert len(dirnames) == 19\n    labels = np.load(osp.join(npy_folder, \"labels.npy\"))\n\n    for dirname in dirnames:\n        corruption = dirname.split(\".\")[0]\n        corruption_folder = osp.join(im_folder, corruption)\n        mkdir_if_missing(corruption_folder)\n\n        npy_filename = osp.join(npy_folder, dirname)\n        images = np.load(npy_filename)\n        assert images.shape[0] == 50000\n\n        for level in range(5):\n            dst = osp.join(corruption_folder, str(level + 1))\n            mkdir_if_missing(dst)\n            print('Saving images to \"{}\"'.format(dst))\n            extract_and_save(images, labels, level, dst)\n\n\nif __name__ == \"__main__\":\n    # sys.argv[1] contains the path to CIFAR-10-C or CIFAR-100-C\n    main(sys.argv[1])\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/ssl/cifar10_cifar100_svhn.py",
    "content": "import sys\nimport os.path as osp\nfrom torchvision.datasets import SVHN, CIFAR10, CIFAR100\n\nfrom dassl.utils import mkdir_if_missing\n\n\ndef extract_and_save_image(dataset, save_dir):\n    if osp.exists(save_dir):\n        print('Folder \"{}\" already exists'.format(save_dir))\n        return\n\n    print('Extracting images to \"{}\" ...'.format(save_dir))\n    mkdir_if_missing(save_dir)\n\n    for i in range(len(dataset)):\n        img, label = dataset[i]\n        class_dir = osp.join(save_dir, str(label).zfill(3))\n        mkdir_if_missing(class_dir)\n        impath = osp.join(class_dir, str(i + 1).zfill(5) + \".jpg\")\n        img.save(impath)\n\n\ndef download_and_prepare(name, root):\n    print(\"Dataset: {}\".format(name))\n    print(\"Root: {}\".format(root))\n\n    if name == \"cifar10\":\n        train = CIFAR10(root, train=True, download=True)\n        test = CIFAR10(root, train=False)\n    elif name == \"cifar100\":\n        train = CIFAR100(root, train=True, download=True)\n        test = CIFAR100(root, train=False)\n    elif name == \"svhn\":\n        train = SVHN(root, split=\"train\", download=True)\n        test = SVHN(root, split=\"test\", download=True)\n    else:\n        raise ValueError\n\n    train_dir = osp.join(root, name, \"train\")\n    test_dir = osp.join(root, name, \"test\")\n\n    extract_and_save_image(train, train_dir)\n    extract_and_save_image(test, test_dir)\n\n\nif __name__ == \"__main__\":\n    download_and_prepare(\"cifar10\", sys.argv[1])\n    download_and_prepare(\"cifar100\", sys.argv[1])\n    download_and_prepare(\"svhn\", sys.argv[1])\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/ssl/stl10.py",
    "content": "import sys\nimport os.path as osp\nfrom torchvision.datasets import STL10\n\nfrom dassl.utils import mkdir_if_missing\n\n\ndef extract_and_save_image(dataset, save_dir):\n    if osp.exists(save_dir):\n        print('Folder \"{}\" already exists'.format(save_dir))\n        return\n\n    print('Extracting images to \"{}\" ...'.format(save_dir))\n    mkdir_if_missing(save_dir)\n\n    for i in range(len(dataset)):\n        img, label = dataset[i]\n        if label == -1:\n            label_name = \"none\"\n        else:\n            label_name = str(label)\n        imname = str(i).zfill(6) + \"_\" + label_name + \".jpg\"\n        impath = osp.join(save_dir, imname)\n        img.save(impath)\n\n\ndef download_and_prepare(root):\n    train = STL10(root, split=\"train\", download=True)\n    test = STL10(root, split=\"test\")\n    unlabeled = STL10(root, split=\"unlabeled\")\n\n    train_dir = osp.join(root, \"train\")\n    test_dir = osp.join(root, \"test\")\n    unlabeled_dir = osp.join(root, \"unlabeled\")\n\n    extract_and_save_image(train, train_dir)\n    extract_and_save_image(test, test_dir)\n    extract_and_save_image(unlabeled, unlabeled_dir)\n\n\nif __name__ == \"__main__\":\n    download_and_prepare(sys.argv[1])\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/linter.sh",
    "content": "echo \"Running isort\"\nisort -y -sp .\necho \"Done\"\n\necho \"Running yapf\"\nyapf -i -r -vv -e build .\necho \"Done\"\n\necho \"Running flake8\"\nflake8 .\necho \"Done\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/requirements.txt",
    "content": "flake8==3.7.9\nyapf==0.29.0\nisort==4.3.21\nyacs\ngdown\ntb-nightly\nfuture\nscipy\nscikit-learn\ntqdm\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/setup.py",
    "content": "import numpy as np\nimport os.path as osp\nfrom setuptools import setup, find_packages\n\n\ndef readme():\n    with open('README.md') as f:\n        content = f.read()\n    return content\n\n\ndef find_version():\n    version_file = 'dassl/__init__.py'\n    with open(version_file, 'r') as f:\n        exec(compile(f.read(), version_file, 'exec'))\n    return locals()['__version__']\n\n\ndef numpy_include():\n    try:\n        numpy_include = np.get_include()\n    except AttributeError:\n        numpy_include = np.get_numpy_include()\n    return numpy_include\n\n\ndef get_requirements(filename='requirements.txt'):\n    here = osp.dirname(osp.realpath(__file__))\n    with open(osp.join(here, filename), 'r') as f:\n        requires = [line.replace('\\n', '') for line in f.readlines()]\n    return requires\n\n\nsetup(\n    name='dassl',\n    version=find_version(),\n    description='Dassl: Domain adaptation and semi-supervised learning',\n    author='Kaiyang Zhou',\n    license='MIT',\n    long_description=readme(),\n    url='https://github.com/KaiyangZhou/Dassl.pytorch',\n    packages=find_packages(),\n    install_requires=get_requirements(),\n    keywords=[\n        'Domain Adaptation', 'Domain Generalization',\n        'Semi-Supervised Learning', 'Pytorch'\n    ]\n)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/tools/parse_test_res.py",
    "content": "\"\"\"\nGoal\n---\n1. Read test results from log.txt files\n2. Compute mean and std across different folders (seeds)\n\nUsage\n---\nAssume the output files are saved under output/my_experiment,\nwhich contains results of different seeds, e.g.,\n\nmy_experiment/\n    seed1/\n        log.txt\n    seed2/\n        log.txt\n    seed3/\n        log.txt\n\nRun the following command from the root directory:\n\n$ python tools/parse_test_res.py output/my_experiment\n\nAdd --ci95 to the argument if you wanna get 95% confidence\ninterval instead of standard deviation:\n\n$ python tools/parse_test_res.py output/my_experiment --ci95\n\nIf my_experiment/ has the following structure,\n\nmy_experiment/\n    exp-1/\n        seed1/\n            log.txt\n            ...\n        seed2/\n            log.txt\n            ...\n        seed3/\n            log.txt\n            ...\n    exp-2/\n        ...\n    exp-3/\n        ...\n\nRun\n\n$ python tools/parse_test_res.py output/my_experiment --multi-exp\n\"\"\"\nimport re\nimport numpy as np\nimport os.path as osp\nimport argparse\nfrom collections import OrderedDict, defaultdict\n\nfrom dassl.utils import check_isfile, listdir_nohidden\n\n\ndef compute_ci95(res):\n    return 1.96 * np.std(res) / np.sqrt(len(res))\n\n\ndef parse_function(*metrics, directory=\"\", args=None, end_signal=None):\n    print(f\"Parsing files in {directory}\")\n    subdirs = listdir_nohidden(directory, sort=True)\n\n    outputs = []\n\n    for subdir in subdirs:\n        fpath = osp.join(directory, subdir, \"log.txt\")\n        assert check_isfile(fpath)\n        good_to_go = False\n        output = OrderedDict()\n\n        with open(fpath, \"r\") as f:\n            lines = f.readlines()\n\n            for line in lines:\n                line = line.strip()\n\n                if line == end_signal:\n                    good_to_go = True\n\n                for metric in metrics:\n                    match = metric[\"regex\"].search(line)\n                    if match and good_to_go:\n                        if \"file\" not in output:\n                            output[\"file\"] = fpath\n                        num = float(match.group(1))\n                        name = metric[\"name\"]\n                        output[name] = num\n\n        if output:\n            outputs.append(output)\n\n    assert len(outputs) > 0, f\"Nothing found in {directory}\"\n\n    metrics_results = defaultdict(list)\n\n    for output in outputs:\n        msg = \"\"\n        for key, value in output.items():\n            if isinstance(value, float):\n                msg += f\"{key}: {value:.2f}%. \"\n            else:\n                msg += f\"{key}: {value}. \"\n            if key != \"file\":\n                metrics_results[key].append(value)\n        print(msg)\n\n    output_results = OrderedDict()\n\n    print(\"===\")\n    print(f\"Summary of directory: {directory}\")\n    for key, values in metrics_results.items():\n        avg = np.mean(values)\n        std = compute_ci95(values) if args.ci95 else np.std(values)\n        print(f\"* {key}: {avg:.2f}% +- {std:.2f}%\")\n        output_results[key] = avg\n    print(\"===\")\n\n    return output_results\n\n\ndef main(args, end_signal):\n    metric = {\n        \"name\": args.keyword,\n        \"regex\": re.compile(fr\"\\* {args.keyword}: ([\\.\\deE+-]+)%\"),\n    }\n\n    if args.multi_exp:\n        final_results = defaultdict(list)\n\n        for directory in listdir_nohidden(args.directory, sort=True):\n            directory = osp.join(args.directory, directory)\n            results = parse_function(\n                metric, directory=directory, args=args, end_signal=end_signal\n            )\n\n            for key, value in results.items():\n                final_results[key].append(value)\n\n        print(\"Average performance\")\n        for key, values in final_results.items():\n            avg = np.mean(values)\n            print(f\"* {key}: {avg:.2f}%\")\n\n    else:\n        parse_function(\n            metric, directory=args.directory, args=args, end_signal=end_signal\n        )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"directory\", type=str, help=\"path to directory\")\n    parser.add_argument(\n        \"--ci95\",\n        action=\"store_true\",\n        help=r\"compute 95\\% confidence interval\"\n    )\n    parser.add_argument(\n        \"--test-log\", action=\"store_true\", help=\"parse test-only logs\"\n    )\n    parser.add_argument(\n        \"--multi-exp\", action=\"store_true\", help=\"parse multiple experiments\"\n    )\n    parser.add_argument(\n        \"--keyword\",\n        default=\"accuracy\",\n        type=str,\n        help=\"which keyword to extract\"\n    )\n    args = parser.parse_args()\n\n    end_signal = \"Finished training\"\n    if args.test_log:\n        end_signal = \"=> result\"\n\n    main(args, end_signal)\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/tools/replace_text.py",
    "content": "\"\"\"\nReplace text in python files.\n\"\"\"\nimport glob\nimport os.path as osp\nimport argparse\nimport fileinput\n\nEXTENSION = \".py\"\n\n\ndef is_python_file(filename):\n    ext = osp.splitext(filename)[1]\n    return ext == EXTENSION\n\n\ndef update_file(filename, text_to_search, replacement_text):\n    print(\"Processing {}\".format(filename))\n    with fileinput.FileInput(filename, inplace=True, backup=\"\") as file:\n        for line in file:\n            print(line.replace(text_to_search, replacement_text), end=\"\")\n\n\ndef recursive_update(directory, text_to_search, replacement_text):\n    filenames = glob.glob(osp.join(directory, \"*\"))\n\n    for filename in filenames:\n        if osp.isfile(filename):\n            if not is_python_file(filename):\n                continue\n            update_file(filename, text_to_search, replacement_text)\n        elif osp.isdir(filename):\n            recursive_update(filename, text_to_search, replacement_text)\n        else:\n            raise NotImplementedError\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"file_or_dir\", type=str, help=\"path to file or directory\"\n    )\n    parser.add_argument(\"text_to_search\", type=str, help=\"name to be replaced\")\n    parser.add_argument(\"replacement_text\", type=str, help=\"new name\")\n    parser.add_argument(\n        \"--ext\", type=str, default=\".py\", help=\"file extension\"\n    )\n    args = parser.parse_args()\n\n    file_or_dir = args.file_or_dir\n    text_to_search = args.text_to_search\n    replacement_text = args.replacement_text\n    extension = args.ext\n\n    global EXTENSION\n    EXTENSION = extension\n\n    if osp.isfile(file_or_dir):\n        if not is_python_file(file_or_dir):\n            return\n        update_file(file_or_dir, text_to_search, replacement_text)\n    elif osp.isdir(file_or_dir):\n        recursive_update(file_or_dir, text_to_search, replacement_text)\n    else:\n        raise NotImplementedError\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/tools/train.py",
    "content": "import argparse\nimport torch\n\nfrom dassl.utils import setup_logger, set_random_seed, collect_env_info\nfrom dassl.config import get_cfg_default\nfrom dassl.engine import build_trainer\n\n\ndef print_args(args, cfg):\n    print(\"***************\")\n    print(\"** Arguments **\")\n    print(\"***************\")\n    optkeys = list(args.__dict__.keys())\n    optkeys.sort()\n    for key in optkeys:\n        print(\"{}: {}\".format(key, args.__dict__[key]))\n    print(\"************\")\n    print(\"** Config **\")\n    print(\"************\")\n    print(cfg)\n\n\ndef reset_cfg(cfg, args):\n    if args.root:\n        cfg.DATASET.ROOT = args.root\n\n    if args.output_dir:\n        cfg.OUTPUT_DIR = args.output_dir\n\n    if args.resume:\n        cfg.RESUME = args.resume\n\n    if args.seed:\n        cfg.SEED = args.seed\n\n    if args.source_domains:\n        cfg.DATASET.SOURCE_DOMAINS = args.source_domains\n\n    if args.target_domains:\n        cfg.DATASET.TARGET_DOMAINS = args.target_domains\n\n    if args.transforms:\n        cfg.INPUT.TRANSFORMS = args.transforms\n\n    if args.trainer:\n        cfg.TRAINER.NAME = args.trainer\n\n    if args.backbone:\n        cfg.MODEL.BACKBONE.NAME = args.backbone\n\n    if args.head:\n        cfg.MODEL.HEAD.NAME = args.head\n\n\ndef extend_cfg(cfg):\n    \"\"\"\n    Add new config variables.\n\n    E.g.\n        from yacs.config import CfgNode as CN\n        cfg.TRAINER.MY_MODEL = CN()\n        cfg.TRAINER.MY_MODEL.PARAM_A = 1.\n        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5\n        cfg.TRAINER.MY_MODEL.PARAM_C = False\n    \"\"\"\n    pass\n\n\ndef setup_cfg(args):\n    cfg = get_cfg_default()\n    extend_cfg(cfg)\n\n    # 1. From the dataset config file\n    if args.dataset_config_file:\n        cfg.merge_from_file(args.dataset_config_file)\n\n    # 2. From the method config file\n    if args.config_file:\n        cfg.merge_from_file(args.config_file)\n\n    # 3. From input arguments\n    reset_cfg(cfg, args)\n\n    # 4. From optional input arguments\n    cfg.merge_from_list(args.opts)\n\n    cfg.freeze()\n\n    return cfg\n\n\ndef main(args):\n    cfg = setup_cfg(args)\n    if cfg.SEED >= 0:\n        print(\"Setting fixed seed: {}\".format(cfg.SEED))\n        set_random_seed(cfg.SEED)\n    setup_logger(cfg.OUTPUT_DIR)\n\n    if torch.cuda.is_available() and cfg.USE_CUDA:\n        torch.backends.cudnn.benchmark = True\n\n    print_args(args, cfg)\n    print(\"Collecting env info ...\")\n    print(\"** System info **\\n{}\\n\".format(collect_env_info()))\n\n    trainer = build_trainer(cfg)\n\n    if args.eval_only:\n        trainer.load_model(args.model_dir, epoch=args.load_epoch)\n        trainer.test()\n        return\n\n    if not args.no_train:\n        trainer.train()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--root\", type=str, default=\"\", help=\"path to dataset\")\n    parser.add_argument(\n        \"--output-dir\", type=str, default=\"\", help=\"output directory\"\n    )\n    parser.add_argument(\n        \"--resume\",\n        type=str,\n        default=\"\",\n        help=\"checkpoint directory (from which the training resumes)\",\n    )\n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=-1,\n        help=\"only positive value enables a fixed seed\"\n    )\n    parser.add_argument(\n        \"--source-domains\",\n        type=str,\n        nargs=\"+\",\n        help=\"source domains for DA/DG\"\n    )\n    parser.add_argument(\n        \"--target-domains\",\n        type=str,\n        nargs=\"+\",\n        help=\"target domains for DA/DG\"\n    )\n    parser.add_argument(\n        \"--transforms\", type=str, nargs=\"+\", help=\"data augmentation methods\"\n    )\n    parser.add_argument(\n        \"--config-file\", type=str, default=\"\", help=\"path to config file\"\n    )\n    parser.add_argument(\n        \"--dataset-config-file\",\n        type=str,\n        default=\"\",\n        help=\"path to config file for dataset setup\",\n    )\n    parser.add_argument(\n        \"--trainer\", type=str, default=\"\", help=\"name of trainer\"\n    )\n    parser.add_argument(\n        \"--backbone\", type=str, default=\"\", help=\"name of CNN backbone\"\n    )\n    parser.add_argument(\"--head\", type=str, default=\"\", help=\"name of head\")\n    parser.add_argument(\n        \"--eval-only\", action=\"store_true\", help=\"evaluation only\"\n    )\n    parser.add_argument(\n        \"--model-dir\",\n        type=str,\n        default=\"\",\n        help=\"load model from this directory for eval-only mode\",\n    )\n    parser.add_argument(\n        \"--load-epoch\",\n        type=int,\n        help=\"load model weights at this epoch for evaluation\"\n    )\n    parser.add_argument(\n        \"--no-train\", action=\"store_true\", help=\"do not call trainer.train()\"\n    )\n    parser.add_argument(\n        \"opts\",\n        default=None,\n        nargs=argparse.REMAINDER,\n        help=\"modify config options using the command-line\",\n    )\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "ProGrad.public/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# Custom\noutput/\ndebug.sh\n"
  },
  {
    "path": "ProGrad.public/DATASETS.md",
    "content": "# How to install datasets\n\nWe suggest putting all datasets under the same folder (say `$DATA`) to ease management and following the instructions below to organize datasets to avoid modifying the source code. The file structure looks like\n\n```\n$DATA/\n|–– imagenet/\n|–– caltech-101/\n|–– oxford_pets/\n|–– stanford_cars/\n```\n\nIf you have some datasets already installed somewhere else, you can create symbolic links in `$DATA/dataset_name` that point to the original data to avoid duplicate download.\n\nDatasets list:\n- [ImageNet](#imagenet)\n- [Caltech101](#caltech101)\n- [OxfordPets](#oxfordpets)\n- [StanfordCars](#stanfordcars)\n- [Flowers102](#flowers102)\n- [Food101](#food101)\n- [FGVCAircraft](#fgvcaircraft)\n- [SUN397](#sun397)\n- [DTD](#dtd)\n- [EuroSAT](#eurosat)\n- [UCF101](#ucf101)\n- [ImageNetV2](#imagenetv2)\n- [ImageNet-Sketch](#imagenet-sketch)\n- [ImageNet-A](#imagenet-a)\n- [ImageNet-R](#imagenet-r)\n\nThe instructions to prepare each dataset are detailed below. To ensure reproducibility and fair comparison for future work, we provide fixed train/val/test splits for all datasets except ImageNet where the validation set is used as test set. The fixed splits are either from the original datasets (if available) or created by us.\n\n### ImageNet\n- Create a folder named `imagenet/` under `$DATA`.\n- Create `images/` under `imagenet/`.\n- Download the dataset from the [official website](https://image-net.org/index.php) and extract the training and validation sets to `$DATA/imagenet/images`. The directory structure should look like\n```\nimagenet/\n|–– images/\n|   |–– train/ # contains 1,000 folders like n01440764, n01443537, etc.\n|   |–– val/\n```\n- If you had downloaded the ImageNet dataset before, you can create symbolic links to map the training and validation sets to `$DATA/imagenet/images`.\n- Download the `classnames.txt` to `$DATA/imagenet/` from this [link](https://drive.google.com/file/d/1-61f_ol79pViBFDG_IDlUQSwoLcn2XXF/view?usp=sharing). The class names are copied from [CLIP](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb).\n\n### Caltech101\n- Create a folder named `caltech-101/` under `$DATA`.\n- Download `101_ObjectCategories.tar.gz` from http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz and extract the file under `$DATA/caltech-101`.\n- Download `split_zhou_Caltech101.json` from this [link](https://drive.google.com/file/d/1hyarUivQE36mY6jSomru6Fjd-JzwcCzN/view?usp=sharing) and put it under `$DATA/caltech-101`. \n\nThe directory structure should look like\n```\ncaltech-101/\n|–– 101_ObjectCategories/\n|–– split_zhou_Caltech101.json\n```\n\n### OxfordPets\n- Create a folder named `oxford_pets/` under `$DATA`.\n- Download the images from https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz.\n- Download the annotations from https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz.\n- Download `split_zhou_OxfordPets.json` from this [link](https://drive.google.com/file/d/1501r8Ber4nNKvmlFVQZ8SeUHTcdTTEqs/view?usp=sharing). \n\nThe directory structure should look like\n```\noxford_pets/\n|–– images/\n|–– annotations/\n|–– split_zhou_OxfordPets.json\n```\n\n### StanfordCars\n- Create a folder named `stanford_cars/` under `$DATA`.\n- Download the train images http://ai.stanford.edu/~jkrause/car196/cars_train.tgz.\n- Download the test images http://ai.stanford.edu/~jkrause/car196/cars_test.tgz.\n- Download the train labels https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz.\n- Download the test labels http://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat.\n- Download `split_zhou_StanfordCars.json` from this [link](https://drive.google.com/file/d/1ObCFbaAgVu0I-k_Au-gIUcefirdAuizT/view?usp=sharing).\n\nThe directory structure should look like\n```\nstanford_cars/\n|–– cars_test\\\n|–– cars_test_annos_withlabels.mat\n|–– cars_train\\\n|–– devkit\\\n|–– split_zhou_StanfordCars.json\n```\n\n### Flowers102\n- Create a folder named `oxford_flowers/` under `$DATA`.\n- Download the images and labels from https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz and https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat respectively.\n- Download `cat_to_name.json` from [here](https://drive.google.com/file/d/1AkcxCXeK_RCGCEC_GvmWxjcjaNhu-at0/view?usp=sharing). \n- Download `split_zhou_OxfordFlowers.json` from [here](https://drive.google.com/file/d/1Pp0sRXzZFZq15zVOzKjKBu4A9i01nozT/view?usp=sharing).\n\nThe directory structure should look like\n```\noxford_flowers/\n|–– cat_to_name.json\n|–– imagelabels.mat\n|–– jpg/\n|–– split_zhou_OxfordFlowers.json\n```\n\n### Food101\n- Download the dataset from https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/ and extract the file `food-101.tar.gz` under `$DATA`, resulting in a folder named `$DATA/food-101/`.\n- Download `split_zhou_Food101.json` from [here](https://drive.google.com/file/d/1QK0tGi096I0Ba6kggatX1ee6dJFIcEJl/view?usp=sharing).\n\nThe directory structure should look like\n```\nfood-101/\n|–– images/\n|–– license_agreement.txt\n|–– meta/\n|–– README.txt\n|–– split_zhou_Food101.json\n```\n\n### FGVCAircraft\n- Download the data from https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz.\n- Extract `fgvc-aircraft-2013b.tar.gz` and keep only `data/`.\n- Move `data/` to `$DATA` and rename the folder to `fgvc_aircraft/`.\n\nThe directory structure should look like\n```\nfgvc_aircraft/\n|–– images/\n|–– ... # a bunch of .txt files\n```\n\n### SUN397\n- Create a folder named  `sun397/` under `$DATA`.\n- Download the images http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz.\n- Download the partitions https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip.\n- Extract these files under `$DATA/sun397/`.\n- Download `split_zhou_SUN397.json` from this [link](https://drive.google.com/file/d/1y2RD81BYuiyvebdN-JymPfyWYcd8_MUq/view?usp=sharing).\n\nThe directory structure should look like\n```\nsun397/\n|–– SUN397/\n|–– split_zhou_SUN397.json\n|–– ... # a bunch of .txt files\n```\n\n### DTD\n- Download the dataset from https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz and extract it to `$DATA`. This should lead to `$DATA/dtd/`.\n- Download `split_zhou_DescribableTextures.json` from this [link](https://drive.google.com/file/d/1u3_QfB467jqHgNXC00UIzbLZRQCg2S7x/view?usp=sharing).\n\nThe directory structure should look like\n```\ndtd/\n|–– images/\n|–– imdb/\n|–– labels/\n|–– split_zhou_DescribableTextures.json\n```\n\n### EuroSAT\n- Create a folder named `eurosat/` under `$DATA`.\n- Download the dataset from http://madm.dfki.de/files/sentinel/EuroSAT.zip and extract it to `$DATA/eurosat/`.\n- Download `split_zhou_EuroSAT.json` from [here](https://drive.google.com/file/d/1Ip7yaCWFi0eaOFUGga0lUdVi_DDQth1o/view?usp=sharing).\n\nThe directory structure should look like\n```\neurosat/\n|–– 2750/\n|–– split_zhou_EuroSAT.json\n```\n\n### UCF101\n- Create a folder named `ucf101/` under `$DATA`.\n- Download the zip file `UCF-101-midframes.zip` from [here](https://drive.google.com/file/d/10Jqome3vtUA2keJkNanAiFpgbyC9Hc2O/view?usp=sharing) and extract it to `$DATA/ucf101/`. This zip file contains the extracted middle video frames.\n- Download `split_zhou_UCF101.json` from this [link](https://drive.google.com/file/d/1I0S0q91hJfsV9Gf4xDIjgDq4AqBNJb1y/view?usp=sharing).\n\nThe directory structure should look like\n```\nucf101/\n|–– UCF-101-midframes/\n|–– split_zhou_UCF101.json\n```\n\n### ImageNetV2\n- Create a folder named `imagenetv2/` under `$DATA`.\n- Go to this github repo https://github.com/modestyachts/ImageNetV2.\n- Download the matched-frequency dataset from https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz and extract it to `$DATA/imagenetv2/`.\n- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenetv2/`.\n\nThe directory structure should look like\n```\nimagenetv2/\n|–– imagenetv2-matched-frequency-format-val/\n|–– classnames.txt\n```\n\n### ImageNet-Sketch\n- Download the dataset from https://github.com/HaohanWang/ImageNet-Sketch.\n- Extract the dataset to `$DATA/imagenet-sketch`.\n- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-sketch/`.\n\nThe directory structure should look like\n```\nimagenet-sketch/\n|–– images/ # contains 1,000 folders whose names have the format of n*\n|–– classnames.txt\n```\n\n### ImageNet-A\n- Create a folder named `imagenet-adversarial/` under `$DATA`.\n- Download the dataset from https://github.com/hendrycks/natural-adv-examples and extract it to `$DATA/imagenet-adversarial/`.\n- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-adversarial/`.\n\nThe directory structure should look like\n```\nimagenet-adversarial/\n|–– imagenet-a/ # contains 200 folders whose names have the format of n*\n|–– classnames.txt\n```\n\n### ImageNet-R\n- Create a folder named `imagenet-rendition/` under `$DATA`.\n- Download the dataset from https://github.com/hendrycks/imagenet-r and extract it to `$DATA/imagenet-rendition/`.\n- Copy `$DATA/imagenet/classnames.txt` to `$DATA/imagenet-rendition/`.\n\nThe directory structure should look like\n```\nimagenet-rendition/\n|–– imagenet-r/ # contains 200 folders whose names have the format of n*\n|–– classnames.txt\n```"
  },
  {
    "path": "ProGrad.public/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 Kaiyang Zhou\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "ProGrad.public/README.md",
    "content": "# How to Run\n\n## GPU memory needed\n\nAll the experiments is able to run on a single graphic card. However, **if you want to get results on ImageNet, the memory on any single graphic card should be larger than 24 GB.** Around 12 GB is enough for other datasets. \n\n\n## How to Install\nThis code is built on top of the toolbox [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch). But we have some modification on it. So please install the provided Dassl.ProGrad.pytorch. Go the the folder Dassl.ProGrad.pytorch provided in the appendix, and prepare the environment as follows:\n\n```\n# Create a conda environment\nconda create -n dassl python=3.7\n\n# Activate the environment\nconda activate dassl\n\n# Install dependencies\npip install -r requirements.txt\n\n# Install torch (version >= 1.7.1) and torchvision\n# Please make sure you have installed the gpu version due to the speed.\n# For example:\nconda install pytorch torchvision cudatoolkit=10.1 -c pytorch\n\n# Install this library (no need to re-build if the source code is modified)\npython setup.py develop\n```\n\nAfter that, run `pip install -r requirements.txt` under `ProGrad.public/` to install a few more packages required by [CLIP](https://github.com/openai/CLIP) (this should be done when `dassl` is activated). Then, you are ready to go.\n\nFollow [DATASETS.md](DATASETS.md) to install the datasets.\n\n## Few-shot setting on 11 datasets\n\nBasic format:\n```\nbash main.sh ${DATASET_NAME} ${CONFIG_NAME} end ${CONTEXT_TOKENS_NUMBER} ${SHOTS} False\n```\n\nFor example, to run 1, 2, 4, 8, and 16 shots on stanford_cars, \n**CLIP + CoOp (M=16, end)**:\n\n- 1 shot: `bash main.sh stanford_cars rn50_ep50 end 16 1 False`\n- 2 shots: `bash main.sh stanford_cars rn50_ep100 end 16 2 False`\n- 4 shots: `bash main.sh stanford_cars rn50_ep100 end 16 4 False`\n- 8 shots: `bash main.sh stanford_cars rn50 end 16 8 False`\n- 16 shots: `bash main.sh stanford_cars rn50 end 16 8 False`\n\n**CLIP + CoOp + ProGrad**:\n\n**Please take note that the 8-shots and 16-shots results on Flowers102, DTD, and EuroSAT are gotten with lambda as 0.8.** To get the results in our paper, please change the variable LAMBDA in prograd.sh from 1.0 to 0.8.\n\n- 1 shot: `bash prograd.sh stanford_cars rn50_ep50 end 16 1 False`\n- 2 shots: `bash prograd.sh stanford_cars rn50_ep100 end 16 2 False`\n- 4 shots: `bash prograd.sh stanford_cars rn50_ep100 end 16 4 False`\n- 8 shots: `bash prograd.sh stanford_cars rn50 end 16 8 False`\n- 16 shots: `bash prograd.sh stanford_cars rn50 end 16 16 False`\n\n\n```\noutput\n|–– caltech101/\n|   |–– CoOp/\n|   |   |–– rn50_16shots/\n|   |   |   |–– nctx16_cscFalse_ctpend/\n|   |   |   |   |–– seed1/\n|   |   |   |   |–– seed2/\n|   |   |   |   |–– seed3/\n|   |   |–– rn50_8shots/\n|   |   |   |–– nctx16_cscFalse_ctpend/\n|   |   |   |   |–– seed1/\n|   |   |   |   |–– seed2/\n|   |   |   |   |–– seed3/\n```\n\nTo calculate the average results for the folder `rn50_16shots/nctx16_cscFalse_ctpend/`, you can run\n\n```bash\npython parse_test_res.py output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend\n```\n\nThen, you will see something like this in your terminal\n\n```bash\nParsing files in output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend\nfile: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed1/log.txt. accuracy: 91.81%. error: 8.19%.\nfile: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed2/log.txt. accuracy: 92.01%. error: 7.99%.\nfile: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend/seed3/log.txt. accuracy: 92.17%. error: 7.83%.\n===\nSummary of directory: output/caltech101/CoOp/rn50_16shots/nctx16_cscFalse_ctpend\n* accuracy: 92.00% +- 0.15%\n* error: 8.00% +- 0.15%\n===\n```\n\n**How to visualize nearest words for the learned context tokens?** All you need is `interpret_prompt.py`. Say the learned tokens are saved in `a/b/c/prompt_learner/model.pth.tar` and you would like to see the top-3 nearest words for each token. In this case, run `python interpret_prompt.py a/b/c/prompt_learner/model.pth.tar 3`\n\n## Robustness to Distribution Shift\nTo reproduce the robustness experiments, you can simply load the models learned on ImageNet and evaluate them on the following datasets: `imagenetv2`, `imagenet-sketch`, `imagenet-a` and `imagenet-r`.\n\nThe command is provided in `scripts/eval.sh`. The key arguments are `--model-dir`, `--load-epoch` and `--eval-only`. `--model-dir` indicates the directory where the models are saved (i.e. the entire folder containing `log.txt`, the tensorboard file and `prompt_learner/`). `--load-epoch` tells the code to load the model saved at a specific epoch, like `--load-epoch 50` for ImageNet for more details).\n\nFor example, to evaluate `CLIP + CoOp (M=16, end)` on ImageNetV2, you can do\n\n```bash\n# Don't need to use rn5_ep50 here as no training is performed\nbash eval.sh imagenetv2 rn50\n```\n\nIf you want to get the results of our method, simply change the TRAINER to `ProGrad`.\n\nThe default setting is `SHOTS=4`. Feel free to modify the script.\n\nAgain, you can use `parse_test_res.py` to automate the calculation of average performance. This time you should append `--test-log`, e.g., `python parse_test_res.py directory --test-log`.\n\n## Zero-Shot CLIP\nSee `CoOp/scripts/zeroshot.sh`.\n\n## Generalization From Base to New Classes\n\nYou will need `base2new_train_main.sh`, `base2new_test_main.sh`, `base2new_train_prograd.sh`, and `base2new_test_prograd.sh`. The scripts with the prefix `base2new_train` train a model on base classes while the ones with the prefix `base2new_test` evaluate the trained model on new classes. Both kinds of scripts have only one input argument, i.e., `DATASET`. `DATASET` takes as input a dataset name, like `imagenet` or `caltech101`. The valid names are the files' names in `CoOp/configs/datasets/`.\n\nThe scripts with postfix `prograd.sh` are used for our proposed method, while the ones with the postfix `main.sh` are used for CoOp.\n\nBelow we provide an example on how to evaluate the model on ImageNet.\n\n```bash\nbash base2new_train_prograd.sh stanford_cars\nbash base2new_test_prograd.sh stanford_cars\n```\n**If you want to test results on ImageNet, remember to change the CFG from \"rn50_ep100\" to \"rn50_ep50\", and change the LOADEP from 100 to 50 in the corresponding script.**\n\nWhen the evaluation is done, you can use `parse_test_res.py` to automatically calculate the average results. For instance, after you finish the evaluation using the aforementioned commands, you would get\n\n```\noutput\n|–– base2new/\n|   |–– test_new/\n|   |   |–– stanford_cars/\n|   |   |   |–– shots_16/\n|   |   |   |   |–– CoCoOp/\n|   |   |   |   |   |–– rn50_ep100/\n|   |   |   |   |   |   |–– seed1/\n|   |   |   |   |   |   |–– seed2/\n|   |   |   |   |   |   |–– seed3/\n|   |–– train_base/\n|   |   |–– stanford_cars/\n|   |   |   |–– shots_16/\n|   |   |   |   |–– CoCoOp/\n|   |   |   |   |   |–– rn50_ep100/\n|   |   |   |   |   |   |–– seed1/\n|   |   |   |   |   |   |–– seed2/\n|   |   |   |   |   |   |–– seed3/\n```\n\nThen, to get the average performance on the base classes, run\n\n```bash\npython parse_test_res.py output/base2new/train_base/stanford_cars/shots_16/CoCoOp/rn50_ep100\n```\n\nTo get the average performance on the new classes, run\n\n```bash\npython parse_test_res.py output/base2new/test_new/stanford_cars/shots_16/CoCoOp/rn50_ep100 --test-log\n```\n\n"
  },
  {
    "path": "ProGrad.public/clip/__init__.py",
    "content": "from .clip import *\n"
  },
  {
    "path": "ProGrad.public/clip/clip.py",
    "content": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom typing import Union, List\n\nimport torch\nfrom PIL import Image\nfrom torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\nfrom tqdm import tqdm\n\nfrom .model import build_model\nfrom .simple_tokenizer import SimpleTokenizer as _Tokenizer\n\ntry:\n    from torchvision.transforms import InterpolationMode\n    BICUBIC = InterpolationMode.BICUBIC\nexcept ImportError:\n    BICUBIC = Image.BICUBIC\n\nif torch.__version__.split(\".\") < [\"1\", \"7\", \"1\"]:\n    warnings.warn(\"PyTorch version 1.7.1 or higher is recommended\")\n\n__all__ = [\"available_models\", \"load\", \"tokenize\"]\n_tokenizer = _Tokenizer()\n\n_MODELS = {\n    \"RN50\":\n    \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n    \"RN101\":\n    \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n    \"RN50x4\":\n    \"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\",\n    \"RN50x16\":\n    \"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\",\n    \"ViT-B/32\":\n    \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\",\n    \"ViT-B/16\":\n    \"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\",\n}\n\n\ndef _download(url: str, root: str = os.path.expanduser(\"~/.cache/clip\")):\n    os.makedirs(root, exist_ok=True)\n    filename = os.path.basename(url)\n\n    expected_sha256 = url.split(\"/\")[-2]\n    download_target = os.path.join(root, filename)\n\n    if os.path.exists(download_target) and not os.path.isfile(download_target):\n        raise RuntimeError(\n            f\"{download_target} exists and is not a regular file\")\n\n    if os.path.isfile(download_target):\n        if hashlib.sha256(open(download_target,\n                               \"rb\").read()).hexdigest() == expected_sha256:\n            return download_target\n        else:\n            warnings.warn(\n                f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\"\n            )\n\n    with urllib.request.urlopen(url) as source, open(download_target,\n                                                     \"wb\") as output:\n        with tqdm(total=int(source.info().get(\"Content-Length\")),\n                  ncols=80,\n                  unit='iB',\n                  unit_scale=True) as loop:\n            while True:\n                buffer = source.read(8192)\n                if not buffer:\n                    break\n\n                output.write(buffer)\n                loop.update(len(buffer))\n\n    if hashlib.sha256(open(download_target,\n                           \"rb\").read()).hexdigest() != expected_sha256:\n        raise RuntimeError(\n            f\"Model has been downloaded but the SHA256 checksum does not not match\"\n        )\n\n    return download_target\n\n\ndef _transform(n_px):\n    return Compose([\n        Resize(n_px, interpolation=BICUBIC),\n        CenterCrop(n_px),\n        lambda image: image.convert(\"RGB\"),\n        ToTensor(),\n        Normalize((0.48145466, 0.4578275, 0.40821073),\n                  (0.26862954, 0.26130258, 0.27577711)),\n    ])\n\n\ndef available_models() -> List[str]:\n    \"\"\"Returns the names of available CLIP models\"\"\"\n    return list(_MODELS.keys())\n\n\ndef load(name: str,\n         device: Union[str, torch.device] = \"cuda\"\n         if torch.cuda.is_available() else \"cpu\",\n         jit=False):\n    \"\"\"Load a CLIP model\n\n    Parameters\n    ----------\n    name : str\n        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict\n\n    device : Union[str, torch.device]\n        The device to put the loaded model\n\n    jit : bool\n        Whether to load the optimized JIT model or more hackable non-JIT model (default).\n\n    Returns\n    -------\n    model : torch.nn.Module\n        The CLIP model\n\n    preprocess : Callable[[PIL.Image], torch.Tensor]\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\n    \"\"\"\n    if name in _MODELS:\n        model_path = _download(_MODELS[name])\n    elif os.path.isfile(name):\n        model_path = name\n    else:\n        raise RuntimeError(\n            f\"Model {name} not found; available models = {available_models()}\")\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path,\n                               map_location=device if jit else \"cpu\").eval()\n        state_dict = None\n    except RuntimeError:\n        # loading saved state dict\n        if jit:\n            warnings.warn(\n                f\"File {model_path} is not a JIT archive. Loading as a state dict instead\"\n            )\n            jit = False\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    if not jit:\n        model = build_model(state_dict or model.state_dict()).to(device)\n        if str(device) == \"cpu\":\n            model.float()\n        return model, _transform(model.visual.input_resolution)\n\n    # patch the device names\n    device_holder = torch.jit.trace(\n        lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])\n    device_node = [\n        n for n in device_holder.graph.findAllNodes(\"prim::Constant\")\n        if \"Device\" in repr(n)\n    ][-1]\n\n    def patch_device(module):\n        try:\n            graphs = [module.graph] if hasattr(module, \"graph\") else []\n        except RuntimeError:\n            graphs = []\n\n        if hasattr(module, \"forward1\"):\n            graphs.append(module.forward1.graph)\n\n        for graph in graphs:\n            for node in graph.findAllNodes(\"prim::Constant\"):\n                if \"value\" in node.attributeNames() and str(\n                        node[\"value\"]).startswith(\"cuda\"):\n                    node.copyAttributes(device_node)\n\n    model.apply(patch_device)\n    patch_device(model.encode_image)\n    patch_device(model.encode_text)\n\n    # patch dtype to float32 on CPU\n    if str(device) == \"cpu\":\n        float_holder = torch.jit.trace(lambda: torch.ones([]).float(),\n                                       example_inputs=[])\n        float_input = list(float_holder.graph.findNode(\"aten::to\").inputs())[1]\n        float_node = float_input.node()\n\n        def patch_float(module):\n            try:\n                graphs = [module.graph] if hasattr(module, \"graph\") else []\n            except RuntimeError:\n                graphs = []\n\n            if hasattr(module, \"forward1\"):\n                graphs.append(module.forward1.graph)\n\n            for graph in graphs:\n                for node in graph.findAllNodes(\"aten::to\"):\n                    inputs = list(node.inputs())\n                    for i in [\n                            1, 2\n                    ]:  # dtype can be the second or third argument to aten::to()\n                        if inputs[i].node()[\"value\"] == 5:\n                            inputs[i].node().copyAttributes(float_node)\n\n        model.apply(patch_float)\n        patch_float(model.encode_image)\n        patch_float(model.encode_text)\n\n        model.float()\n\n    return model, _transform(model.input_resolution.item())\n\n\ndef tokenize(texts: Union[str, List[str]],\n             context_length: int = 77,\n             truncate: bool = False) -> torch.LongTensor:\n    \"\"\"\n    Returns the tokenized representation of given input string(s)\n\n    Parameters\n    ----------\n    texts : Union[str, List[str]]\n        An input string or a list of input strings to tokenize\n\n    context_length : int\n        The context length to use; all CLIP models use 77 as the context length\n\n    truncate: bool\n        Whether to truncate the text in case its encoding is longer than the context length\n\n    Returns\n    -------\n    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]\n    \"\"\"\n    if isinstance(texts, str):\n        texts = [texts]\n\n    sot_token = _tokenizer.encoder[\"<|startoftext|>\"]\n    eot_token = _tokenizer.encoder[\"<|endoftext|>\"]\n    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]\n                  for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        if len(tokens) > context_length:\n            if truncate:\n                tokens = tokens[:context_length]\n                tokens[-1] = eot_token\n            else:\n                raise RuntimeError(\n                    f\"Input {texts[i]} is too long for context length {context_length}\"\n                )\n        result[i, :len(tokens)] = torch.tensor(tokens)\n\n    return result\n"
  },
  {
    "path": "ProGrad.public/clip/model.py",
    "content": "from collections import OrderedDict\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n\n        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n\n        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n\n        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = None\n        self.stride = stride\n\n        if stride > 1 or inplanes != planes * Bottleneck.expansion:\n            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n            self.downsample = nn.Sequential(\n                OrderedDict([(\"-1\", nn.AvgPool2d(stride)),\n                             (\"0\",\n                              nn.Conv2d(inplanes,\n                                        planes * self.expansion,\n                                        1,\n                                        stride=1,\n                                        bias=False)),\n                             (\"1\", nn.BatchNorm2d(planes * self.expansion))]))\n\n    def forward(self, x: torch.Tensor):\n        identity = x\n\n        out = self.relu(self.bn1(self.conv1(x)))\n        out = self.relu(self.bn2(self.conv2(out)))\n        out = self.avgpool(out)\n        out = self.bn3(self.conv3(out))\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n        return out\n\n\nclass AttentionPool2d(nn.Module):\n    def __init__(self,\n                 spacial_dim: int,\n                 embed_dim: int,\n                 num_heads: int,\n                 output_dim: int = None):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(\n            torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n        self.num_heads = num_heads\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], x.shape[1],\n                      x.shape[2] * x.shape[3]).permute(2, 0,\n                                                       1)  # NCHW -> (HW)NC\n        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC\n        x, _ = F.multi_head_attention_forward(\n            query=x,\n            key=x,\n            value=x,\n            embed_dim_to_check=x.shape[-1],\n            num_heads=self.num_heads,\n            q_proj_weight=self.q_proj.weight,\n            k_proj_weight=self.k_proj.weight,\n            v_proj_weight=self.v_proj.weight,\n            in_proj_weight=None,\n            in_proj_bias=torch.cat(\n                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n            bias_k=None,\n            bias_v=None,\n            add_zero_attn=False,\n            dropout_p=0,\n            out_proj_weight=self.c_proj.weight,\n            out_proj_bias=self.c_proj.bias,\n            use_separate_proj_weight=True,\n            training=self.training,\n            need_weights=False)\n\n        return x[0]\n\n\nclass ModifiedResNet(nn.Module):\n    \"\"\"\n    A ResNet class that is similar to torchvision's but contains the following changes:\n    - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n    - The final pooling layer is a QKV attention instead of an average pool\n    \"\"\"\n    def __init__(self,\n                 layers,\n                 output_dim,\n                 heads,\n                 input_resolution=224,\n                 width=64):\n        super().__init__()\n        self.output_dim = output_dim\n        self.input_resolution = input_resolution\n\n        # the 3-layer stem\n        self.conv1 = nn.Conv2d(3,\n                               width // 2,\n                               kernel_size=3,\n                               stride=2,\n                               padding=1,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(width // 2)\n        self.conv2 = nn.Conv2d(width // 2,\n                               width // 2,\n                               kernel_size=3,\n                               padding=1,\n                               bias=False)\n        self.bn2 = nn.BatchNorm2d(width // 2)\n        self.conv3 = nn.Conv2d(width // 2,\n                               width,\n                               kernel_size=3,\n                               padding=1,\n                               bias=False)\n        self.bn3 = nn.BatchNorm2d(width)\n        self.avgpool = nn.AvgPool2d(2)\n        self.relu = nn.ReLU(inplace=True)\n\n        # residual layers\n        self._inplanes = width  # this is a *mutable* variable used during construction\n        self.layer1 = self._make_layer(width, layers[0])\n        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n\n        embed_dim = width * 32  # the ResNet feature dimension\n        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,\n                                        heads, output_dim)\n\n    def _make_layer(self, planes, blocks, stride=1):\n        layers = [Bottleneck(self._inplanes, planes, stride)]\n\n        self._inplanes = planes * Bottleneck.expansion\n        for _ in range(1, blocks):\n            layers.append(Bottleneck(self._inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        def stem(x):\n            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),\n                             (self.conv3, self.bn3)]:\n                x = self.relu(bn(conv(x)))\n            x = self.avgpool(x)\n            return x\n\n        x = x.type(self.conv1.weight.dtype)\n        x = stem(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.attnpool(x)\n\n        return x\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        ret = super().forward(x.type(torch.float32))\n        return ret.type(orig_type)\n\n\nclass QuickGELU(nn.Module):\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(self,\n                 d_model: int,\n                 n_head: int,\n                 attn_mask: torch.Tensor = None):\n        super().__init__()\n\n        self.attn = nn.MultiheadAttention(d_model, n_head)\n        self.ln_1 = LayerNorm(d_model)\n        self.mlp = nn.Sequential(\n            OrderedDict([(\"c_fc\", nn.Linear(d_model, d_model * 4)),\n                         (\"gelu\", QuickGELU()),\n                         (\"c_proj\", nn.Linear(d_model * 4, d_model))]))\n        self.ln_2 = LayerNorm(d_model)\n        self.attn_mask = attn_mask\n\n    def attention(self, x: torch.Tensor):\n        self.attn_mask = self.attn_mask.to(\n            dtype=x.dtype,\n            device=x.device) if self.attn_mask is not None else None\n        return self.attn(x, x, x, need_weights=False,\n                         attn_mask=self.attn_mask)[0]\n\n    def forward(self, x: torch.Tensor):\n        x = x + self.attention(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass Transformer(nn.Module):\n    def __init__(self,\n                 width: int,\n                 layers: int,\n                 heads: int,\n                 attn_mask: torch.Tensor = None):\n        super().__init__()\n        self.width = width\n        self.layers = layers\n        self.resblocks = nn.Sequential(*[\n            ResidualAttentionBlock(width, heads, attn_mask)\n            for _ in range(layers)\n        ])\n\n    def forward(self, x: torch.Tensor):\n        return self.resblocks(x)\n\n\nclass VisionTransformer(nn.Module):\n    def __init__(self, input_resolution: int, patch_size: int, width: int,\n                 layers: int, heads: int, output_dim: int):\n        super().__init__()\n        self.input_resolution = input_resolution\n        self.output_dim = output_dim\n        self.conv1 = nn.Conv2d(in_channels=3,\n                               out_channels=width,\n                               kernel_size=patch_size,\n                               stride=patch_size,\n                               bias=False)\n\n        scale = width**-0.5\n        self.class_embedding = nn.Parameter(scale * torch.randn(width))\n        self.positional_embedding = nn.Parameter(scale * torch.randn(\n            (input_resolution // patch_size)**2 + 1, width))\n        self.ln_pre = LayerNorm(width)\n\n        self.transformer = Transformer(width, layers, heads)\n\n        self.ln_post = LayerNorm(width)\n        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))\n\n    def forward(self, x: torch.Tensor):\n        x = self.conv1(x)  # shape = [*, width, grid, grid]\n        x = x.reshape(x.shape[0], x.shape[1],\n                      -1)  # shape = [*, width, grid ** 2]\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n        x = torch.cat([\n            self.class_embedding.to(x.dtype) + torch.zeros(\n                x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x\n        ],\n                      dim=1)  # shape = [*, grid ** 2 + 1, width]\n        x = x + self.positional_embedding.to(x.dtype)\n        x = self.ln_pre(x)\n\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n\n        x = self.ln_post(x[:, 0, :])\n\n        if self.proj is not None:\n            x = x @ self.proj\n\n        return x\n\n\nclass CLIP(nn.Module):\n    def __init__(\n        self,\n        embed_dim: int,\n        # vision\n        image_resolution: int,\n        vision_layers: Union[Tuple[int, int, int, int], int],\n        vision_width: int,\n        vision_patch_size: int,\n        # text\n        context_length: int,\n        vocab_size: int,\n        transformer_width: int,\n        transformer_heads: int,\n        transformer_layers: int):\n        super().__init__()\n\n        self.context_length = context_length\n\n        if isinstance(vision_layers, (tuple, list)):\n            vision_heads = vision_width * 32 // 64\n            self.visual = ModifiedResNet(layers=vision_layers,\n                                         output_dim=embed_dim,\n                                         heads=vision_heads,\n                                         input_resolution=image_resolution,\n                                         width=vision_width)\n        else:\n            vision_heads = vision_width // 64\n            self.visual = VisionTransformer(input_resolution=image_resolution,\n                                            patch_size=vision_patch_size,\n                                            width=vision_width,\n                                            layers=vision_layers,\n                                            heads=vision_heads,\n                                            output_dim=embed_dim)\n\n        self.transformer = Transformer(width=transformer_width,\n                                       layers=transformer_layers,\n                                       heads=transformer_heads,\n                                       attn_mask=self.build_attention_mask())\n\n        self.vocab_size = vocab_size\n        self.token_embedding = nn.Embedding(vocab_size, transformer_width)\n        self.positional_embedding = nn.Parameter(\n            torch.empty(self.context_length, transformer_width))\n        self.ln_final = LayerNorm(transformer_width)\n\n        self.text_projection = nn.Parameter(\n            torch.empty(transformer_width, embed_dim))\n        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n\n        self.initialize_parameters()\n\n    def initialize_parameters(self):\n        nn.init.normal_(self.token_embedding.weight, std=0.02)\n        nn.init.normal_(self.positional_embedding, std=0.01)\n\n        if isinstance(self.visual, ModifiedResNet):\n            if self.visual.attnpool is not None:\n                std = self.visual.attnpool.c_proj.in_features**-0.5\n                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)\n                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)\n\n            for resnet_block in [\n                    self.visual.layer1, self.visual.layer2, self.visual.layer3,\n                    self.visual.layer4\n            ]:\n                for name, param in resnet_block.named_parameters():\n                    if name.endswith(\"bn3.weight\"):\n                        nn.init.zeros_(param)\n\n        proj_std = (self.transformer.width**-0.5) * (\n            (2 * self.transformer.layers)**-0.5)\n        attn_std = self.transformer.width**-0.5\n        fc_std = (2 * self.transformer.width)**-0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            nn.init.normal_(self.text_projection,\n                            std=self.transformer.width**-0.5)\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    @property\n    def dtype(self):\n        return self.visual.conv1.weight.dtype\n\n    def encode_image(self, image):\n        return self.visual(image.type(self.dtype))\n\n    def encode_text(self, text):\n        x = self.token_embedding(text).type(\n            self.dtype)  # [batch_size, n_ctx, d_model]\n\n        x = x + self.positional_embedding.type(self.dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x).type(self.dtype)\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = x[torch.arange(x.shape[0]),\n              text.argmax(dim=-1)] @ self.text_projection\n\n        return x\n\n    def forward(self, image, text):\n        image_features = self.encode_image(image)\n        text_features = self.encode_text(text)\n\n        # normalized features\n        image_features = image_features / image_features.norm(dim=-1,\n                                                              keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1,\n                                                           keepdim=True)\n\n        # cosine similarity as logits\n        logit_scale = self.logit_scale.exp()\n        logits_per_image = logit_scale * image_features @ text_features.t()\n        logits_per_text = logit_scale * text_features @ image_features.t()\n\n        # shape = [global_batch_size, global_batch_size]\n        return logits_per_image, logits_per_text\n\n\ndef convert_weights(model: nn.Module):\n    \"\"\"Convert applicable model parameters to fp16\"\"\"\n    def _convert_weights_to_fp16(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.half()\n            if l.bias is not None:\n                l.bias.data = l.bias.data.half()\n\n        if isinstance(l, nn.MultiheadAttention):\n            for attr in [\n                    *[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]],\n                    \"in_proj_bias\", \"bias_k\", \"bias_v\"\n            ]:\n                tensor = getattr(l, attr)\n                if tensor is not None:\n                    tensor.data = tensor.data.half()\n\n        for name in [\"text_projection\", \"proj\"]:\n            if hasattr(l, name):\n                attr = getattr(l, name)\n                if attr is not None:\n                    attr.data = attr.data.half()\n\n    model.apply(_convert_weights_to_fp16)\n\n\ndef build_model(state_dict: dict):\n    vit = \"visual.proj\" in state_dict\n\n    if vit:\n        vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n        vision_layers = len([\n            k for k in state_dict.keys()\n            if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")\n        ])\n        vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n        grid_size = round(\n            (state_dict[\"visual.positional_embedding\"].shape[0] - 1)**0.5)\n        image_resolution = vision_patch_size * grid_size\n    else:\n        counts: list = [\n            len(\n                set(\n                    k.split(\".\")[2] for k in state_dict\n                    if k.startswith(f\"visual.layer{b}\")))\n            for b in [1, 2, 3, 4]\n        ]\n        vision_layers = tuple(counts)\n        vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n        output_width = round(\n            (state_dict[\"visual.attnpool.positional_embedding\"].shape[0] -\n             1)**0.5)\n        vision_patch_size = None\n        assert output_width**2 + 1 == state_dict[\n            \"visual.attnpool.positional_embedding\"].shape[0]\n        image_resolution = output_width * 32\n\n    embed_dim = state_dict[\"text_projection\"].shape[1]\n    context_length = state_dict[\"positional_embedding\"].shape[0]\n    vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n    transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n    transformer_heads = transformer_width // 64\n    transformer_layers = len(\n        set(\n            k.split(\".\")[2] for k in state_dict\n            if k.startswith(f\"transformer.resblocks\")))\n\n    model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,\n                 vision_patch_size, context_length, vocab_size,\n                 transformer_width, transformer_heads, transformer_layers)\n\n    for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n        if key in state_dict:\n            del state_dict[key]\n\n    convert_weights(model)\n    model.load_state_dict(state_dict)\n    return model.eval()\n"
  },
  {
    "path": "ProGrad.public/clip/simple_tokenizer.py",
    "content": "import gzip\nimport html\nimport os\nfrom functools import lru_cache\n\nimport ftfy\nimport regex as re\n\n\n@lru_cache()\ndef default_bpe():\n    return os.path.join(os.path.dirname(os.path.abspath(__file__)),\n                        \"bpe_simple_vocab_16e6.txt.gz\")\n\n\n@lru_cache()\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a signficant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"),\n                    ord(\"~\") + 1)) + list(range(\n                        ord(\"¡\"),\n                        ord(\"¬\") + 1)) + list(range(ord(\"®\"),\n                                                    ord(\"ÿ\") + 1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8 + n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n    \"\"\"Return set of symbol pairs in a word.\n    Word is represented as tuple of symbols (symbols being variable-length strings).\n    \"\"\"\n    pairs = set()\n    prev_char = word[0]\n    for char in word[1:]:\n        pairs.add((prev_char, char))\n        prev_char = char\n    return pairs\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = re.sub(r'\\s+', ' ', text)\n    text = text.strip()\n    return text\n\n\nclass SimpleTokenizer(object):\n    def __init__(self, bpe_path: str = default_bpe()):\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n        merges = gzip.open(bpe_path).read().decode(\"utf-8\").split('\\n')\n        merges = merges[1:49152 - 256 - 2 + 1]\n        merges = [tuple(merge.split()) for merge in merges]\n        vocab = list(bytes_to_unicode().values())\n        vocab = vocab + [v + '</w>' for v in vocab]\n        for merge in merges:\n            vocab.append(''.join(merge))\n        vocab.extend(['<|startoftext|>', '<|endoftext|>'])\n        self.encoder = dict(zip(vocab, range(len(vocab))))\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.bpe_ranks = dict(zip(merges, range(len(merges))))\n        self.cache = {\n            '<|startoftext|>': '<|startoftext|>',\n            '<|endoftext|>': '<|endoftext|>'\n        }\n        self.pat = re.compile(\n            r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\",\n            re.IGNORECASE)\n\n    def bpe(self, token):\n        if token in self.cache:\n            return self.cache[token]\n        word = tuple(token[:-1]) + (token[-1] + '</w>', )\n        pairs = get_pairs(word)\n\n        if not pairs:\n            return token + '</w>'\n\n        while True:\n            bigram = min(\n                pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))\n            if bigram not in self.bpe_ranks:\n                break\n            first, second = bigram\n            new_word = []\n            i = 0\n            while i < len(word):\n                try:\n                    j = word.index(first, i)\n                    new_word.extend(word[i:j])\n                    i = j\n                except:\n                    new_word.extend(word[i:])\n                    break\n\n                if word[i] == first and i < len(word) - 1 and word[\n                        i + 1] == second:\n                    new_word.append(first + second)\n                    i += 2\n                else:\n                    new_word.append(word[i])\n                    i += 1\n            new_word = tuple(new_word)\n            word = new_word\n            if len(word) == 1:\n                break\n            else:\n                pairs = get_pairs(word)\n        word = ' '.join(word)\n        self.cache[token] = word\n        return word\n\n    def encode(self, text):\n        bpe_tokens = []\n        text = whitespace_clean(basic_clean(text)).lower()\n        for token in re.findall(self.pat, text):\n            token = ''.join(self.byte_encoder[b]\n                            for b in token.encode('utf-8'))\n            bpe_tokens.extend(self.encoder[bpe_token]\n                              for bpe_token in self.bpe(token).split(' '))\n        return bpe_tokens\n\n    def decode(self, tokens):\n        text = ''.join([self.decoder[token] for token in tokens])\n        text = bytearray([self.byte_decoder[c] for c in text\n                          ]).decode('utf-8',\n                                    errors=\"replace\").replace('</w>', ' ')\n        return text\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/caltech101.yaml",
    "content": "DATASET:\n  NAME: \"Caltech101\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/dtd.yaml",
    "content": "DATASET:\n  NAME: \"DescribableTextures\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/eurosat.yaml",
    "content": "DATASET:\n  NAME: \"EuroSAT\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/fgvc_aircraft.yaml",
    "content": "DATASET:\n  NAME: \"FGVCAircraft\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/food101.yaml",
    "content": "DATASET:\n  NAME: \"Food101\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet.yaml",
    "content": "DATASET:\n  NAME: \"ImageNet\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet_a.yaml",
    "content": "DATASET:\n  NAME: \"ImageNetA\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet_r.yaml",
    "content": "DATASET:\n  NAME: \"ImageNetR\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet_sketch.yaml",
    "content": "DATASET:\n  NAME: \"ImageNetSketch\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenetv2.yaml",
    "content": "DATASET:\n  NAME: \"ImageNetV2\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/oxford_flowers.yaml",
    "content": "DATASET:\n  NAME: \"OxfordFlowers\""
  },
  {
    "path": "ProGrad.public/configs/datasets/oxford_pets.yaml",
    "content": "DATASET:\n  NAME: \"OxfordPets\""
  },
  {
    "path": "ProGrad.public/configs/datasets/stanford_cars.yaml",
    "content": "DATASET:\n  NAME: \"StanfordCars\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/sun397.yaml",
    "content": "DATASET:\n  NAME: \"SUN397\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/ucf101.yaml",
    "content": "DATASET:\n  NAME: \"UCF101\"\n"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/rn50_c4_ep10_batch1_ctxv1.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 10\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 20\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\"\n\nTRAINER:\n  COCOOP:\n    N_CTX: 4\n    CTX_INIT: True\n    PREC: \"fp16\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/rn50_ep100_init.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 100\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 20\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\"\n\nTRAINER:\n  COCOOP:\n    N_CTX: 16\n    CTX_INIT: True\n    PREC: \"fp16\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/rn50_ep50.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 50\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 20\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\"\n\nTRAINER:\n  COCOOP:\n    N_CTX: 16\n    CTX_INIT: True\n    PREC: \"fp16\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 10\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 20\n\nMODEL:\n  BACKBONE:\n    NAME: \"ViT-B/16\"\n\nTRAINER:\n  COCOOP:\n    N_CTX: 16\n    CTX_INIT: \"\"\n    PREC: \"fp16\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 10\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 20\n\nMODEL:\n  BACKBONE:\n    NAME: \"ViT-B/16\"\n\nTRAINER:\n  COCOOP:\n    N_CTX: 4\n    CTX_INIT: \"\"\n    PREC: \"fp16\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 10\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 20\n\nMODEL:\n  BACKBONE:\n    NAME: \"ViT-B/16\"\n\nTRAINER:\n  COCOOP:\n    N_CTX: 4\n    CTX_INIT: \"a photo of a\"\n    PREC: \"fp16\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 10\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 20\n\nMODEL:\n  BACKBONE:\n    NAME: \"ViT-B/16\"\n\nTRAINER:\n  COCOOP:\n    N_CTX: 8\n    CTX_INIT: \"\"\n    PREC: \"fp16\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 200\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50_ep100.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 100\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\"\n"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50_ep50.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 50\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nTRAIN:\n  PRINT_FREQ: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\""
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50_val.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 32\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\""
  },
  {
    "path": "ProGrad.public/configs/trainers/ProGrad/rn50.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 200 \n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nLOSS:\n  NAME: \"prograd\"\n  T: 1.0\n\nTRAIN:\n  PRINT_FREQ: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\"\n\nTRAINER:\n  COOP:\n    CTX_INIT: True\n"
  },
  {
    "path": "ProGrad.public/configs/trainers/ProGrad/rn50_ep100.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 100 \n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nLOSS:\n  NAME: \"prograd\"\n  T: 1.0\n\nTRAIN:\n  PRINT_FREQ: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\"\n\nTRAINER:\n  COOP:\n    CTX_INIT: True\n"
  },
  {
    "path": "ProGrad.public/configs/trainers/ProGrad/rn50_ep50.yaml",
    "content": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTERPOLATION: \"bicubic\"\n  PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]\n  PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]\n  TRANSFORMS: [\"random_resized_crop\", \"random_flip\", \"normalize\"]\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  MAX_EPOCH: 50\n  LR_SCHEDULER: \"cosine\"\n  WARMUP_EPOCH: 1\n  WARMUP_TYPE: \"constant\"\n  WARMUP_CONS_LR: 1e-5\n\nLOSS:\n  NAME: \"prograd\"\n  T: 1.0\n\nTRAIN:\n  PRINT_FREQ: 5\n\nMODEL:\n  BACKBONE:\n    NAME: \"RN50\"\n\nTRAINER:\n  COOP:\n    CTX_INIT: True"
  },
  {
    "path": "ProGrad.public/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "ProGrad.public/datasets/caltech101.py",
    "content": "import os\nimport pickle\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\nfrom .dtd import DescribableTextures as DTD\n\nIGNORED = [\"BACKGROUND_Google\", \"Faces_easy\"]\nNEW_CNAMES = {\n    \"airplanes\": \"airplane\",\n    \"Faces\": \"face\",\n    \"Leopards\": \"leopard\",\n    \"Motorbikes\": \"motorbike\",\n}\n\n\n@DATASET_REGISTRY.register()\nclass Caltech101(DatasetBase):\n\n    dataset_dir = \"caltech-101\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"101_ObjectCategories\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_Caltech101.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.image_dir)\n        else:\n            train, val, test = DTD.read_and_split_data(self.image_dir,\n                                                       ignored=IGNORED,\n                                                       new_cnames=NEW_CNAMES)\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n"
  },
  {
    "path": "ProGrad.public/datasets/dtd.py",
    "content": "import os\nimport pickle\nimport random\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import listdir_nohidden, mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\n\n\n@DATASET_REGISTRY.register()\nclass DescribableTextures(DatasetBase):\n\n    dataset_dir = \"dtd\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"images\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_DescribableTextures.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.image_dir)\n        else:\n            train, val, test = self.read_and_split_data(self.image_dir)\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    @staticmethod\n    def read_and_split_data(image_dir,\n                            p_trn=0.5,\n                            p_val=0.2,\n                            ignored=[],\n                            new_cnames=None):\n        # The data are supposed to be organized into the following structure\n        # =============\n        # images/\n        #     dog/\n        #     cat/\n        #     horse/\n        # =============\n        categories = listdir_nohidden(image_dir)\n        categories = [c for c in categories if c not in ignored]\n        categories.sort()\n\n        p_tst = 1 - p_trn - p_val\n        print(\n            f\"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test\"\n        )\n\n        def _collate(ims, y, c):\n            items = []\n            for im in ims:\n                item = Datum(impath=im, label=y,\n                             classname=c)  # is already 0-based\n                items.append(item)\n            return items\n\n        train, val, test = [], [], []\n        for label, category in enumerate(categories):\n            category_dir = os.path.join(image_dir, category)\n            images = listdir_nohidden(category_dir)\n            images = [os.path.join(category_dir, im) for im in images]\n            random.shuffle(images)\n            n_total = len(images)\n            n_train = round(n_total * p_trn)\n            n_val = round(n_total * p_val)\n            n_test = n_total - n_train - n_val\n            assert n_train > 0 and n_val > 0 and n_test > 0\n\n            if new_cnames is not None and category in new_cnames:\n                category = new_cnames[category]\n\n            train.extend(_collate(images[:n_train], label, category))\n            val.extend(\n                _collate(images[n_train:n_train + n_val], label, category))\n            test.extend(_collate(images[n_train + n_val:], label, category))\n\n        return train, val, test\n"
  },
  {
    "path": "ProGrad.public/datasets/eurosat.py",
    "content": "import os\nimport pickle\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\nfrom .dtd import DescribableTextures as DTD\n\nNEW_CNAMES = {\n    \"AnnualCrop\": \"Annual Crop Land\",\n    \"Forest\": \"Forest\",\n    \"HerbaceousVegetation\": \"Herbaceous Vegetation Land\",\n    \"Highway\": \"Highway or Road\",\n    \"Industrial\": \"Industrial Buildings\",\n    \"Pasture\": \"Pasture Land\",\n    \"PermanentCrop\": \"Permanent Crop Land\",\n    \"Residential\": \"Residential Buildings\",\n    \"River\": \"River\",\n    \"SeaLake\": \"Sea or Lake\",\n}\n\n\n@DATASET_REGISTRY.register()\nclass EuroSAT(DatasetBase):\n\n    dataset_dir = \"eurosat\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"2750\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_EuroSAT.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.image_dir)\n        else:\n            train, val, test = DTD.read_and_split_data(self.image_dir,\n                                                       new_cnames=NEW_CNAMES)\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def update_classname(self, dataset_old):\n        dataset_new = []\n        for item_old in dataset_old:\n            cname_old = item_old.classname\n            cname_new = NEW_CLASSNAMES[cname_old]\n            item_new = Datum(impath=item_old.impath,\n                             label=item_old.label,\n                             classname=cname_new)\n            dataset_new.append(item_new)\n        return dataset_new\n"
  },
  {
    "path": "ProGrad.public/datasets/fgvc_aircraft.py",
    "content": "import os\nimport pickle\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\n\n\n@DATASET_REGISTRY.register()\nclass FGVCAircraft(DatasetBase):\n\n    dataset_dir = \"fgvc_aircraft\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"images\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        classnames = []\n        with open(os.path.join(self.dataset_dir, \"variants.txt\"), \"r\") as f:\n            lines = f.readlines()\n            for line in lines:\n                classnames.append(line.strip())\n        cname2lab = {c: i for i, c in enumerate(classnames)}\n\n        train = self.read_data(cname2lab, \"images_variant_train.txt\")\n        val = self.read_data(cname2lab, \"images_variant_val.txt\")\n        test = self.read_data(cname2lab, \"images_variant_test.txt\")\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def read_data(self, cname2lab, split_file):\n        filepath = os.path.join(self.dataset_dir, split_file)\n        items = []\n\n        with open(filepath, \"r\") as f:\n            lines = f.readlines()\n            for line in lines:\n                line = line.strip().split(\" \")\n                imname = line[0] + \".jpg\"\n                classname = \" \".join(line[1:])\n                impath = os.path.join(self.image_dir, imname)\n                label = cname2lab[classname]\n                item = Datum(impath=impath, label=label, classname=classname)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/food101.py",
    "content": "import os\nimport pickle\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\nfrom .dtd import DescribableTextures as DTD\n\n\n@DATASET_REGISTRY.register()\nclass Food101(DatasetBase):\n\n    dataset_dir = \"food-101\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"images\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_Food101.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.image_dir)\n        else:\n            train, val, test = DTD.read_and_split_data(self.image_dir)\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n"
  },
  {
    "path": "ProGrad.public/datasets/imagenet.py",
    "content": "import os\nimport pickle\nfrom collections import OrderedDict\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import listdir_nohidden, mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\n\n\n@DATASET_REGISTRY.register()\nclass ImageNet(DatasetBase):\n\n    dataset_dir = \"imagenet\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"images\")\n        self.preprocessed = os.path.join(self.dataset_dir, \"preprocessed.pkl\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.preprocessed):\n            with open(self.preprocessed, \"rb\") as f:\n                preprocessed = pickle.load(f)\n                train = preprocessed[\"train\"]\n                test = preprocessed[\"test\"]\n        else:\n            text_file = os.path.join(self.dataset_dir, \"classnames.txt\")\n            classnames = self.read_classnames(text_file)\n            train = self.read_data(classnames, \"train\")\n            # Follow standard practice to perform evaluation on the val set\n            # Also used as the val set (so evaluate the last-step model)\n            test = self.read_data(classnames, \"val\")\n\n            preprocessed = {\"train\": train, \"test\": test}\n            with open(self.preprocessed, \"wb\") as f:\n                pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train = data[\"train\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                data = {\"train\": train}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, test = OxfordPets.subsample_classes(train,\n                                                   test,\n                                                   subsample=subsample)\n\n        super().__init__(train_x=train, val=test, test=test)\n\n    @staticmethod\n    def read_classnames(text_file):\n        \"\"\"Return a dictionary containing\n        key-value pairs of <folder name>: <class name>.\n        \"\"\"\n        classnames = OrderedDict()\n        with open(text_file, \"r\") as f:\n            lines = f.readlines()\n            for line in lines:\n                line = line.strip().split(\" \")\n                folder = line[0]\n                classname = \" \".join(line[1:])\n                classnames[folder] = classname\n        return classnames\n\n    def read_data(self, classnames, split_dir):\n        split_dir = os.path.join(self.image_dir, split_dir)\n        folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())\n        items = []\n\n        for label, folder in enumerate(folders):\n            imnames = listdir_nohidden(os.path.join(split_dir, folder))\n            classname = classnames[folder]\n            for imname in imnames:\n                impath = os.path.join(split_dir, folder, imname)\n                item = Datum(impath=impath, label=label, classname=classname)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/imagenet_a.py",
    "content": "import os\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import listdir_nohidden\n\nfrom .imagenet import ImageNet\n\nTO_BE_IGNORED = [\"README.txt\"]\n\n\n@DATASET_REGISTRY.register()\nclass ImageNetA(DatasetBase):\n    \"\"\"ImageNet-A(dversarial).\n\n    This dataset is used for testing only.\n    \"\"\"\n\n    dataset_dir = \"imagenet-adversarial\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"imagenet-a\")\n\n        text_file = os.path.join(self.dataset_dir, \"classnames.txt\")\n        classnames = ImageNet.read_classnames(text_file)\n\n        data = self.read_data(classnames)\n\n        super().__init__(train_x=data, test=data)\n\n    def read_data(self, classnames):\n        image_dir = self.image_dir\n        folders = listdir_nohidden(image_dir, sort=True)\n        folders = [f for f in folders if f not in TO_BE_IGNORED]\n        items = []\n\n        for label, folder in enumerate(folders):\n            imnames = listdir_nohidden(os.path.join(image_dir, folder))\n            classname = classnames[folder]\n            for imname in imnames:\n                impath = os.path.join(image_dir, folder, imname)\n                item = Datum(impath=impath, label=label, classname=classname)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/imagenet_r.py",
    "content": "import os\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import listdir_nohidden\n\nfrom .imagenet import ImageNet\n\nTO_BE_IGNORED = [\"README.txt\"]\n\n\n@DATASET_REGISTRY.register()\nclass ImageNetR(DatasetBase):\n    \"\"\"ImageNet-R(endition).\n\n    This dataset is used for testing only.\n    \"\"\"\n\n    dataset_dir = \"imagenet-rendition\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"imagenet-r\")\n\n        text_file = os.path.join(self.dataset_dir, \"classnames.txt\")\n        classnames = ImageNet.read_classnames(text_file)\n\n        data = self.read_data(classnames)\n\n        super().__init__(train_x=data, test=data)\n\n    def read_data(self, classnames):\n        image_dir = self.image_dir\n        folders = listdir_nohidden(image_dir, sort=True)\n        folders = [f for f in folders if f not in TO_BE_IGNORED]\n        items = []\n\n        for label, folder in enumerate(folders):\n            imnames = listdir_nohidden(os.path.join(image_dir, folder))\n            classname = classnames[folder]\n            for imname in imnames:\n                impath = os.path.join(image_dir, folder, imname)\n                item = Datum(impath=impath, label=label, classname=classname)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/imagenet_sketch.py",
    "content": "import os\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import listdir_nohidden\n\nfrom .imagenet import ImageNet\n\n\n@DATASET_REGISTRY.register()\nclass ImageNetSketch(DatasetBase):\n    \"\"\"ImageNet-Sketch.\n\n    This dataset is used for testing only.\n    \"\"\"\n\n    dataset_dir = \"imagenet-sketch\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"images\")\n\n        text_file = os.path.join(self.dataset_dir, \"classnames.txt\")\n        classnames = ImageNet.read_classnames(text_file)\n\n        data = self.read_data(classnames)\n\n        super().__init__(train_x=data, test=data)\n\n    def read_data(self, classnames):\n        image_dir = self.image_dir\n        folders = listdir_nohidden(image_dir, sort=True)\n        items = []\n\n        for label, folder in enumerate(folders):\n            imnames = listdir_nohidden(os.path.join(image_dir, folder))\n            classname = classnames[folder]\n            for imname in imnames:\n                impath = os.path.join(image_dir, folder, imname)\n                item = Datum(impath=impath, label=label, classname=classname)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/imagenetv2.py",
    "content": "import os\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import listdir_nohidden\n\nfrom .imagenet import ImageNet\n\n\n@DATASET_REGISTRY.register()\nclass ImageNetV2(DatasetBase):\n    \"\"\"ImageNetV2.\n\n    This dataset is used for testing only.\n    \"\"\"\n\n    dataset_dir = \"imagenetv2\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        image_dir = \"imagenetv2-matched-frequency-format-val\"\n        self.image_dir = os.path.join(self.dataset_dir, image_dir)\n\n        text_file = os.path.join(self.dataset_dir, \"classnames.txt\")\n        classnames = ImageNet.read_classnames(text_file)\n\n        data = self.read_data(classnames)\n\n        super().__init__(train_x=data, test=data)\n\n    def read_data(self, classnames):\n        image_dir = self.image_dir\n        folders = list(classnames.keys())\n        items = []\n\n        for label in range(1000):\n            class_dir = os.path.join(image_dir, str(label))\n            imnames = listdir_nohidden(class_dir)\n            folder = folders[label]\n            classname = classnames[folder]\n            for imname in imnames:\n                impath = os.path.join(class_dir, imname)\n                item = Datum(impath=impath, label=label, classname=classname)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/oxford_flowers.py",
    "content": "import os\nimport pickle\nimport random\nfrom scipy.io import loadmat\nfrom collections import defaultdict\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import read_json, mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\n\n\n@DATASET_REGISTRY.register()\nclass OxfordFlowers(DatasetBase):\n\n    dataset_dir = \"oxford_flowers\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"jpg\")\n        self.label_file = os.path.join(self.dataset_dir, \"imagelabels.mat\")\n        self.lab2cname_file = os.path.join(self.dataset_dir,\n                                           \"cat_to_name.json\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_OxfordFlowers.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.image_dir)\n        else:\n            train, val, test = self.read_data()\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def read_data(self):\n        tracker = defaultdict(list)\n        label_file = loadmat(self.label_file)[\"labels\"][0]\n        for i, label in enumerate(label_file):\n            imname = f\"image_{str(i + 1).zfill(5)}.jpg\"\n            impath = os.path.join(self.image_dir, imname)\n            label = int(label)\n            tracker[label].append(impath)\n\n        print(\"Splitting data into 50% train, 20% val, and 30% test\")\n\n        def _collate(ims, y, c):\n            items = []\n            for im in ims:\n                item = Datum(impath=im, label=y - 1,\n                             classname=c)  # convert to 0-based label\n                items.append(item)\n            return items\n\n        lab2cname = read_json(self.lab2cname_file)\n        train, val, test = [], [], []\n        for label, impaths in tracker.items():\n            random.shuffle(impaths)\n            n_total = len(impaths)\n            n_train = round(n_total * 0.5)\n            n_val = round(n_total * 0.2)\n            n_test = n_total - n_train - n_val\n            assert n_train > 0 and n_val > 0 and n_test > 0\n            cname = lab2cname[str(label)]\n            train.extend(_collate(impaths[:n_train], label, cname))\n            val.extend(_collate(impaths[n_train:n_train + n_val], label,\n                                cname))\n            test.extend(_collate(impaths[n_train + n_val:], label, cname))\n\n        return train, val, test\n"
  },
  {
    "path": "ProGrad.public/datasets/oxford_pets.py",
    "content": "import os\nimport pickle\nimport math\nimport random\nfrom collections import defaultdict\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import read_json, write_json, mkdir_if_missing\n\n\n@DATASET_REGISTRY.register()\nclass OxfordPets(DatasetBase):\n\n    dataset_dir = \"oxford_pets\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"images\")\n        self.anno_dir = os.path.join(self.dataset_dir, \"annotations\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_OxfordPets.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = self.read_split(self.split_path, self.image_dir)\n        else:\n            trainval = self.read_data(split_file=\"trainval.txt\")\n            test = self.read_data(split_file=\"test.txt\")\n            train, val = self.split_trainval(trainval)\n            self.save_split(train, val, test, self.split_path, self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = self.subsample_classes(train,\n                                                  val,\n                                                  test,\n                                                  subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def read_data(self, split_file):\n        filepath = os.path.join(self.anno_dir, split_file)\n        items = []\n\n        with open(filepath, \"r\") as f:\n            lines = f.readlines()\n            for line in lines:\n                line = line.strip()\n                imname, label, species, _ = line.split(\" \")\n                breed = imname.split(\"_\")[:-1]\n                breed = \"_\".join(breed)\n                breed = breed.lower()\n                imname += \".jpg\"\n                impath = os.path.join(self.image_dir, imname)\n                label = int(label) - 1  # convert to 0-based index\n                item = Datum(impath=impath, label=label, classname=breed)\n                items.append(item)\n\n        return items\n\n    @staticmethod\n    def split_trainval(trainval, p_val=0.2):\n        p_trn = 1 - p_val\n        print(f\"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val\")\n        tracker = defaultdict(list)\n        for idx, item in enumerate(trainval):\n            label = item.label\n            tracker[label].append(idx)\n\n        train, val = [], []\n        for label, idxs in tracker.items():\n            n_val = round(len(idxs) * p_val)\n            assert n_val > 0\n            random.shuffle(idxs)\n            for n, idx in enumerate(idxs):\n                item = trainval[idx]\n                if n < n_val:\n                    val.append(item)\n                else:\n                    train.append(item)\n\n        return train, val\n\n    @staticmethod\n    def save_split(train, val, test, filepath, path_prefix):\n        def _extract(items):\n            out = []\n            for item in items:\n                impath = item.impath\n                label = item.label\n                classname = item.classname\n                impath = impath.replace(path_prefix, \"\")\n                if impath.startswith(\"/\"):\n                    impath = impath[1:]\n                out.append((impath, label, classname))\n            return out\n\n        train = _extract(train)\n        val = _extract(val)\n        test = _extract(test)\n\n        split = {\"train\": train, \"val\": val, \"test\": test}\n\n        write_json(split, filepath)\n        print(f\"Saved split to {filepath}\")\n\n    @staticmethod\n    def read_split(filepath, path_prefix):\n        def _convert(items):\n            out = []\n            for impath, label, classname in items:\n                impath = os.path.join(path_prefix, impath)\n                item = Datum(impath=impath,\n                             label=int(label),\n                             classname=classname)\n                out.append(item)\n            return out\n\n        print(f\"Reading split from {filepath}\")\n        split = read_json(filepath)\n        train = _convert(split[\"train\"])\n        val = _convert(split[\"val\"])\n        test = _convert(split[\"test\"])\n\n        return train, val, test\n\n    @staticmethod\n    def subsample_classes(*args, subsample=\"all\"):\n        \"\"\"Divide classes into two groups. The first group\n        represents base classes while the second group represents\n        new classes.\n\n        Args:\n            args: a list of datasets, e.g. train, val and test.\n            subsample (str): what classes to subsample.\n        \"\"\"\n        assert subsample in [\"all\", \"base\", \"new\"]\n\n        if subsample == \"all\":\n            return args\n\n        dataset = args[0]\n        labels = set()\n        for item in dataset:\n            labels.add(item.label)\n        labels = list(labels)\n        labels.sort()\n        n = len(labels)\n        # Divide classes into two halves\n        m = math.ceil(n / 2)\n\n        print(f\"SUBSAMPLE {subsample.upper()} CLASSES!\")\n        if subsample == \"base\":\n            selected = labels[:m]  # take the first half\n        else:\n            selected = labels[m:]  # take the second half\n        relabeler = {y: y_new for y_new, y in enumerate(selected)}\n\n        output = []\n        for dataset in args:\n            dataset_new = []\n            for item in dataset:\n                if item.label not in selected:\n                    continue\n                item_new = Datum(impath=item.impath,\n                                 label=relabeler[item.label],\n                                 classname=item.classname)\n                dataset_new.append(item_new)\n            output.append(dataset_new)\n\n        return output\n"
  },
  {
    "path": "ProGrad.public/datasets/stanford_cars.py",
    "content": "import os\nimport pickle\nfrom scipy.io import loadmat\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\n\n\n@DATASET_REGISTRY.register()\nclass StanfordCars(DatasetBase):\n\n    dataset_dir = \"stanford_cars\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_StanfordCars.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.dataset_dir)\n        else:\n            trainval_file = os.path.join(self.dataset_dir, \"devkit\",\n                                         \"cars_train_annos.mat\")\n            test_file = os.path.join(self.dataset_dir,\n                                     \"cars_test_annos_withlabels.mat\")\n            meta_file = os.path.join(self.dataset_dir, \"devkit\",\n                                     \"cars_meta.mat\")\n            trainval = self.read_data(\"cars_train\", trainval_file, meta_file)\n            test = self.read_data(\"cars_test\", test_file, meta_file)\n            train, val = OxfordPets.split_trainval(trainval)\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.dataset_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def read_data(self, image_dir, anno_file, meta_file):\n        anno_file = loadmat(anno_file)[\"annotations\"][0]\n        meta_file = loadmat(meta_file)[\"class_names\"][0]\n        items = []\n\n        for i in range(len(anno_file)):\n            imname = anno_file[i][\"fname\"][0]\n            impath = os.path.join(self.dataset_dir, image_dir, imname)\n            label = anno_file[i][\"class\"][0, 0]\n            label = int(label) - 1  # convert to 0-based index\n            classname = meta_file[label][0]\n            names = classname.split(\" \")\n            year = names.pop(-1)\n            names.insert(0, year)\n            classname = \" \".join(names)\n            item = Datum(impath=impath, label=label, classname=classname)\n            items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/sun397.py",
    "content": "import os\nimport pickle\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\n\n\n@DATASET_REGISTRY.register()\nclass SUN397(DatasetBase):\n\n    dataset_dir = \"sun397\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"SUN397\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_SUN397.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.image_dir)\n        else:\n            classnames = []\n            with open(os.path.join(self.dataset_dir, \"ClassName.txt\"),\n                      \"r\") as f:\n                lines = f.readlines()\n                for line in lines:\n                    line = line.strip()[1:]  # remove /\n                    classnames.append(line)\n            cname2lab = {c: i for i, c in enumerate(classnames)}\n            trainval = self.read_data(cname2lab, \"Training_01.txt\")\n            test = self.read_data(cname2lab, \"Testing_01.txt\")\n            train, val = OxfordPets.split_trainval(trainval)\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def read_data(self, cname2lab, text_file):\n        text_file = os.path.join(self.dataset_dir, text_file)\n        items = []\n\n        with open(text_file, \"r\") as f:\n            lines = f.readlines()\n            for line in lines:\n                imname = line.strip()[1:]  # remove /\n                classname = os.path.dirname(imname)\n                label = cname2lab[classname]\n                impath = os.path.join(self.image_dir, imname)\n\n                names = classname.split(\"/\")[1:]  # remove 1st letter\n                names = names[::-1]  # put words like indoor/outdoor at first\n                classname = \" \".join(names)\n\n                item = Datum(impath=impath, label=label, classname=classname)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/datasets/ucf101.py",
    "content": "import os\nimport pickle\nimport re\n\nfrom dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase\nfrom dassl.utils import mkdir_if_missing\n\nfrom .oxford_pets import OxfordPets\n\n\n@DATASET_REGISTRY.register()\nclass UCF101(DatasetBase):\n\n    dataset_dir = \"ucf101\"\n\n    def __init__(self, cfg):\n        root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))\n        self.dataset_dir = os.path.join(root, self.dataset_dir)\n        self.image_dir = os.path.join(self.dataset_dir, \"UCF-101-midframes\")\n        self.split_path = os.path.join(self.dataset_dir,\n                                       \"split_zhou_UCF101.json\")\n        self.split_fewshot_dir = os.path.join(self.dataset_dir,\n                                              \"split_fewshot\")\n        mkdir_if_missing(self.split_fewshot_dir)\n\n        if os.path.exists(self.split_path):\n            train, val, test = OxfordPets.read_split(self.split_path,\n                                                     self.image_dir)\n        else:\n            cname2lab = {}\n            filepath = os.path.join(self.dataset_dir,\n                                    \"ucfTrainTestlist/classInd.txt\")\n            with open(filepath, \"r\") as f:\n                lines = f.readlines()\n                for line in lines:\n                    label, classname = line.strip().split(\" \")\n                    label = int(label) - 1  # conver to 0-based index\n                    cname2lab[classname] = label\n\n            trainval = self.read_data(cname2lab,\n                                      \"ucfTrainTestlist/trainlist01.txt\")\n            test = self.read_data(cname2lab, \"ucfTrainTestlist/testlist01.txt\")\n            train, val = OxfordPets.split_trainval(trainval)\n            OxfordPets.save_split(train, val, test, self.split_path,\n                                  self.image_dir)\n\n        num_shots = cfg.DATASET.NUM_SHOTS\n        if num_shots >= 1:\n            seed = cfg.SEED\n            preprocessed = os.path.join(self.split_fewshot_dir,\n                                        f\"shot_{num_shots}-seed_{seed}.pkl\")\n\n            if os.path.exists(preprocessed):\n                print(\n                    f\"Loading preprocessed few-shot data from {preprocessed}\")\n                with open(preprocessed, \"rb\") as file:\n                    data = pickle.load(file)\n                    train, val = data[\"train\"], data[\"val\"]\n            else:\n                train = self.generate_fewshot_dataset(train,\n                                                      num_shots=num_shots)\n                val = self.generate_fewshot_dataset(val,\n                                                    num_shots=min(\n                                                        num_shots, 4))\n                data = {\"train\": train, \"val\": val}\n                print(f\"Saving preprocessed few-shot data to {preprocessed}\")\n                with open(preprocessed, \"wb\") as file:\n                    pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n        subsample = cfg.DATASET.SUBSAMPLE_CLASSES\n        train, val, test = OxfordPets.subsample_classes(train,\n                                                        val,\n                                                        test,\n                                                        subsample=subsample)\n\n        super().__init__(train_x=train, val=val, test=test)\n\n    def read_data(self, cname2lab, text_file):\n        text_file = os.path.join(self.dataset_dir, text_file)\n        items = []\n\n        with open(text_file, \"r\") as f:\n            lines = f.readlines()\n            for line in lines:\n                line = line.strip().split(\" \")[0]  # trainlist: filename, label\n                action, filename = line.split(\"/\")\n                label = cname2lab[action]\n\n                elements = re.findall(\"[A-Z][^A-Z]*\", action)\n                renamed_action = \"_\".join(elements)\n\n                filename = filename.replace(\".avi\", \".jpg\")\n                impath = os.path.join(self.image_dir, renamed_action, filename)\n\n                item = Datum(impath=impath,\n                             label=label,\n                             classname=renamed_action)\n                items.append(item)\n\n        return items\n"
  },
  {
    "path": "ProGrad.public/interpret_prompt.py",
    "content": "import os\nimport sys\nimport argparse\nimport torch\n\nfrom clip.simple_tokenizer import SimpleTokenizer\nfrom clip import clip\n\n\ndef load_clip_to_cpu(backbone_name=\"RN50\"):\n    url = clip._MODELS[backbone_name]\n    model_path = clip._download(url)\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=\"cpu\").eval()\n        state_dict = None\n\n    except RuntimeError:\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    model = clip.build_model(state_dict or model.state_dict())\n\n    return model\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"fpath\", type=str, help=\"Path to the learned prompt\")\nparser.add_argument(\"topk\", type=int, help=\"Select top-k similar words\")\nargs = parser.parse_args()\n\nfpath = args.fpath\ntopk = args.topk\n\nassert os.path.exists(fpath)\n\nprint(f\"Return the top-{topk} matched words\")\n\ntokenizer = SimpleTokenizer()\nclip_model = load_clip_to_cpu()\ntoken_embedding = clip_model.token_embedding.weight\nprint(f\"Size of token embedding: {token_embedding.shape}\")\n\nprompt_learner = torch.load(fpath, map_location=\"cpu\")[\"state_dict\"]\nctx = prompt_learner[\"ctx\"]\nctx = ctx.float()\nprint(f\"Size of context: {ctx.shape}\")\n\nif ctx.dim() == 2:\n    # Generic context\n    distance = torch.cdist(ctx, token_embedding)\n    print(f\"Size of distance matrix: {distance.shape}\")\n    sorted_idxs = torch.argsort(distance, dim=1)\n    sorted_idxs = sorted_idxs[:, :topk]\n\n    for m, idxs in enumerate(sorted_idxs):\n        words = [tokenizer.decoder[idx.item()] for idx in idxs]\n        dist = [f\"{distance[m, idx].item():.4f}\" for idx in idxs]\n        print(f\"{m+1}: {words} {dist}\")\n\nelif ctx.dim() == 3:\n    # Class-specific context\n    raise NotImplementedError\n"
  },
  {
    "path": "ProGrad.public/lpclip/README.md",
    "content": "# Linear Probe CLIP\n\nTo run linear probe baselines, make sure that your current working directory is `lpclip/`.\n\nStep 1: Extract Features using the CLIP Image Encoder\n```bash\nsh feat_extractor.sh\n```\n\nStep 2: Train few-shot linear probe\n```bash\nsh linear_probe.sh\n```\n\nWe follow the instructions stated in the Appendix A3 (pp.38) of [the original CLIP paper](https://arxiv.org/pdf/2103.00020.pdf), with a careful hyperparameter sweep.\n\nNote: please pull the latest Dassl (version >= `606a2c6`).\n"
  },
  {
    "path": "ProGrad.public/lpclip/feat_extractor.py",
    "content": "import os, argparse\nimport numpy as np\nimport torch\nimport sys\n\nsys.path.append(os.path.abspath(\"..\"))\n\nfrom datasets.oxford_pets import OxfordPets\nfrom datasets.oxford_flowers import OxfordFlowers\nfrom datasets.fgvc_aircraft import FGVCAircraft\nfrom datasets.dtd import DescribableTextures\nfrom datasets.eurosat import EuroSAT\nfrom datasets.stanford_cars import StanfordCars\nfrom datasets.food101 import Food101\nfrom datasets.sun397 import SUN397\nfrom datasets.caltech101 import Caltech101\nfrom datasets.ucf101 import UCF101\nfrom datasets.imagenet import ImageNet\nfrom datasets.imagenetv2 import ImageNetV2\nfrom datasets.imagenet_sketch import ImageNetSketch\nfrom datasets.imagenet_a import ImageNetA\nfrom datasets.imagenet_r import ImageNetR\n\nfrom dassl.utils import setup_logger, set_random_seed, collect_env_info\nfrom dassl.config import get_cfg_default\nfrom dassl.data.transforms import build_transform\nfrom dassl.data import DatasetWrapper\n\nimport clip\n\n# import pdb; pdb.set_trace()\n\n\ndef print_args(args, cfg):\n    print(\"***************\")\n    print(\"** Arguments **\")\n    print(\"***************\")\n    optkeys = list(args.__dict__.keys())\n    optkeys.sort()\n    for key in optkeys:\n        print(\"{}: {}\".format(key, args.__dict__[key]))\n    print(\"************\")\n    print(\"** Config **\")\n    print(\"************\")\n    print(cfg)\n\n\ndef reset_cfg(cfg, args):\n    if args.root:\n        cfg.DATASET.ROOT = args.root\n\n    if args.output_dir:\n        cfg.OUTPUT_DIR = args.output_dir\n\n    if args.trainer:\n        cfg.TRAINER.NAME = args.trainer\n\n    if args.backbone:\n        cfg.MODEL.BACKBONE.NAME = args.backbone\n\n    if args.head:\n        cfg.MODEL.HEAD.NAME = args.head\n\n\ndef extend_cfg(cfg):\n    \"\"\"\n    Add new config variables.\n\n    E.g.\n        from yacs.config import CfgNode as CN\n        cfg.TRAINER.MY_MODEL = CN()\n        cfg.TRAINER.MY_MODEL.PARAM_A = 1.\n        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5\n        cfg.TRAINER.MY_MODEL.PARAM_C = False\n    \"\"\"\n    from yacs.config import CfgNode as CN\n\n    cfg.TRAINER.OURS = CN()\n    cfg.TRAINER.OURS.N_CTX = 10  # number of context vectors\n    cfg.TRAINER.OURS.CSC = False  # class-specific context\n    cfg.TRAINER.OURS.CTX_INIT = \"\"  # initialize context vectors with given words\n    cfg.TRAINER.OURS.WEIGHT_U = 0.1  # weight for the unsupervised loss\n    cfg.DATASET.SUBSAMPLE_CLASSES = \"all\"  # all, base or new\n\n\ndef setup_cfg(args):\n    cfg = get_cfg_default()\n    extend_cfg(cfg)\n\n    # 1. From the dataset config file\n    if args.dataset_config_file:\n        cfg.merge_from_file(args.dataset_config_file)\n\n    # 2. From the method config file\n    if args.config_file:\n        cfg.merge_from_file(args.config_file)\n\n    # 3. From input arguments\n    reset_cfg(cfg, args)\n\n    cfg.freeze()\n\n    return cfg\n\n\ndef main(args):\n    cfg = setup_cfg(args)\n    if cfg.SEED >= 0:\n        print(\"Setting fixed seed: {}\".format(cfg.SEED))\n        set_random_seed(cfg.SEED)\n    setup_logger(cfg.OUTPUT_DIR)\n\n    if torch.cuda.is_available() and cfg.USE_CUDA:\n        torch.backends.cudnn.benchmark = True\n\n    print_args(args, cfg)\n    print(\"Collecting env info ...\")\n    print(\"** System info **\\n{}\\n\".format(collect_env_info()))\n\n    ######################################\n    #   Setup DataLoader\n    ######################################\n    dataset = eval(cfg.DATASET.NAME)(cfg)\n\n    if args.split == \"train\":\n        dataset_input = dataset.train_x\n    elif args.split == \"val\":\n        dataset_input = dataset.val\n    else:\n        dataset_input = dataset.test\n\n    tfm_train = build_transform(cfg, is_train=False)\n    data_loader = torch.utils.data.DataLoader(\n        DatasetWrapper(cfg, dataset_input, transform=tfm_train,\n                       is_train=False),\n        batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,\n        sampler=None,\n        shuffle=False,\n        num_workers=cfg.DATALOADER.NUM_WORKERS,\n        drop_last=False,\n        pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),\n    )\n\n    ########################################\n    #   Setup Network\n    ########################################\n    clip_model, _ = clip.load(\"RN50\", \"cuda\", jit=False)\n    clip_model.eval()\n    ###################################################################################################################\n    # Start Feature Extractor\n    feature_list = []\n    label_list = []\n    train_dataiter = iter(data_loader)\n    for train_step in range(1, len(train_dataiter) + 1):\n        batch = next(train_dataiter)\n        data = batch[\"img\"].cuda()\n        feature = clip_model.visual(data)\n        feature = feature.cpu()\n        for idx in range(len(data)):\n            feature_list.append(feature[idx].tolist())\n        label_list.extend(batch[\"label\"].tolist())\n    save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.DATASET.NAME)\n    os.makedirs(save_dir, exist_ok=True)\n    save_filename = f\"{args.split}\"\n    np.savez(\n        os.path.join(save_dir, save_filename),\n        feature_list=feature_list,\n        label_list=label_list,\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--root\", type=str, default=\"\", help=\"path to dataset\")\n    parser.add_argument(\"--output-dir\",\n                        type=str,\n                        default=\"\",\n                        help=\"output directory\")\n    parser.add_argument(\"--config-file\",\n                        type=str,\n                        default=\"\",\n                        help=\"path to config file\")\n    parser.add_argument(\n        \"--dataset-config-file\",\n        type=str,\n        default=\"\",\n        help=\"path to config file for dataset setup\",\n    )\n    parser.add_argument(\"--num-shot\",\n                        type=int,\n                        default=1,\n                        help=\"number of shots\")\n    parser.add_argument(\"--split\",\n                        type=str,\n                        choices=[\"train\", \"val\", \"test\"],\n                        help=\"which split\")\n    parser.add_argument(\"--trainer\",\n                        type=str,\n                        default=\"\",\n                        help=\"name of trainer\")\n    parser.add_argument(\"--backbone\",\n                        type=str,\n                        default=\"\",\n                        help=\"name of CNN backbone\")\n    parser.add_argument(\"--head\", type=str, default=\"\", help=\"name of head\")\n    parser.add_argument(\"--seed\",\n                        type=int,\n                        default=-1,\n                        help=\"only positive value enables a fixed seed\")\n    parser.add_argument(\"--eval-only\",\n                        action=\"store_true\",\n                        help=\"evaluation only\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "ProGrad.public/lpclip/feat_extractor.sh",
    "content": "# sh feat_extractor.sh\nDATA=/data1/CoOpData\nOUTPUT='/data1/CoOpData/clip_feat/'\nSEED=1\n\nGPULIST=(0 1 2 3)\nGPUIDX=0\n\n# oxford_pets oxford_flowers fgvc_aircraft dtd eurosat stanford_cars food101 sun397 caltech101 ucf101 imagenet\n# imagenet oxford_pets oxford_flowers stanford_cars food101 caltech101\nfor DATASET in imagenetv2 imagenet_sketch imagenet_a imagenet_r\ndo\n    for SPLIT in train val test\n    do\n        while true \n        do \n            sleep 10\n            let STATIDX=GPULIST[GPUIDX]+2\n            stat=$(gpustat | awk '{print $11}' | sed -n ${STATIDX}'p')\n            if [ \"$stat\" -lt 20 ]\n            then\n                break\n            fi \n            let GPUIDX=(GPUIDX+1)%${#GPULIST[@]}\n            echo $GPUIDX'N'\n        done\n        CUDA_VISIBLE_DEVICES=${GPULIST[${GPUIDX}]} python feat_extractor.py \\\n        --split ${SPLIT} \\\n        --root ${DATA} \\\n        --seed ${SEED} \\\n        --dataset-config-file ../configs/datasets/${DATASET}.yaml \\\n        --config-file ../configs/trainers/CoOp/rn50_val.yaml \\\n        --output-dir ${OUTPUT} \\\n        --eval-only &\n        sleep 10\n    done\ndone\n"
  },
  {
    "path": "ProGrad.public/lpclip/linear_probe.py",
    "content": "import numpy as np\nimport os\nfrom sklearn.linear_model import LogisticRegression\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--dataset\", type=str, default=\"\", help=\"path to dataset\")\nparser.add_argument(\"--num_step\", type=int, default=8, help=\"number of steps\")\nparser.add_argument(\"--num_run\", type=int, default=10, help=\"number of runs\")\nparser.add_argument(\"--feature_dir\",\n                    type=str,\n                    default=\"clip_feat\",\n                    help=\"feature dir path\")\nargs = parser.parse_args()\n\ndataset = args.dataset\ndataset_path = os.path.join(f\"{args.feature_dir}\", dataset)\n\ntrain_file = np.load(os.path.join(dataset_path, \"train.npz\"))\ntrain_feature, train_label = train_file[\"feature_list\"], train_file[\n    \"label_list\"]\nval_file = np.load(os.path.join(dataset_path, \"val.npz\"))\nval_feature, val_label = val_file[\"feature_list\"], val_file[\"label_list\"]\ntest_file = np.load(os.path.join(dataset_path, \"test.npz\"))\ntest_feature, test_label = test_file[\"feature_list\"], test_file[\"label_list\"]\n\nos.makedirs(\"report\", exist_ok=True)\n\nval_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4}\n\n# for num_shot in [1, 2, 4, 8, 16]:\nfor num_shot in [4, 16]:\n    test_acc_step_list = np.zeros([args.num_run, args.num_step])\n    for seed in range(1, args.num_run + 1):\n        np.random.seed(seed)\n        print(\n            f\"-- Seed: {seed} --------------------------------------------------------------\"\n        )\n        # Sampling\n        all_label_list = np.unique(train_label)\n        selected_idx_list = []\n        for label in all_label_list:\n            label_collection = np.where(train_label == label)[0]\n            selected_idx = np.random.choice(label_collection,\n                                            size=num_shot,\n                                            replace=False)\n            selected_idx_list.extend(selected_idx)\n\n        fewshot_train_feature = train_feature[selected_idx_list]\n        fewshot_train_label = train_label[selected_idx_list]\n\n        val_num_shot = val_shot_list[num_shot]\n        val_selected_idx_list = []\n        for label in all_label_list:\n            label_collection = np.where(val_label == label)[0]\n            selected_idx = np.random.choice(label_collection,\n                                            size=val_num_shot,\n                                            replace=False)\n            val_selected_idx_list.extend(selected_idx)\n\n        fewshot_val_feature = val_feature[val_selected_idx_list]\n        fewshot_val_label = val_label[val_selected_idx_list]\n\n        # search initialization\n        search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6]\n        acc_list = []\n        for c_weight in search_list:\n            clf = LogisticRegression(solver=\"lbfgs\",\n                                     max_iter=1000,\n                                     penalty=\"l2\",\n                                     C=c_weight).fit(fewshot_train_feature,\n                                                     fewshot_train_label)\n            pred = clf.predict(fewshot_val_feature)\n            acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label)\n            acc_list.append(acc_val)\n\n        print(acc_list, flush=True)\n\n        # binary search\n        peak_idx = np.argmax(acc_list)\n        c_peak = search_list[peak_idx]\n        c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak\n\n        def binary_search(c_left, c_right, seed, step, test_acc_step_list):\n            clf_left = LogisticRegression(solver=\"lbfgs\",\n                                          max_iter=1000,\n                                          penalty=\"l2\",\n                                          C=c_left).fit(\n                                              fewshot_train_feature,\n                                              fewshot_train_label)\n            pred_left = clf_left.predict(fewshot_val_feature)\n            acc_left = sum(\n                pred_left == fewshot_val_label) / len(fewshot_val_label)\n            print(\"Val accuracy (Left): {:.2f}\".format(100 * acc_left),\n                  flush=True)\n\n            clf_right = LogisticRegression(solver=\"lbfgs\",\n                                           max_iter=1000,\n                                           penalty=\"l2\",\n                                           C=c_right).fit(\n                                               fewshot_train_feature,\n                                               fewshot_train_label)\n            pred_right = clf_right.predict(fewshot_val_feature)\n            acc_right = sum(\n                pred_right == fewshot_val_label) / len(fewshot_val_label)\n            print(\"Val accuracy (Right): {:.2f}\".format(100 * acc_right),\n                  flush=True)\n\n            # find maximum and update ranges\n            if acc_left < acc_right:\n                c_final = c_right\n                clf_final = clf_right\n                # range for the next step\n                c_left = 0.5 * (np.log10(c_right) + np.log10(c_left))\n                c_right = np.log10(c_right)\n            else:\n                c_final = c_left\n                clf_final = clf_left\n                # range for the next step\n                c_right = 0.5 * (np.log10(c_right) + np.log10(c_left))\n                c_left = np.log10(c_left)\n\n            pred = clf_final.predict(test_feature)\n            test_acc = 100 * sum(pred == test_label) / len(pred)\n            print(\"Test Accuracy: {:.2f}\".format(test_acc), flush=True)\n            test_acc_step_list[seed - 1, step] = test_acc\n\n            saveline = \"{}, seed {}, {} shot, weight {}, test_acc {:.2f}\\n\".format(\n                dataset, seed, num_shot, c_final, test_acc)\n            with open(\n                    \"./report/{}_s{}r{}_details.txt\".format(\n                        'clip_feat', args.num_step, args.num_run),\n                    \"a+\",\n            ) as writer:\n                writer.write(saveline)\n            return (\n                np.power(10, c_left),\n                np.power(10, c_right),\n                seed,\n                step,\n                test_acc_step_list,\n            )\n\n        for step in range(args.num_step):\n            print(\n                f\"{dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}\",\n                flush=True,\n            )\n            c_left, c_right, seed, step, test_acc_step_list = binary_search(\n                c_left, c_right, seed, step, test_acc_step_list)\n    # save results of last step\n    test_acc_list = test_acc_step_list[:, -1]\n    acc_mean = np.mean(test_acc_list)\n    acc_std = np.std(test_acc_list)\n    save_line = \"{}, {} Shot, Test acc stat: {:.2f} ({:.2f})\\n\".format(\n        dataset, num_shot, acc_mean, acc_std)\n    print(save_line, flush=True)\n    with open(\n            \"./report/{}_s{}r{}.txt\".format('clip_feat', args.num_step,\n                                            args.num_run),\n            \"a+\",\n    ) as writer:\n        writer.write(save_line)\n"
  },
  {
    "path": "ProGrad.public/lpclip/linear_probe.sh",
    "content": "feature_dir=/data1/CoOpData/clip_feat/\n# ImageNet OxfordPets OxfordFlowers StanfordCars Food101 Caltech101\nfor DATASET in ImageNet\ndo\n    python linear_probe.py \\\n    --dataset ${DATASET} \\\n    --feature_dir ${feature_dir} \\\n    --num_step 8 \\\n    --num_run 3\ndone\n"
  },
  {
    "path": "ProGrad.public/lpclip/linear_probe_transfer.py",
    "content": "import numpy as np\nimport os\nfrom sklearn.linear_model import LogisticRegression\nimport argparse\n\nparser = argparse.ArgumentParser()\n# parser.add_argument(\"--train_dataset\",\n#                     type=str,\n#                     default=\"\",\n#                     help=\"path to train dataset\")\n# parser.add_argument(\"--test_dataset\",\n#                     type=str,\n#                     default=\"\",\n#                     help=\"path to test dataset\")\nparser.add_argument(\"--num_step\", type=int, default=8, help=\"number of steps\")\nparser.add_argument(\"--num_run\", type=int, default=10, help=\"number of runs\")\nparser.add_argument(\"--feature_dir\",\n                    type=str,\n                    default=\"/data1/CoOpData/clip_feat/\",\n                    help=\"feature dir path\")\nargs = parser.parse_args()\n\ntrain_dataset = 'ImageNet'\ntrain_dataset_path = os.path.join(f\"{args.feature_dir}\", train_dataset)\ntest_datasets = ['ImageNetV2', 'ImageNetSketch', 'ImageNetR', 'ImageNetA']\ntest_dataset_paths = [\n    os.path.join(f\"{args.feature_dir}\", test_dataset)\n    for test_dataset in test_datasets\n]\n\ntrain_file = np.load(os.path.join(train_dataset_path, \"train.npz\"))\ntrain_feature, train_label = train_file[\"feature_list\"], train_file[\n    \"label_list\"]\nval_file = np.load(os.path.join(train_dataset_path, \"val.npz\"))\nval_feature, val_label = val_file[\"feature_list\"], val_file[\"label_list\"]\n\ntest_files = [\n    np.load(os.path.join(test_dataset_path, \"test.npz\"))\n    for test_dataset_path in test_dataset_paths\n]\ntest_features, test_labels = [\n    test_file[\"feature_list\"] for test_file in test_files\n], [test_file[\"label_list\"] for test_file in test_files]\n\nos.makedirs(\"report\", exist_ok=True)\n\nval_shot_list = {1: 1, 2: 2, 4: 4, 8: 4, 16: 4}\n\n# for num_shot in [1, 2, 4, 8, 16]:\nfor num_shot in [16]:\n    test_acc_step_list = np.zeros(\n        [len(test_datasets), args.num_run, args.num_step])\n    for seed in range(1, args.num_run + 1):\n        np.random.seed(seed)\n        print(\n            f\"-- Seed: {seed} --------------------------------------------------------------\"\n        )\n        # Sampling\n        all_label_list = np.unique(train_label)\n        selected_idx_list = []\n        for label in all_label_list:\n            label_collection = np.where(train_label == label)[0]\n            selected_idx = np.random.choice(label_collection,\n                                            size=num_shot,\n                                            replace=False)\n            selected_idx_list.extend(selected_idx)\n\n        fewshot_train_feature = train_feature[selected_idx_list]\n        fewshot_train_label = train_label[selected_idx_list]\n\n        val_num_shot = val_shot_list[num_shot]\n        val_selected_idx_list = []\n        for label in all_label_list:\n            label_collection = np.where(val_label == label)[0]\n            selected_idx = np.random.choice(label_collection,\n                                            size=val_num_shot,\n                                            replace=False)\n            val_selected_idx_list.extend(selected_idx)\n\n        fewshot_val_feature = val_feature[val_selected_idx_list]\n        fewshot_val_label = val_label[val_selected_idx_list]\n\n        # search initialization\n        search_list = [1e6, 1e4, 1e2, 1, 1e-2, 1e-4, 1e-6]\n        acc_list = []\n        for c_weight in search_list:\n            clf = LogisticRegression(solver=\"lbfgs\",\n                                     max_iter=1000,\n                                     penalty=\"l2\",\n                                     C=c_weight).fit(fewshot_train_feature,\n                                                     fewshot_train_label)\n            pred = clf.predict(fewshot_val_feature)\n            acc_val = sum(pred == fewshot_val_label) / len(fewshot_val_label)\n            acc_list.append(acc_val)\n\n        print(acc_list, flush=True)\n\n        # binary search\n        peak_idx = np.argmax(acc_list)\n        c_peak = search_list[peak_idx]\n        c_left, c_right = 1e-1 * c_peak, 1e1 * c_peak\n\n        def binary_search(c_left, c_right, seed, step, test_acc_step_list):\n            clf_left = LogisticRegression(solver=\"lbfgs\",\n                                          max_iter=1000,\n                                          penalty=\"l2\",\n                                          C=c_left).fit(\n                                              fewshot_train_feature,\n                                              fewshot_train_label)\n            pred_left = clf_left.predict(fewshot_val_feature)\n            acc_left = sum(\n                pred_left == fewshot_val_label) / len(fewshot_val_label)\n            print(\"Val accuracy (Left): {:.2f}\".format(100 * acc_left),\n                  flush=True)\n\n            clf_right = LogisticRegression(solver=\"lbfgs\",\n                                           max_iter=1000,\n                                           penalty=\"l2\",\n                                           C=c_right).fit(\n                                               fewshot_train_feature,\n                                               fewshot_train_label)\n            pred_right = clf_right.predict(fewshot_val_feature)\n            acc_right = sum(\n                pred_right == fewshot_val_label) / len(fewshot_val_label)\n            print(\"Val accuracy (Right): {:.2f}\".format(100 * acc_right),\n                  flush=True)\n\n            # find maximum and update ranges\n            if acc_left < acc_right:\n                c_final = c_right\n                clf_final = clf_right\n                # range for the next step\n                c_left = 0.5 * (np.log10(c_right) + np.log10(c_left))\n                c_right = np.log10(c_right)\n            else:\n                c_final = c_left\n                clf_final = clf_left\n                # range for the next step\n                c_right = 0.5 * (np.log10(c_right) + np.log10(c_left))\n                c_left = np.log10(c_left)\n\n            for i, (test_feature, test_label, test_dataset) in enumerate(\n                    zip(test_features, test_labels, test_datasets)):\n                pred = clf_final.predict(test_feature)\n                test_acc = 100 * sum(pred == test_label) / len(pred)\n                print(\"Test Accuracy: {:.2f}\".format(test_acc), flush=True)\n                test_acc_step_list[i, seed - 1, step] = test_acc\n\n                saveline = \"{}, {}, seed {}, {} shot, weight {}, test_acc {:.2f}\\n\".format(\n                    train_dataset, test_dataset, seed, num_shot, c_final,\n                    test_acc)\n                with open(\n                        \"./report/{}_s{}r{}_details.txt\".format(\n                            'clip_feat', args.num_step, args.num_run),\n                        \"a+\",\n                ) as writer:\n                    writer.write(saveline)\n            return (\n                np.power(10, c_left),\n                np.power(10, c_right),\n                seed,\n                step,\n                test_acc_step_list,\n            )\n\n        for step in range(args.num_step):\n            print(\n                f\"{train_dataset}, {num_shot} Shot, Round {step}: {c_left}/{c_right}\",\n                flush=True,\n            )\n            c_left, c_right, seed, step, test_acc_step_list = binary_search(\n                c_left, c_right, seed, step, test_acc_step_list)\n    # save results of last step\n    test_acc_list = test_acc_step_list[:, :, -1]\n    acc_mean = np.mean(test_acc_list, dim=-1)\n    acc_std = np.std(test_acc_list, dim=-1)\n    for i in range(len(test_datasets)):\n        save_line = \"{}, {}, {} Shot, Test acc stat: {:.2f} ({:.2f})\\n\".format(\n            train_dataset, test_datasets[i], num_shot, acc_mean[i], acc_std[i])\n    print(save_line, flush=True)\n    with open(\n            \"./report/{}_s{}r{}.txt\".format('clip_feat', args.num_step,\n                                            args.num_run),\n            \"a+\",\n    ) as writer:\n        writer.write(save_line)\n"
  },
  {
    "path": "ProGrad.public/parse_test_res.py",
    "content": "\"\"\"\nGoal\n---\n1. Read test results from log.txt files\n2. Compute mean and std across different folders (seeds)\n\nUsage\n---\nAssume the output files are saved under output/my_experiment,\nwhich contains results of different seeds, e.g.,\n\nmy_experiment/\n    seed1/\n        log.txt\n    seed2/\n        log.txt\n    seed3/\n        log.txt\n\nRun the following command from the root directory:\n\n$ python tools/parse_test_res.py output/my_experiment\n\nAdd --ci95 to the argument if you wanna get 95% confidence\ninterval instead of standard deviation:\n\n$ python tools/parse_test_res.py output/my_experiment --ci95\n\nIf my_experiment/ has the following structure,\n\nmy_experiment/\n    exp-1/\n        seed1/\n            log.txt\n            ...\n        seed2/\n            log.txt\n            ...\n        seed3/\n            log.txt\n            ...\n    exp-2/\n        ...\n    exp-3/\n        ...\n\nRun\n\n$ python tools/parse_test_res.py output/my_experiment --multi-exp\n\"\"\"\nimport re\nimport numpy as np\nimport os.path as osp\nimport argparse\nfrom collections import OrderedDict, defaultdict\n\nfrom dassl.utils import check_isfile, listdir_nohidden\n\n\ndef compute_ci95(res):\n    return 1.96 * np.std(res) / np.sqrt(len(res))\n\n\ndef parse_function(*metrics, directory=\"\", args=None, end_signal=None):\n    print(f\"Parsing files in {directory}\")\n    subdirs = listdir_nohidden(directory, sort=True)\n\n    outputs = []\n\n    for subdir in subdirs:\n        fpath = osp.join(directory, subdir, \"log.txt\")\n        assert check_isfile(fpath)\n        good_to_go = False\n        output = OrderedDict()\n\n        with open(fpath, \"r\") as f:\n            lines = f.readlines()\n\n            for line in lines:\n                line = line.strip()\n\n                if line == end_signal:\n                    good_to_go = True\n\n                for metric in metrics:\n                    match = metric[\"regex\"].search(line)\n                    if match and good_to_go:\n                        if \"file\" not in output:\n                            output[\"file\"] = fpath\n                        num = float(match.group(1))\n                        name = metric[\"name\"]\n                        output[name] = num\n\n        if output:\n            outputs.append(output)\n\n    assert len(outputs) > 0, f\"Nothing found in {directory}\"\n\n    metrics_results = defaultdict(list)\n\n    for output in outputs:\n        msg = \"\"\n        for key, value in output.items():\n            if isinstance(value, float):\n                msg += f\"{key}: {value:.2f}%. \"\n            else:\n                msg += f\"{key}: {value}. \"\n            if key != \"file\":\n                metrics_results[key].append(value)\n        print(msg)\n\n    output_results = OrderedDict()\n\n    print(\"===\")\n    print(f\"Summary of directory: {directory}\")\n    for key, values in metrics_results.items():\n        avg = np.mean(values)\n        std = compute_ci95(values) if args.ci95 else np.std(values)\n        print(f\"* {key}: {avg:.2f}% +- {std:.2f}%\")\n        output_results[key] = avg\n    print(\"===\")\n\n    return output_results\n\n\ndef main(args, end_signal):\n    metric = {\n        \"name\": args.keyword,\n        \"regex\": re.compile(fr\"\\* {args.keyword}: ([\\.\\deE+-]+)%\"),\n    }\n\n    if args.multi_exp:\n        final_results = defaultdict(list)\n\n        for directory in listdir_nohidden(args.directory, sort=True):\n            directory = osp.join(args.directory, directory)\n            results = parse_function(metric,\n                                     directory=directory,\n                                     args=args,\n                                     end_signal=end_signal)\n\n            for key, value in results.items():\n                final_results[key].append(value)\n\n        print(\"Average performance\")\n        for key, values in final_results.items():\n            avg = np.mean(values)\n            print(f\"* {key}: {avg:.2f}%\")\n\n    else:\n        parse_function(metric,\n                       directory=args.directory,\n                       args=args,\n                       end_signal=end_signal)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"directory\", type=str, help=\"path to directory\")\n    parser.add_argument(\"--ci95\",\n                        action=\"store_true\",\n                        help=r\"compute 95\\% confidence interval\")\n    parser.add_argument(\"--test-log\",\n                        action=\"store_true\",\n                        help=\"parse test-only logs\")\n    parser.add_argument(\"--multi-exp\",\n                        action=\"store_true\",\n                        help=\"parse multiple experiments\")\n    parser.add_argument(\"--keyword\",\n                        default=\"accuracy\",\n                        type=str,\n                        help=\"which keyword to extract\")\n    args = parser.parse_args()\n\n    end_signal = \"Finished training\"\n    if args.test_log:\n        end_signal = \"=> result\"\n\n    main(args, end_signal)\n"
  },
  {
    "path": "ProGrad.public/requirements.txt",
    "content": "ftfy\nregex\ntqdm\n"
  },
  {
    "path": "ProGrad.public/scripts/base2new_test_main.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/data1/CoOpData/\nTRAINER=CoOp\n\nDATASET=$1\nCFG=rn50_ep100  # config file\nCTP=end  # class token position (end or middle)\nNCTX=16  # number of context tokens\nSHOTS=4  # number of shots (1, 2, 4, 8, 16)\nCSC=False  # class-specific context (False or True)\n\nLOADEP=100\nSUB=new\n\nfor SEED in 1 2 3\ndo\n    COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}\n    MODEL_DIR=output/base2new/train_base/${COMMON_DIR}\n    DIR=output/base2new/test_${SUB}/${COMMON_DIR}\n    if [ -d \"$DIR\" ]; then\n        echo \"Results are available in ${DIR}. Skip this job\"\n    else\n        echo \"Run this job and save the output to ${DIR}\"\n        python train.py \\\n        --root ${DATA} \\\n        --seed ${SEED} \\\n        --trainer ${TRAINER} \\\n        --dataset-config-file configs/datasets/${DATASET}.yaml \\\n        --config-file configs/trainers/${TRAINER}/${CFG}.yaml \\\n        --output-dir ${DIR} \\\n        --model-dir ${MODEL_DIR} \\\n        --load-epoch ${LOADEP} \\\n        --eval-only \\\n        TRAINER.COOP.N_CTX ${NCTX} \\\n        TRAINER.COOP.CSC ${CSC} \\\n        TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \\\n        DATASET.NUM_SHOTS ${SHOTS} \\\n        DATASET.SUBSAMPLE_CLASSES ${SUB}\n    fi\ndone\n"
  },
  {
    "path": "ProGrad.public/scripts/base2new_test_prograd.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/data1/CoOpData/\nTRAINER=ProGrad\n\nDATASET=$1\nCFG=rn50_ep100  # config file\nCTP=end  # class token position (end or middle)\nNCTX=16  # number of context tokens\nSHOTS=4  # number of shots (1, 2, 4, 8, 16)\nCSC=False  # class-specific context (False or True)\nLAMBDA=1.0\n\nLOADEP=100\nSUB=new\n\nfor SEED in 1 2 3\ndo\n    COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}\n    MODEL_DIR=output/base2new/train_base/${COMMON_DIR}\n    DIR=output/base2new/test_${SUB}/${COMMON_DIR}\n    if [ -d \"$DIR\" ]; then\n        echo \"Results are available in ${DIR}. Skip this job\"\n    else\n        echo \"Run this job and save the output to ${DIR}\"\n        python train.py \\\n        --root ${DATA} \\\n        --seed ${SEED} \\\n        --trainer ${TRAINER} \\\n        --dataset-config-file configs/datasets/${DATASET}.yaml \\\n        --config-file configs/trainers/${TRAINER}/${CFG}.yaml \\\n        --output-dir ${DIR} \\\n        --model-dir ${MODEL_DIR} \\\n        --load-epoch ${LOADEP} \\\n        --eval-only \\\n        LOSS.LAMBDA ${LAMBDA} \\\n        TRAINER.COOP.N_CTX ${NCTX} \\\n        TRAINER.COOP.CSC ${CSC} \\\n        TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \\\n        DATASET.NUM_SHOTS ${SHOTS} \\\n        DATASET.SUBSAMPLE_CLASSES ${SUB}\n    fi\ndone\n"
  },
  {
    "path": "ProGrad.public/scripts/base2new_train_main.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/data1/CoOpData/\nTRAINER=CoOp\n\nDATASET=$1\nCFG=rn50_ep100  # config file\nCTP=end  # class token position (end or middle)\nNCTX=16  # number of context tokens\nSHOTS=4  # number of shots (1, 2, 4, 8, 16)\nCSC=False  # class-specific context (False or True)\n\nfor SEED in 1 2 3\ndo\n    DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}\n    if [ -d \"$DIR\" ]; then\n        echo \"Results are available in ${DIR}. Skip this job\"\n    else\n        echo \"Run this job and save the output to ${DIR}\"\n        python train.py \\\n        --root ${DATA} \\\n        --seed ${SEED} \\\n        --trainer ${TRAINER} \\\n        --dataset-config-file configs/datasets/${DATASET}.yaml \\\n        --config-file configs/trainers/${TRAINER}/${CFG}.yaml \\\n        --output-dir ${DIR} \\\n        TRAINER.COOP.N_CTX ${NCTX} \\\n        TRAINER.COOP.CSC ${CSC} \\\n        TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \\\n        DATASET.NUM_SHOTS ${SHOTS} \\\n        DATASET.SUBSAMPLE_CLASSES base\n    fi\ndone\n"
  },
  {
    "path": "ProGrad.public/scripts/base2new_train_prograd.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/data1/CoOpData/\nTRAINER=ProGrad\n\nDATASET=$1\nCFG=rn50_ep100  # config file\nCTP=end  # class token position (end or middle)\nNCTX=16  # number of context tokens\nSHOTS=4  # number of shots (1, 2, 4, 8, 16)\nCSC=False  # class-specific context (False or True)\nLAMBDA=1.0\n\nfor SEED in 1 2 3\ndo\n    DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}\n    if [ -d \"$DIR\" ]; then\n        echo \"Results are available in ${DIR}. Skip this job\"\n    else\n        echo \"Run this job and save the output to ${DIR}\"\n        python train.py \\\n        --root ${DATA} \\\n        --seed ${SEED} \\\n        --trainer ${TRAINER} \\\n        --dataset-config-file configs/datasets/${DATASET}.yaml \\\n        --config-file configs/trainers/${TRAINER}/${CFG}.yaml \\\n        --output-dir ${DIR} \\\n        LOSS.LAMBDA ${LAMBDA} \\\n        TRAINER.COOP.N_CTX ${NCTX} \\\n        TRAINER.COOP.CSC ${CSC} \\\n        TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \\\n        DATASET.NUM_SHOTS ${SHOTS} \\\n        DATASET.SUBSAMPLE_CLASSES base\n    fi\ndone\n"
  },
  {
    "path": "ProGrad.public/scripts/eval.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/path/to/datasets\nTRAINER=CoOp\nSHOTS=4\nNCTX=16\nCSC=False\nCTP=end\n\nDATASET=$1\nCFG=$2\n\nfor SEED in 1 2 3\ndo\n    python train.py \\\n    --root ${DATA} \\\n    --seed ${SEED} \\\n    --trainer ${TRAINER} \\\n    --dataset-config-file configs/datasets/${DATASET}.yaml \\\n    --config-file configs/trainers/${TRAINER}/${CFG}.yaml \\\n    --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \\\n    --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \\\n    --load-epoch 50 \\\n    --eval-only \\\n    TRAINER.COOP.N_CTX ${NCTX} \\\n    TRAINER.COOP.CSC ${CSC} \\\n    TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP}\ndone"
  },
  {
    "path": "ProGrad.public/scripts/main.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/data1/CoOpData/\nTRAINER=CoOp\n\nDATASET=$1\nCFG=$2  # config file\nCTP=$3  # class token position (end or middle)\nNCTX=$4  # number of context tokens\nSHOTS=$5  # number of shots (1, 2, 4, 8, 16)\nCSC=$6  # class-specific context (False or True)\n\nfor SEED in 1 2 3\ndo\n    DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}\n    if [ -d \"$DIR\" ]; then\n        echo \"Results are available in ${DIR}. Skip this job\"\n    else\n        echo \"Run this job and save the output to ${DIR}\"\n        python train.py \\\n        --root ${DATA} \\\n        --seed ${SEED} \\\n        --trainer ${TRAINER} \\\n        --dataset-config-file configs/datasets/${DATASET}.yaml \\\n        --config-file configs/trainers/${TRAINER}/${CFG}.yaml \\\n        --output-dir ${DIR} \\\n        TRAINER.COOP.N_CTX ${NCTX} \\\n        TRAINER.COOP.CSC ${CSC} \\\n        TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \\\n        DATASET.NUM_SHOTS ${SHOTS}\n    fi\ndone\n"
  },
  {
    "path": "ProGrad.public/scripts/prograd.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/data1/CoOpData/\nTRAINER=ProGrad\n\nDATASET=$1\nCFG=$2  # config file\nCTP=$3  # class token position (end or middle)\nNCTX=$4  # number of context tokens\nSHOTS=$5  # number of shots (1, 2, 4, 8, 16)\nCSC=$6  # class-specific context (False or True)\nLAMBDA=1.0\n\nfor SEED in 1 2 3\ndo\n    DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}\n    if [ -d \"$DIR\" ]; then\n        echo \"Results are available in ${DIR}. Skip this job\"\n    else\n        echo \"Run this job and save the output to ${DIR}\"\n        python train.py \\\n        --root ${DATA} \\\n        --seed ${SEED} \\\n        --trainer ${TRAINER} \\\n        --dataset-config-file configs/datasets/${DATASET}.yaml \\\n        --config-file configs/trainers/${TRAINER}/${CFG}.yaml \\\n        --output-dir ${DIR} \\\n        LOSS.LAMBDA ${LAMBDA} \\\n        TRAINER.COOP.N_CTX ${NCTX} \\\n        TRAINER.COOP.CSC ${CSC} \\\n        TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \\\n        DATASET.NUM_SHOTS ${SHOTS}\n    fi\ndone\n"
  },
  {
    "path": "ProGrad.public/scripts/zeroshot.sh",
    "content": "#!/bin/bash\n\ncd ..\n\n# custom config\nDATA=/data1/CoOpData\nTRAINER=ZeroshotCLIP\nDATASET=$1\nCFG=$2  # rn50, rn101, vit_b32 or vit_b16\n\npython train.py \\\n--root ${DATA} \\\n--trainer ${TRAINER} \\\n--dataset-config-file configs/datasets/${DATASET}.yaml \\\n--config-file configs/trainers/CoOp/${CFG}.yaml \\\n--output-dir output/${TRAINER}/${CFG}/${DATASET} \\\n--eval-only"
  },
  {
    "path": "ProGrad.public/train.py",
    "content": "import argparse\nimport torch\nimport time\nimport os\n\nfrom dassl.utils import setup_logger, set_random_seed, collect_env_info\nfrom dassl.config import get_cfg_default\nfrom dassl.engine import build_trainer\n\n# custom\nimport datasets.oxford_pets\nimport datasets.oxford_flowers\nimport datasets.fgvc_aircraft\nimport datasets.dtd\nimport datasets.eurosat\nimport datasets.stanford_cars\nimport datasets.food101\nimport datasets.sun397\nimport datasets.caltech101\nimport datasets.ucf101\nimport datasets.imagenet\n\nimport datasets.imagenet_sketch\nimport datasets.imagenetv2\nimport datasets.imagenet_a\nimport datasets.imagenet_r\n\nimport trainers.coop\nimport trainers.cocoop\nimport trainers.zsclip\nimport trainers.prograd\n\n\ndef print_args(args, cfg):\n    print(\"***************\")\n    print(\"** Arguments **\")\n    print(\"***************\")\n    optkeys = list(args.__dict__.keys())\n    optkeys.sort()\n    for key in optkeys:\n        print(\"{}: {}\".format(key, args.__dict__[key]))\n    print(\"************\")\n    print(\"** Config **\")\n    print(\"************\")\n    print(cfg)\n\n\ndef reset_cfg(cfg, args):\n    if args.root:\n        cfg.DATASET.ROOT = args.root\n\n    if args.output_dir:\n        cfg.OUTPUT_DIR = args.output_dir\n\n    if args.resume:\n        cfg.RESUME = args.resume\n\n    if args.seed:\n        cfg.SEED = args.seed\n\n    if args.source_domains:\n        cfg.DATASET.SOURCE_DOMAINS = args.source_domains\n\n    if args.target_domains:\n        cfg.DATASET.TARGET_DOMAINS = args.target_domains\n\n    if args.transforms:\n        cfg.INPUT.TRANSFORMS = args.transforms\n\n    if args.trainer:\n        cfg.TRAINER.NAME = args.trainer\n\n    if args.backbone:\n        cfg.MODEL.BACKBONE.NAME = args.backbone\n\n    if args.head:\n        cfg.MODEL.HEAD.NAME = args.head\n\n\ndef extend_cfg(cfg):\n    \"\"\"\n    Add new config variables.\n\n    E.g.\n        from yacs.config import CfgNode as CN\n        cfg.TRAINER.MY_MODEL = CN()\n        cfg.TRAINER.MY_MODEL.PARAM_A = 1.\n        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5\n        cfg.TRAINER.MY_MODEL.PARAM_C = False\n    \"\"\"\n    from yacs.config import CfgNode as CN\n\n    cfg.TRAINER.COOP = CN()\n    cfg.TRAINER.COOP.ALPHA = 1.0\n    cfg.TRAINER.COOP.N_CTX = 16  # number of context vectors\n    cfg.TRAINER.COOP.CSC = False  # class-specific context\n    cfg.TRAINER.COOP.CTX_INIT = False  # initialization words\n    cfg.TRAINER.COOP.PREC = \"fp16\"  # fp16, fp32, amp\n    cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = \"end\"  # 'middle' or 'end' or 'front'\n\n    cfg.TRAINER.COCOOP = CN()\n    cfg.TRAINER.COCOOP.N_CTX = 16  # number of context vectors\n    cfg.TRAINER.COCOOP.CTX_INIT = False  # initialization words\n    cfg.TRAINER.COCOOP.PREC = \"fp16\"  # fp16, fp32, amp\n\n    cfg.DATASET.SUBSAMPLE_CLASSES = \"all\"  # all, base or new\n    \"\"\"\n    Add new config\n    \"\"\"\n    cfg.LOSS = CN()\n    cfg.LOSS.GM = False\n    cfg.LOSS.NAME = \"\"\n    cfg.LOSS.ALPHA = 0.\n    cfg.LOSS.T = 1.\n    cfg.LOSS.LAMBDA = 1.\n\n\ndef setup_cfg(args):\n    cfg = get_cfg_default()\n    extend_cfg(cfg)\n\n    # 1. From the dataset config file\n    if args.dataset_config_file:\n        cfg.merge_from_file(args.dataset_config_file)\n\n    # 2. From the method config file\n    if args.config_file:\n        cfg.merge_from_file(args.config_file)\n\n    # 3. From input arguments\n    reset_cfg(cfg, args)\n\n    # 4. From optional input arguments\n    cfg.merge_from_list(args.opts)\n\n    cfg.freeze()\n\n    return cfg\n\n\ndef main(args):\n    cfg = setup_cfg(args)\n    if cfg.SEED >= 0:\n        print(\"Setting fixed seed: {}\".format(cfg.SEED))\n        set_random_seed(cfg.SEED)\n    setup_logger(cfg.OUTPUT_DIR)\n\n    if torch.cuda.is_available() and cfg.USE_CUDA:\n        torch.backends.cudnn.benchmark = True\n\n    print_args(args, cfg)\n    print(\"Collecting env info ...\")\n    print(\"** System info **\\n{}\\n\".format(collect_env_info()))\n\n    trainer = build_trainer(cfg)\n\n    if args.eval_only:\n        trainer.load_model(args.model_dir, epoch=args.load_epoch)\n        trainer.test()\n        return\n\n    if not args.no_train:\n        trainer.train()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--root\", type=str, default=\"\", help=\"path to dataset\")\n    parser.add_argument(\"--output-dir\",\n                        type=str,\n                        default=\"\",\n                        help=\"output directory\")\n    parser.add_argument(\n        \"--resume\",\n        type=str,\n        default=\"\",\n        help=\"checkpoint directory (from which the training resumes)\",\n    )\n    parser.add_argument(\"--seed\",\n                        type=int,\n                        default=-1,\n                        help=\"only positive value enables a fixed seed\")\n    parser.add_argument(\"--source-domains\",\n                        type=str,\n                        nargs=\"+\",\n                        help=\"source domains for DA/DG\")\n    parser.add_argument(\"--target-domains\",\n                        type=str,\n                        nargs=\"+\",\n                        help=\"target domains for DA/DG\")\n    parser.add_argument(\"--transforms\",\n                        type=str,\n                        nargs=\"+\",\n                        help=\"data augmentation methods\")\n    parser.add_argument(\"--config-file\",\n                        type=str,\n                        default=\"\",\n                        help=\"path to config file\")\n    parser.add_argument(\n        \"--dataset-config-file\",\n        type=str,\n        default=\"\",\n        help=\"path to config file for dataset setup\",\n    )\n    parser.add_argument(\"--trainer\",\n                        type=str,\n                        default=\"\",\n                        help=\"name of trainer\")\n    parser.add_argument(\"--backbone\",\n                        type=str,\n                        default=\"\",\n                        help=\"name of CNN backbone\")\n    parser.add_argument(\"--head\", type=str, default=\"\", help=\"name of head\")\n    parser.add_argument(\"--eval-only\",\n                        action=\"store_true\",\n                        help=\"evaluation only\")\n    parser.add_argument(\n        \"--model-dir\",\n        type=str,\n        default=\"\",\n        help=\"load model from this directory for eval-only mode\",\n    )\n    parser.add_argument(\"--load-epoch\",\n                        type=int,\n                        help=\"load model weights at this epoch for evaluation\")\n    parser.add_argument(\"--no-train\",\n                        action=\"store_true\",\n                        help=\"do not call trainer.train()\")\n    parser.add_argument(\n        \"opts\",\n        default=None,\n        nargs=argparse.REMAINDER,\n        help=\"modify config options using the command-line\",\n    )\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "ProGrad.public/trainers/__init__.py",
    "content": ""
  },
  {
    "path": "ProGrad.public/trainers/cocoop.py",
    "content": "import os.path as osp\nfrom collections import OrderedDict\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torch.cuda.amp import GradScaler, autocast\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.metrics import compute_accuracy\nfrom dassl.utils import load_pretrained_weights, load_checkpoint\nfrom dassl.optim import build_optimizer, build_lr_scheduler\n\nfrom clip import clip\nfrom clip.simple_tokenizer import SimpleTokenizer as _Tokenizer\n\n_tokenizer = _Tokenizer()\n\n\ndef load_clip_to_cpu(cfg):\n    backbone_name = cfg.MODEL.BACKBONE.NAME\n    url = clip._MODELS[backbone_name]\n    model_path = clip._download(url)\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=\"cpu\").eval()\n        state_dict = None\n\n    except RuntimeError:\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    model = clip.build_model(state_dict or model.state_dict())\n\n    return model\n\n\nclass TextEncoder(nn.Module):\n    def __init__(self, clip_model):\n        super().__init__()\n        self.transformer = clip_model.transformer\n        self.positional_embedding = clip_model.positional_embedding\n        self.ln_final = clip_model.ln_final\n        self.text_projection = clip_model.text_projection\n        self.dtype = clip_model.dtype\n\n    def forward(self, prompts, tokenized_prompts):\n        x = prompts + self.positional_embedding.type(self.dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x).type(self.dtype)\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = x[torch.arange(x.shape[0]),\n              tokenized_prompts.argmax(dim=-1)] @ self.text_projection\n\n        return x\n\n\nclass PromptLearner(nn.Module):\n    def __init__(self, cfg, classnames, clip_model):\n        super().__init__()\n        n_cls = len(classnames)\n        n_ctx = cfg.TRAINER.COCOOP.N_CTX\n        ctx_init = cfg.TRAINER.COCOOP.CTX_INIT\n        dtype = clip_model.dtype\n        ctx_dim = clip_model.ln_final.weight.shape[0]\n        vis_dim = clip_model.visual.output_dim\n        clip_imsize = clip_model.visual.input_resolution\n        cfg_imsize = cfg.INPUT.SIZE[0]\n        assert cfg_imsize == clip_imsize, f\"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})\"\n\n        if ctx_init:\n            ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME]\n            ctx_init = ctx_init.replace(\" {}.\", \"\")\n            ctx_init = ctx_init.replace(\"_\", \" \")\n            prompt_n_ctx = len(ctx_init.split(\" \"))\n\n            assert n_ctx >= prompt_n_ctx, f\"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})\"\n\n            prompt = clip.tokenize(ctx_init)\n            with torch.no_grad():\n                embedding = clip_model.token_embedding(prompt).type(dtype)\n\n            ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype)\n\n            ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 +\n                                                              prompt_n_ctx, :]\n            prompt_prefix = \" \".join([\"X\"] * (n_ctx - prompt_n_ctx))\n            prompt_prefix = f\"{prompt_prefix} {ctx_init}\"\n        else:\n            # random initialization\n            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)\n            nn.init.normal_(ctx_vectors, std=0.02)\n            prompt_prefix = \" \".join([\"X\"] * n_ctx)\n\n        print(f'Initial context: \"{prompt_prefix}\"')\n        print(f\"Number of context words (tokens): {n_ctx}\")\n\n        self.ctx = nn.Parameter(ctx_vectors)\n\n        self.meta_net = nn.Sequential(\n            OrderedDict([(\"linear1\", nn.Linear(vis_dim, vis_dim // 16)),\n                         (\"relu\", nn.ReLU(inplace=True)),\n                         (\"linear2\", nn.Linear(vis_dim // 16, ctx_dim))]))\n\n        if cfg.TRAINER.COCOOP.PREC == \"fp16\":\n            self.meta_net.half()\n\n        classnames = [name.replace(\"_\", \" \") for name in classnames]\n        name_lens = [len(_tokenizer.encode(name)) for name in classnames]\n        prompts = [prompt_prefix + \" \" + name + \".\" for name in classnames]\n\n        tokenized_prompts = torch.cat([clip.tokenize(p)\n                                       for p in prompts])  # (n_cls, n_tkn)\n        with torch.no_grad():\n            embedding = clip_model.token_embedding(tokenized_prompts).type(\n                dtype)\n\n        # These token vectors will be saved when in save_model(),\n        # but they should be ignored in load_model() as we want to use\n        # those computed using the current class names\n        self.register_buffer(\"token_prefix\", embedding[:, :1, :])  # SOS\n        self.register_buffer(\"token_suffix\",\n                             embedding[:, 1 + n_ctx:, :])  # CLS, EOS\n\n        self.n_cls = n_cls\n        self.n_ctx = n_ctx\n        self.tokenized_prompts = tokenized_prompts  # torch.Tensor\n        self.name_lens = name_lens\n\n    def construct_prompts(self, ctx, prefix, suffix, label=None):\n        # dim0 is either batch_size (during training) or n_cls (during testing)\n        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)\n        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)\n        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)\n\n        if label is not None:\n            prefix = prefix[label]\n            suffix = suffix[label]\n\n        prompts = torch.cat(\n            [\n                prefix,  # (dim0, 1, dim)\n                ctx,  # (dim0, n_ctx, dim)\n                suffix,  # (dim0, *, dim)\n            ],\n            dim=1,\n        )\n\n        return prompts\n\n    def forward(self, im_features):\n        prefix = self.token_prefix\n        suffix = self.token_suffix\n        ctx = self.ctx  # (n_ctx, ctx_dim)\n        bias = self.meta_net(im_features)  # (batch, ctx_dim)\n        bias = bias.unsqueeze(1)  # (batch, 1, ctx_dim)\n        ctx = ctx.unsqueeze(0)  # (1, n_ctx, ctx_dim)\n        ctx_shifted = ctx + bias  # (batch, n_ctx, ctx_dim)\n\n        # Use instance-conditioned context tokens for all classes\n        prompts = []\n        for ctx_shifted_i in ctx_shifted:\n            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)\n            pts_i = self.construct_prompts(ctx_i, prefix,\n                                           suffix)  # (n_cls, n_tkn, ctx_dim)\n            prompts.append(pts_i)\n        prompts = torch.stack(prompts)\n\n        return prompts\n\n\nCUSTOM_TEMPLATES = {\n    # \"OxfordPets\": \"a photo of a {}, a type of pet.\",\n    \"OxfordPets\": \"a type of pet, a photo of a {}.\",\n    # \"OxfordFlowers\": \"a photo of a {}, a type of flower.\",\n    \"OxfordFlowers\": \"a type of flower, a photo of a {}.\",\n    \"FGVCAircraft\": \"a type of aircraft, a photo of a {}.\",\n    \"DescribableTextures\": \"a texture of {}.\",\n    \"EuroSAT\": \"a centered satellite photo of {}.\",\n    \"StanfordCars\": \"a photo of a {}.\",\n    # \"Food101\": \"a photo of {}, a type of food.\",\n    \"Food101\": \"a type of food, a photo of {}.\",\n    \"SUN397\": \"a photo of a {}.\",\n    \"Caltech101\": \"a photo of a {}.\",\n    \"UCF101\": \"a photo of a person doing {}.\",\n    \"ImageNet\": \"a photo of a {}.\",\n    \"ImageNetSketch\": \"a photo of a {}.\",\n    \"ImageNetV2\": \"a photo of a {}.\",\n    \"ImageNetA\": \"a photo of a {}.\",\n    \"ImageNetR\": \"a photo of a {}.\",\n}\n\n\nclass CustomCLIP(nn.Module):\n    def __init__(self, cfg, classnames, clip_model):\n        super().__init__()\n        self.prompt_learner = PromptLearner(cfg, classnames, clip_model)\n        self.tokenized_prompts = self.prompt_learner.tokenized_prompts\n        self.image_encoder = clip_model.visual\n        self.text_encoder = TextEncoder(clip_model)\n        self.logit_scale = clip_model.logit_scale\n        self.dtype = clip_model.dtype\n\n    def forward(self, image, label=None):\n        tokenized_prompts = self.tokenized_prompts\n        logit_scale = self.logit_scale.exp()\n\n        image_features = self.image_encoder(image.type(self.dtype))\n        image_features = image_features / image_features.norm(dim=-1,\n                                                              keepdim=True)\n\n        prompts = self.prompt_learner(image_features)\n\n        logits = []\n        for pts_i, imf_i in zip(prompts, image_features):\n            text_features = self.text_encoder(pts_i, tokenized_prompts)\n            text_features = text_features / text_features.norm(dim=-1,\n                                                               keepdim=True)\n            l_i = logit_scale * imf_i @ text_features.t()\n            logits.append(l_i)\n        logits = torch.stack(logits)\n\n        if self.prompt_learner.training:\n            return F.cross_entropy(logits, label)\n\n        return logits\n\n\n@TRAINER_REGISTRY.register()\nclass CoCoOp(TrainerX):\n    def check_cfg(self, cfg):\n        assert cfg.TRAINER.COCOOP.PREC in [\"fp16\", \"fp32\", \"amp\"]\n\n    def build_model(self):\n        cfg = self.cfg\n        classnames = self.dm.dataset.classnames\n\n        print(f\"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})\")\n        clip_model = load_clip_to_cpu(cfg)\n\n        if cfg.TRAINER.COCOOP.PREC == \"fp32\" or cfg.TRAINER.COCOOP.PREC == \"amp\":\n            # CLIP's default precision is fp16\n            clip_model.float()\n\n        print(\"Building custom CLIP\")\n        self.model = CustomCLIP(cfg, classnames, clip_model)\n\n        print(\"Turning off gradients in both the image and the text encoder\")\n        name_to_update = \"prompt_learner\"\n\n        for name, param in self.model.named_parameters():\n            if name_to_update not in name:\n                param.requires_grad_(False)\n\n        # Double check\n        enabled = set()\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                enabled.add(name)\n        print(f\"Parameters to be updated: {enabled}\")\n\n        if cfg.MODEL.INIT_WEIGHTS:\n            load_pretrained_weights(self.model.prompt_learner,\n                                    cfg.MODEL.INIT_WEIGHTS)\n\n        self.model.to(self.device)\n        # NOTE: only give prompt_learner to the optimizer\n        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)\n        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)\n        self.register_model(\"prompt_learner\", self.model.prompt_learner,\n                            self.optim, self.sched)\n\n        self.scaler = GradScaler(\n        ) if cfg.TRAINER.COCOOP.PREC == \"amp\" else None\n\n        # Note that multi-gpu training could be slow because CLIP's size is\n        # big, which slows down the copy operation in DataParallel\n        device_count = torch.cuda.device_count()\n        if device_count > 1:\n            print(\n                f\"Multiple GPUs detected (n_gpus={device_count}), use all of them!\"\n            )\n            self.model = nn.DataParallel(self.model)\n\n    def forward_backward(self, batch):\n        image, label = self.parse_batch_train(batch)\n\n        model = self.model\n        optim = self.optim\n        scaler = self.scaler\n\n        prec = self.cfg.TRAINER.COCOOP.PREC\n        if prec == \"amp\":\n            with autocast():\n                loss = model(image, label)\n            optim.zero_grad()\n            scaler.scale(loss).backward()\n            scaler.step(optim)\n            scaler.update()\n        else:\n            loss = model(image, label)\n            optim.zero_grad()\n            loss.backward()\n            optim.step()\n\n        loss_summary = {\"loss\": loss.item()}\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch):\n        input = batch[\"img\"]\n        label = batch[\"label\"]\n        input = input.to(self.device)\n        label = label.to(self.device)\n        return input, label\n\n    def load_model(self, directory, epoch=None):\n        if not directory:\n            print(\n                \"Note that load_model() is skipped as no pretrained model is given\"\n            )\n            return\n\n        names = self.get_model_names()\n\n        # By default, the best model is loaded\n        model_file = \"model-best.pth.tar\"\n\n        if epoch is not None:\n            model_file = \"model.pth.tar-\" + str(epoch)\n\n        for name in names:\n            model_path = osp.join(directory, name, model_file)\n\n            if not osp.exists(model_path):\n                raise FileNotFoundError(\n                    'Model not found at \"{}\"'.format(model_path))\n\n            checkpoint = load_checkpoint(model_path)\n            state_dict = checkpoint[\"state_dict\"]\n            epoch = checkpoint[\"epoch\"]\n\n            # Ignore fixed token vectors\n            if \"token_prefix\" in state_dict:\n                del state_dict[\"token_prefix\"]\n\n            if \"token_suffix\" in state_dict:\n                del state_dict[\"token_suffix\"]\n\n            print(\"Loading weights to {} \"\n                  'from \"{}\" (epoch = {})'.format(name, model_path, epoch))\n            # set strict=False\n            self._models[name].load_state_dict(state_dict, strict=False)\n"
  },
  {
    "path": "ProGrad.public/trainers/coop.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torch.cuda.amp import GradScaler, autocast\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.metrics import compute_accuracy\nfrom dassl.utils import load_pretrained_weights, load_checkpoint\nfrom dassl.optim import build_optimizer, build_lr_scheduler\n\nfrom clip import clip\nfrom clip.simple_tokenizer import SimpleTokenizer as _Tokenizer\n\n_tokenizer = _Tokenizer()\n\n\ndef load_clip_to_cpu(cfg):\n    backbone_name = cfg.MODEL.BACKBONE.NAME\n    url = clip._MODELS[backbone_name]\n    model_path = clip._download(url)\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=\"cpu\").eval()\n        state_dict = None\n\n    except RuntimeError:\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    model = clip.build_model(state_dict or model.state_dict())\n\n    return model\n\n\nclass TextEncoder(nn.Module):\n    def __init__(self, clip_model):\n        super().__init__()\n        self.transformer = clip_model.transformer\n        self.positional_embedding = clip_model.positional_embedding\n        self.ln_final = clip_model.ln_final\n        self.text_projection = clip_model.text_projection\n        self.dtype = clip_model.dtype\n\n    def forward(self, prompts, tokenized_prompts):\n        x = prompts + self.positional_embedding.type(self.dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x).type(self.dtype)\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = x[torch.arange(x.shape[0]),\n              tokenized_prompts.argmax(dim=-1)] @ self.text_projection\n\n        return x\n\n\nclass PromptLearner(nn.Module):\n    def __init__(self, cfg, classnames, clip_model):\n        super().__init__()\n        n_cls = len(classnames)\n        n_ctx = cfg.TRAINER.COOP.N_CTX\n        ctx_init = cfg.TRAINER.COOP.CTX_INIT\n        dtype = clip_model.dtype\n        ctx_dim = clip_model.ln_final.weight.shape[0]\n        clip_imsize = clip_model.visual.input_resolution\n        cfg_imsize = cfg.INPUT.SIZE[0]\n        assert cfg_imsize == clip_imsize, f\"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})\"\n\n        # if ctx_init:\n        #     # use given words to initialize context vectors\n        #     ctx_init = ctx_init.replace(\"_\", \" \")\n        #     n_ctx = len(ctx_init.split(\" \"))\n        #     prompt = clip.tokenize(ctx_init)\n        #     with torch.no_grad():\n        #         embedding = clip_model.token_embedding(prompt).type(dtype)\n        #     ctx_vectors = embedding[0, 1:1 + n_ctx, :]\n        #     prompt_prefix = ctx_init\n        if ctx_init:\n            ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME]\n            ctx_init = ctx_init.replace(\" {}.\", \"\")\n            ctx_init = ctx_init.replace(\"_\", \" \")\n            prompt_n_ctx = len(ctx_init.split(\" \"))\n\n            assert n_ctx >= prompt_n_ctx, f\"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})\"\n\n            prompt = clip.tokenize(ctx_init)\n            with torch.no_grad():\n                embedding = clip_model.token_embedding(prompt).type(dtype)\n\n            ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype)\n\n            ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 +\n                                                              prompt_n_ctx, :]\n            prompt_prefix = \" \".join([\"X\"] * (n_ctx - prompt_n_ctx))\n            prompt_prefix = f\"{prompt_prefix} {ctx_init}\"\n        else:\n            # random initialization\n            if cfg.TRAINER.COOP.CSC:\n                print(\"Initializing class-specific contexts\")\n                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)\n            else:\n                print(\"Initializing a generic context\")\n                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)\n            nn.init.normal_(ctx_vectors, std=0.02)\n            prompt_prefix = \" \".join([\"X\"] * n_ctx)\n\n        print(f'Initial context: \"{prompt_prefix}\"')\n        print(f\"Number of context words (tokens): {n_ctx}\")\n\n        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized\n\n        classnames = [name.replace(\"_\", \" \") for name in classnames]\n        name_lens = [len(_tokenizer.encode(name)) for name in classnames]\n        prompts = [prompt_prefix + \" \" + name + \".\" for name in classnames]\n\n        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])\n        with torch.no_grad():\n            embedding = clip_model.token_embedding(tokenized_prompts).type(\n                dtype)\n\n        # These token vectors will be saved when in save_model(),\n        # but they should be ignored in load_model() as we want to use\n        # those computed using the current class names\n        self.register_buffer(\"token_prefix\", embedding[:, :1, :])  # SOS\n        self.register_buffer(\"token_suffix\",\n                             embedding[:, 1 + n_ctx:, :])  # CLS, EOS\n\n        self.n_cls = n_cls\n        self.n_ctx = n_ctx\n        self.tokenized_prompts = tokenized_prompts  # torch.Tensor\n        self.name_lens = name_lens\n        self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION\n\n    def forward(self):\n        ctx = self.ctx\n        if ctx.dim() == 2:\n            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)\n\n        prefix = self.token_prefix\n        suffix = self.token_suffix\n\n        if self.class_token_position == \"end\":\n            prompts = torch.cat(\n                [\n                    prefix,  # (n_cls, 1, dim)\n                    ctx,  # (n_cls, n_ctx, dim)\n                    suffix,  # (n_cls, *, dim)\n                ],\n                dim=1,\n            )\n\n        elif self.class_token_position == \"middle\":\n            half_n_ctx = self.n_ctx // 2\n            prompts = []\n            for i in range(self.n_cls):\n                name_len = self.name_lens[i]\n                prefix_i = prefix[i:i + 1, :, :]\n                class_i = suffix[i:i + 1, :name_len, :]\n                suffix_i = suffix[i:i + 1, name_len:, :]\n                ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :]\n                ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :]\n                prompt = torch.cat(\n                    [\n                        prefix_i,  # (1, 1, dim)\n                        ctx_i_half1,  # (1, n_ctx//2, dim)\n                        class_i,  # (1, name_len, dim)\n                        ctx_i_half2,  # (1, n_ctx//2, dim)\n                        suffix_i,  # (1, *, dim)\n                    ],\n                    dim=1,\n                )\n                prompts.append(prompt)\n            prompts = torch.cat(prompts, dim=0)\n\n        elif self.class_token_position == \"front\":\n            prompts = []\n            for i in range(self.n_cls):\n                name_len = self.name_lens[i]\n                prefix_i = prefix[i:i + 1, :, :]\n                class_i = suffix[i:i + 1, :name_len, :]\n                suffix_i = suffix[i:i + 1, name_len:, :]\n                ctx_i = ctx[i:i + 1, :, :]\n                prompt = torch.cat(\n                    [\n                        prefix_i,  # (1, 1, dim)\n                        class_i,  # (1, name_len, dim)\n                        ctx_i,  # (1, n_ctx, dim)\n                        suffix_i,  # (1, *, dim)\n                    ],\n                    dim=1,\n                )\n                prompts.append(prompt)\n            prompts = torch.cat(prompts, dim=0)\n\n        else:\n            raise ValueError\n\n        return prompts\n\n\nCUSTOM_TEMPLATES = {\n    # \"OxfordPets\": \"a photo of a {}, a type of pet.\",\n    \"OxfordPets\": \"a type of pet, a photo of a {}.\",\n    # \"OxfordFlowers\": \"a photo of a {}, a type of flower.\",\n    \"OxfordFlowers\": \"a type of flower, a photo of a {}.\",\n    \"FGVCAircraft\": \"a type of aircraft, a photo of a {}.\",\n    \"DescribableTextures\": \"a texture of {}.\",\n    \"EuroSAT\": \"a centered satellite photo of {}.\",\n    \"StanfordCars\": \"a photo of a {}.\",\n    # \"Food101\": \"a photo of {}, a type of food.\",\n    \"Food101\": \"a type of food, a photo of {}.\",\n    \"SUN397\": \"a photo of a {}.\",\n    \"Caltech101\": \"a photo of a {}.\",\n    \"UCF101\": \"a photo of a person doing {}.\",\n    \"ImageNet\": \"a photo of a {}.\",\n    \"ImageNetSketch\": \"a photo of a {}.\",\n    \"ImageNetV2\": \"a photo of a {}.\",\n    \"ImageNetA\": \"a photo of a {}.\",\n    \"ImageNetR\": \"a photo of a {}.\",\n}\n\n\nclass CustomCLIP(nn.Module):\n    def __init__(self, cfg, classnames, clip_model):\n        super().__init__()\n        self.prompt_learner = PromptLearner(cfg, classnames, clip_model)\n        self.tokenized_prompts = self.prompt_learner.tokenized_prompts\n        self.image_encoder = clip_model.visual\n        self.text_encoder = TextEncoder(clip_model)\n        self.logit_scale = clip_model.logit_scale\n        self.dtype = clip_model.dtype\n\n    def forward(self, image):\n        image_features = self.image_encoder(image.type(self.dtype))\n\n        prompts = self.prompt_learner()\n        tokenized_prompts = self.tokenized_prompts\n        text_features = self.text_encoder(prompts, tokenized_prompts)\n\n        image_features = image_features / image_features.norm(dim=-1,\n                                                              keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1,\n                                                           keepdim=True)\n\n        logit_scale = self.logit_scale.exp()\n        logits = logit_scale * image_features @ text_features.t()\n\n        return logits\n\n\n@TRAINER_REGISTRY.register()\nclass CoOp(TrainerX):\n    \"\"\"Context Optimization (CoOp).\n\n    Learning to Prompt for Vision-Language Models\n    https://arxiv.org/abs/2109.01134\n    \"\"\"\n    def check_cfg(self, cfg):\n        assert cfg.TRAINER.COOP.PREC in [\"fp16\", \"fp32\", \"amp\"]\n\n    def build_model(self):\n        cfg = self.cfg\n        classnames = self.dm.dataset.classnames\n\n        print(f\"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})\")\n        clip_model = load_clip_to_cpu(cfg)\n\n        if cfg.TRAINER.COOP.PREC == \"fp32\" or cfg.TRAINER.COOP.PREC == \"amp\":\n            # CLIP's default precision is fp16\n            clip_model.float()\n\n        print(\"Building custom CLIP\")\n        self.model = CustomCLIP(cfg, classnames, clip_model)\n\n        print(\"Turning off gradients in both the image and the text encoder\")\n        for name, param in self.model.named_parameters():\n            if \"prompt_learner\" not in name:\n                param.requires_grad_(False)\n\n        if cfg.MODEL.INIT_WEIGHTS:\n            load_pretrained_weights(self.model.prompt_learner,\n                                    cfg.MODEL.INIT_WEIGHTS)\n\n        self.model.to(self.device)\n        # NOTE: only give prompt_learner to the optimizer\n        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)\n        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)\n        self.register_model(\"prompt_learner\", self.model.prompt_learner,\n                            self.optim, self.sched)\n\n        self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == \"amp\" else None\n\n        # Note that multi-gpu training could be slow because CLIP's size is\n        # big, which slows down the copy operation in DataParallel\n        device_count = torch.cuda.device_count()\n        if device_count > 1:\n            print(\n                f\"Multiple GPUs detected (n_gpus={device_count}), use all of them!\"\n            )\n            self.model = nn.DataParallel(self.model)\n\n    def forward_backward(self, batch):\n        image, label = self.parse_batch_train(batch)\n\n        prec = self.cfg.TRAINER.COOP.PREC\n        if prec == \"amp\":\n            with autocast():\n                output = self.model(image)\n                loss = F.cross_entropy(output, label)\n            self.optim.zero_grad()\n            self.scaler.scale(loss).backward()\n            self.scaler.step(self.optim)\n            self.scaler.update()\n        else:\n            output = self.model(image)\n            loss = F.cross_entropy(output, label)\n            self.model_backward_and_update(loss)\n\n        loss_summary = {\n            \"loss\": loss.item(),\n            \"acc\": compute_accuracy(output, label)[0].item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch):\n        input = batch[\"img\"]\n        label = batch[\"label\"]\n        input = input.to(self.device)\n        label = label.to(self.device)\n        return input, label\n\n    def load_model(self, directory, epoch=None):\n        if not directory:\n            print(\n                \"Note that load_model() is skipped as no pretrained model is given\"\n            )\n            return\n\n        names = self.get_model_names()\n\n        # By default, the best model is loaded\n        model_file = \"model-best.pth.tar\"\n\n        if epoch is not None:\n            model_file = \"model.pth.tar-\" + str(epoch)\n\n        for name in names:\n            model_path = osp.join(directory, name, model_file)\n\n            if not osp.exists(model_path):\n                raise FileNotFoundError(\n                    'Model not found at \"{}\"'.format(model_path))\n\n            checkpoint = load_checkpoint(model_path)\n            state_dict = checkpoint[\"state_dict\"]\n            epoch = checkpoint[\"epoch\"]\n\n            # Ignore fixed token vectors\n            if \"token_prefix\" in state_dict:\n                del state_dict[\"token_prefix\"]\n\n            if \"token_suffix\" in state_dict:\n                del state_dict[\"token_suffix\"]\n\n            print(\"Loading weights to {} \"\n                  'from \"{}\" (epoch = {})'.format(name, model_path, epoch))\n            # set strict=False\n            self._models[name].load_state_dict(state_dict, strict=False)\n"
  },
  {
    "path": "ProGrad.public/trainers/imagenet_templates.py",
    "content": "# source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb\n\nIMAGENET_TEMPLATES = [\n    \"a bad photo of a {}.\",\n    \"a photo of many {}.\",\n    \"a sculpture of a {}.\",\n    \"a photo of the hard to see {}.\",\n    \"a low resolution photo of the {}.\",\n    \"a rendering of a {}.\",\n    \"graffiti of a {}.\",\n    \"a bad photo of the {}.\",\n    \"a cropped photo of the {}.\",\n    \"a tattoo of a {}.\",\n    \"the embroidered {}.\",\n    \"a photo of a hard to see {}.\",\n    \"a bright photo of a {}.\",\n    \"a photo of a clean {}.\",\n    \"a photo of a dirty {}.\",\n    \"a dark photo of the {}.\",\n    \"a drawing of a {}.\",\n    \"a photo of my {}.\",\n    \"the plastic {}.\",\n    \"a photo of the cool {}.\",\n    \"a close-up photo of a {}.\",\n    \"a black and white photo of the {}.\",\n    \"a painting of the {}.\",\n    \"a painting of a {}.\",\n    \"a pixelated photo of the {}.\",\n    \"a sculpture of the {}.\",\n    \"a bright photo of the {}.\",\n    \"a cropped photo of a {}.\",\n    \"a plastic {}.\",\n    \"a photo of the dirty {}.\",\n    \"a jpeg corrupted photo of a {}.\",\n    \"a blurry photo of the {}.\",\n    \"a photo of the {}.\",\n    \"a good photo of the {}.\",\n    \"a rendering of the {}.\",\n    \"a {} in a video game.\",\n    \"a photo of one {}.\",\n    \"a doodle of a {}.\",\n    \"a close-up photo of the {}.\",\n    \"a photo of a {}.\",\n    \"the origami {}.\",\n    \"the {} in a video game.\",\n    \"a sketch of a {}.\",\n    \"a doodle of the {}.\",\n    \"a origami {}.\",\n    \"a low resolution photo of a {}.\",\n    \"the toy {}.\",\n    \"a rendition of the {}.\",\n    \"a photo of the clean {}.\",\n    \"a photo of a large {}.\",\n    \"a rendition of a {}.\",\n    \"a photo of a nice {}.\",\n    \"a photo of a weird {}.\",\n    \"a blurry photo of a {}.\",\n    \"a cartoon {}.\",\n    \"art of a {}.\",\n    \"a sketch of the {}.\",\n    \"a embroidered {}.\",\n    \"a pixelated photo of a {}.\",\n    \"itap of the {}.\",\n    \"a jpeg corrupted photo of the {}.\",\n    \"a good photo of a {}.\",\n    \"a plushie {}.\",\n    \"a photo of the nice {}.\",\n    \"a photo of the small {}.\",\n    \"a photo of the weird {}.\",\n    \"the cartoon {}.\",\n    \"art of the {}.\",\n    \"a drawing of the {}.\",\n    \"a photo of the large {}.\",\n    \"a black and white photo of a {}.\",\n    \"the plushie {}.\",\n    \"a dark photo of a {}.\",\n    \"itap of a {}.\",\n    \"graffiti of the {}.\",\n    \"a toy {}.\",\n    \"itap of my {}.\",\n    \"a photo of a cool {}.\",\n    \"a photo of a small {}.\",\n    \"a tattoo of the {}.\",\n]\n\nIMAGENET_TEMPLATES_SELECT = [\n    \"itap of a {}.\",\n    \"a bad photo of the {}.\",\n    \"a origami {}.\",\n    \"a photo of the large {}.\",\n    \"a {} in a video game.\",\n    \"art of the {}.\",\n    \"a photo of the small {}.\",\n]\n"
  },
  {
    "path": "ProGrad.public/trainers/prograd.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom torch.cuda.amp import GradScaler, autocast\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.metrics import compute_accuracy\nfrom dassl.utils import load_pretrained_weights, load_checkpoint\nfrom dassl.optim import build_optimizer, build_lr_scheduler\n\nfrom clip import clip\nfrom clip.simple_tokenizer import SimpleTokenizer as _Tokenizer\n\nfrom torch.nn.modules.loss import _Loss\n\nfrom tqdm import tqdm\nimport json\n\n_tokenizer = _Tokenizer()\n\n\ndef load_clip_to_cpu(cfg):\n    backbone_name = cfg.MODEL.BACKBONE.NAME\n    url = clip._MODELS[backbone_name]\n    model_path = clip._download(url)\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=\"cpu\").eval()\n        state_dict = None\n\n    except RuntimeError:\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    model = clip.build_model(state_dict or model.state_dict())\n\n    return model\n\n\nclass TextEncoder(nn.Module):\n    def __init__(self, clip_model):\n        super().__init__()\n        self.transformer = clip_model.transformer\n        self.positional_embedding = clip_model.positional_embedding\n        self.ln_final = clip_model.ln_final\n        self.text_projection = clip_model.text_projection\n        self.dtype = clip_model.dtype\n\n    def forward(self, prompts, tokenized_prompts):\n        x = prompts + self.positional_embedding.type(self.dtype)\n        x = x.permute(1, 0, 2)  # NLD -> LND\n        x = self.transformer(x)\n        x = x.permute(1, 0, 2)  # LND -> NLD\n        x = self.ln_final(x).type(self.dtype)\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        x = x[torch.arange(x.shape[0]),\n              tokenized_prompts.argmax(dim=-1)] @ self.text_projection\n\n        return x\n\n\nclass PromptLearner(nn.Module):\n    def __init__(self, cfg, classnames, clip_model):\n        super().__init__()\n        n_cls = len(classnames)\n        n_ctx = cfg.TRAINER.COOP.N_CTX\n        ctx_init = cfg.TRAINER.COOP.CTX_INIT\n        dtype = clip_model.dtype\n        ctx_dim = clip_model.ln_final.weight.shape[0]\n        clip_imsize = clip_model.visual.input_resolution\n        cfg_imsize = cfg.INPUT.SIZE[0]\n        assert cfg_imsize == clip_imsize, f\"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})\"\n\n        if ctx_init:\n            ctx_init = CUSTOM_TEMPLATES[cfg.DATASET.NAME]\n            ctx_init = ctx_init.replace(\" {}.\", \"\")\n            ctx_init = ctx_init.replace(\"_\", \" \")\n            prompt_n_ctx = len(ctx_init.split(\" \"))\n\n            assert n_ctx >= prompt_n_ctx, f\"#tokens ({n_ctx}) should larger equal than #initial prompt tokens ({prompt_n_ctx}, {ctx_init})\"\n\n            prompt = clip.tokenize(ctx_init)\n            with torch.no_grad():\n                embedding = clip_model.token_embedding(prompt).type(dtype)\n\n            ctx_vectors = torch.zeros(n_ctx, ctx_dim, dtype=dtype)\n\n            ctx_vectors[n_ctx - prompt_n_ctx:, :] = embedding[0, 1:1 +\n                                                              prompt_n_ctx, :]\n            prompt_prefix = \" \".join([\"X\"] * (n_ctx - prompt_n_ctx))\n            prompt_prefix = f\"{prompt_prefix} {ctx_init}\"\n        else:\n            # random initialization\n            if cfg.TRAINER.COOP.CSC:\n                print(\"Initializing class-specific contexts\")\n                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)\n            else:\n                print(\"Initializing a generic context\")\n                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)\n            nn.init.normal_(ctx_vectors, std=0.02)\n            prompt_prefix = \" \".join([\"X\"] * n_ctx)\n\n        print(f'Initial context: \"{prompt_prefix}\"')\n        print(f\"Number of context words (tokens): {n_ctx}\")\n\n        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized\n\n        classnames = [name.replace(\"_\", \" \") for name in classnames]\n        name_lens = [len(_tokenizer.encode(name)) for name in classnames]\n        prompts = [prompt_prefix + \" \" + name + \".\" for name in classnames]\n\n        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])\n        with torch.no_grad():\n            embedding = clip_model.token_embedding(tokenized_prompts).type(\n                dtype)\n\n        # These token vectors will be saved when in save_model(),\n        # but they should be ignored in load_model() as we want to use\n        # those computed using the current class names\n        self.register_buffer(\"token_prefix\", embedding[:, :1, :])  # SOS\n        self.register_buffer(\"token_suffix\",\n                             embedding[:, 1 + n_ctx:, :])  # CLS, EOS\n\n        self.n_cls = n_cls\n        self.n_ctx = n_ctx\n        self.tokenized_prompts = tokenized_prompts  # torch.Tensor\n        self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION\n        self.name_lens = name_lens\n\n    def forward(self):\n        ctx = self.ctx\n        if ctx.dim() == 2:\n            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)\n\n        prefix = self.token_prefix\n        suffix = self.token_suffix\n\n        if self.class_token_position == \"end\":\n            prompts = torch.cat(\n                [\n                    prefix,  # (n_cls, 1, dim)\n                    ctx,  # (n_cls, n_ctx, dim)\n                    suffix,  # (n_cls, *, dim)\n                ],\n                dim=1,\n            )\n\n        elif self.class_token_position == \"middle\":\n            half_n_ctx = n_ctx // 2\n            prompts = []\n            for i in range(self.n_cls):\n                name_len = self.name_lens[i]\n                prefix_i = prefix[i:i + 1, :, :]\n                class_i = suffix[i:i + 1, :name_len, :]\n                suffix_i = suffix[i:i + 1, name_len:, :]\n                ctx_i_half1 = ctx[i:i + 1, :half_n_ctx, :]\n                ctx_i_half2 = ctx[i:i + 1, half_n_ctx:, :]\n                prompt = torch.cat(\n                    [\n                        prefix_i,  # (1, 1, dim)\n                        ctx_i_half1,  # (1, n_ctx//2, dim)\n                        class_i,  # (1, name_len, dim)\n                        ctx_i_half2,  # (1, n_ctx//2, dim)\n                        suffix_i,  # (1, *, dim)\n                    ],\n                    dim=1,\n                )\n                prompts.append(prompt)\n            prompts = torch.cat(prompts, dim=0)\n\n        elif self.class_token_position == \"front\":\n            prompts = []\n            for i in range(self.n_cls):\n                name_len = self.name_lens[i]\n                prefix_i = prefix[i:i + 1, :, :]\n                class_i = suffix[i:i + 1, :name_len, :]\n                suffix_i = suffix[i:i + 1, name_len:, :]\n                ctx_i = ctx[i:i + 1, :, :]\n                prompt = torch.cat(\n                    [\n                        prefix_i,  # (1, 1, dim)\n                        class_i,  # (1, name_len, dim)\n                        ctx_i,  # (1, n_ctx, dim)\n                        suffix_i,  # (1, *, dim)\n                    ],\n                    dim=1,\n                )\n                prompts.append(prompt)\n            prompts = torch.cat(prompts, dim=0)\n\n        else:\n            raise ValueError\n\n        return prompts\n\n\nCUSTOM_TEMPLATES = {\n    \"OxfordPets\": \"a type of pet, a photo of a {}.\",\n    \"OxfordFlowers\": \"a type of flower, a photo of a {}.\",\n    \"FGVCAircraft\": \"a type of aircraft, a photo of a {}.\",\n    \"DescribableTextures\": \"a texture of {}.\",\n    \"EuroSAT\": \"a centered satellite photo of {}.\",\n    \"StanfordCars\": \"a photo of a {}.\",\n    \"Food101\": \"a type of food, a photo of {}.\",\n    \"SUN397\": \"a photo of a {}.\",\n    \"Caltech101\": \"a photo of a {}.\",\n    \"UCF101\": \"a photo of a person doing {}.\",\n    \"ImageNet\": \"a photo of a {}.\",\n    \"ImageNetSketch\": \"a photo of a {}.\",\n    \"ImageNetV2\": \"a photo of a {}.\",\n    \"ImageNetA\": \"a photo of a {}.\",\n    \"ImageNetR\": \"a photo of a {}.\",\n}\n\n\nclass CLIP(nn.Module):\n    def __init__(self, cfg, classnames):\n        super().__init__()\n        print(f\"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})\")\n        clip_model = load_clip_to_cpu(cfg)\n        clip_model.float()\n\n        temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]\n        prompts = [temp.format(c.replace(\"_\", \" \")) for c in classnames]\n        print(f\"Prompts: {prompts}\")\n        prompts = torch.cat([clip.tokenize(p) for p in prompts])\n\n        with torch.no_grad():\n            text_features = clip_model.encode_text(prompts)\n            text_features = text_features / text_features.norm(dim=-1,\n                                                               keepdim=True)\n\n        self.text_features = text_features\n        self.clip_model = clip_model\n\n    def forward(self, image):\n        image_features = self.clip_model.encode_image(image)\n        image_features = image_features / image_features.norm(dim=-1,\n                                                              keepdim=True)\n        logit_scale = self.clip_model.logit_scale.exp()\n\n        text_features = self.text_features\n        text_features = text_features.to(image_features.device)\n        logits = logit_scale * image_features @ text_features.t()\n        return logits\n\n\nclass CustomCLIP(nn.Module):\n    def __init__(self, cfg, classnames, clip_model):\n        super().__init__()\n        self.prompt_learner = PromptLearner(cfg, classnames, clip_model)\n        self.tokenized_prompts = self.prompt_learner.tokenized_prompts\n        self.image_encoder = clip_model.visual\n        self.text_encoder = TextEncoder(clip_model)\n        self.logit_scale = clip_model.logit_scale\n        self.dtype = clip_model.dtype\n\n    def forward(self, image):\n        image_features = self.image_encoder(image.type(self.dtype))\n\n        prompts = self.prompt_learner()\n        tokenized_prompts = self.tokenized_prompts\n        text_features = self.text_encoder(prompts, tokenized_prompts)\n\n        image_features = image_features / image_features.norm(dim=-1,\n                                                              keepdim=True)\n        text_features = text_features / text_features.norm(dim=-1,\n                                                           keepdim=True)\n\n        logit_scale = self.logit_scale.exp()\n        logits = logit_scale * image_features @ text_features.t()\n\n        return logits\n\n\nclass ProGradLoss(_Loss):\n    def __init__(self, T):\n        super(ProGradLoss, self).__init__()\n        self.T = T\n\n    def forward(self, stu_logits, tea_logits, label):\n        xe_loss = F.cross_entropy(stu_logits, label)\n\n        tea_prob = F.softmax(tea_logits / self.T, dim=-1)\n        kl_loss = -tea_prob * F.log_softmax(stu_logits / self.T,\n                                            -1) * self.T * self.T\n        kl_loss = kl_loss.sum(1).mean()\n\n        return xe_loss, kl_loss\n\n\n@TRAINER_REGISTRY.register()\nclass ProGrad(TrainerX):\n    \"\"\"Projected Gradient for few-shot CLIP \n    \"\"\"\n    def check_cfg(self, cfg):\n        assert cfg.TRAINER.COOP.PREC in [\"fp16\", \"fp32\", \"amp\"]\n\n    def build_model(self):\n        cfg = self.cfg\n        classnames = self.dm.dataset.classnames\n\n        print(f\"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})\")\n        clip_model = load_clip_to_cpu(cfg)\n\n        if cfg.TRAINER.COOP.PREC == \"fp32\" or cfg.TRAINER.COOP.PREC == \"amp\":\n            # CLIP's default precision is fp16\n            clip_model.float()\n\n        print(\"Building zeroshot CLIP\")\n        self.zs_clip = CLIP(cfg, classnames)\n\n        print(\"Building custom CLIP\")\n        self.model = CustomCLIP(cfg, classnames, clip_model)\n\n        print(\"Turning off gradients in ZS Clip model\")\n        for name, param in self.zs_clip.named_parameters():\n            param.requires_grad_(False)\n\n        print(\"Turning off gradients in CoOp model\")\n        for name, param in self.model.named_parameters():\n            if \"prompt_learner\" not in name:\n                param.requires_grad_(False)\n\n        if cfg.MODEL.INIT_WEIGHTS:\n            load_pretrained_weights(self.model.prompt_learner,\n                                    cfg.MODEL.INIT_WEIGHTS)\n\n        self.model.to(self.device)\n        self.zs_clip = self.zs_clip.cuda()\n\n        # NOTE: only give prompt_learner to the optimizer\n        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)\n        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)\n        self.register_model(\"prompt_learner\", self.model.prompt_learner,\n                            self.optim, self.sched)\n\n        self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == \"amp\" else None\n\n        # Note that multi-gpu training could be slow because CLIP's size is\n        # big, which slows down the copy operation in DataParallel\n        device_count = torch.cuda.device_count()\n        if device_count > 1:\n            print(\n                f\"Multiple GPUs detected (n_gpus={device_count}), use all of them!\"\n            )\n            self.model = nn.DataParallel(self.model)\n            self.zs_clip = nn.DataParallel(self.zs_clip)\n\n        # build criterion\n        if cfg.LOSS.NAME == \"prograd\":\n            self.criterion = ProGradLoss(T=cfg.LOSS.T)\n        else:\n            raise NotImplementedError\n\n    def forward_backward(self, batch):\n        image, label = self.parse_batch_train(batch)\n\n        prec = self.cfg.TRAINER.COOP.PREC\n        if prec == \"amp\":\n            with autocast():\n                output = self.model(image)\n                with torch.no_grad():\n                    zs_clip_output = self.zs_clip(image)\n                loss = self.criterion(output, zs_clip_output.detach(), label)\n            self.optim.zero_grad()\n            self.scaler.scale(loss).backward()\n            self.scaler.step(self.optim)\n            self.scaler.update()\n        else:\n            output = self.model(image)\n            with torch.no_grad():\n                zs_clip_output = self.zs_clip(image)\n\n            xe_loss, kl_loss = self.criterion(output,\n                                              zs_clip_output.detach(),\n                                              label)\n            self.prograd_backward_and_update(xe_loss, kl_loss,\n                                                 self.cfg.LOSS.LAMBDA)\n\n        loss_summary = {\n            \"xe_loss\": xe_loss.item(),\n            \"kl_loss\": kl_loss.item(),\n            \"acc\": compute_accuracy(output, label)[0].item(),\n        }\n\n        if (self.batch_idx + 1) == self.num_batches:\n            self.update_lr()\n\n        return loss_summary\n\n    def parse_batch_train(self, batch):\n        input = batch[\"img\"]\n        label = batch[\"label\"]\n        input = input.to(self.device)\n        label = label.to(self.device)\n        return input, label\n\n    def load_model(self, directory, epoch=None):\n        if not directory:\n            print(\n                \"Note that load_model() is skipped as no pretrained model is given\"\n            )\n            return\n\n        names = self.get_model_names()\n\n        # By default, the best model is loaded\n        model_file = \"model-best.pth.tar\"\n\n        if epoch is not None:\n            model_file = \"model.pth.tar-\" + str(epoch)\n\n        for name in names:\n            model_path = osp.join(directory, name, model_file)\n\n            if not osp.exists(model_path):\n                raise FileNotFoundError(\n                    'Model not found at \"{}\"'.format(model_path))\n\n            checkpoint = load_checkpoint(model_path)\n            state_dict = checkpoint[\"state_dict\"]\n            epoch = checkpoint[\"epoch\"]\n\n            # Ignore fixed token vectors\n            if \"token_prefix\" in state_dict:\n                del state_dict[\"token_prefix\"]\n\n            if \"token_suffix\" in state_dict:\n                del state_dict[\"token_suffix\"]\n\n            print(\"Loading weights to {} \"\n                  'from \"{}\" (epoch = {})'.format(name, model_path, epoch))\n            # set strict=False\n            self._models[name].load_state_dict(state_dict, strict=False)\n"
  },
  {
    "path": "ProGrad.public/trainers/zsclip.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.optim import build_optimizer, build_lr_scheduler\n\nfrom clip import clip\nfrom clip.model import convert_weights\n\nfrom .coop import load_clip_to_cpu\nfrom .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT\n\nCUSTOM_TEMPLATES = {\n    # \"OxfordPets\": \"a photo of a {}, a type of pet.\",\n    \"OxfordPets\": \"a type of pet, a photo of a {}.\",\n    # \"OxfordFlowers\": \"a photo of a {}, a type of flower.\",\n    \"OxfordFlowers\": \"a type of flower, a photo of a {}.\",\n    \"FGVCAircraft\": \"a photo of a {}, a type of aircraft.\",\n    \"DescribableTextures\": \"{} texture.\",\n    \"EuroSAT\": \"a centered satellite photo of {}.\",\n    \"StanfordCars\": \"a photo of a {}.\",\n    # \"Food101\": \"a photo of {}, a type of food.\",\n    \"Food101\": \"a type of food, a photo of {}.\",\n    \"SUN397\": \"a photo of a {}.\",\n    \"Caltech101\": \"a photo of a {}.\",\n    \"UCF101\": \"a photo of a person doing {}.\",\n    \"ImageNet\": \"a photo of a {}.\",\n    \"ImageNetSketch\": \"a photo of a {}.\",\n    \"ImageNetV2\": \"a photo of a {}.\",\n    \"ImageNetA\": \"a photo of a {}.\",\n    \"ImageNetR\": \"a photo of a {}.\",\n}\n\n\n@TRAINER_REGISTRY.register()\nclass ZeroshotCLIP(TrainerX):\n    def build_model(self):\n        cfg = self.cfg\n        classnames = self.dm.dataset.classnames\n\n        print(f\"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})\")\n        clip_model = load_clip_to_cpu(cfg)\n        clip_model.to(self.device)\n\n        temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]\n        prompts = [temp.format(c.replace(\"_\", \" \")) for c in classnames]\n        print(f\"Prompts: {prompts}\")\n        prompts = torch.cat([clip.tokenize(p) for p in prompts])\n        prompts = prompts.to(self.device)\n\n        with torch.no_grad():\n            text_features = clip_model.encode_text(prompts)\n            text_features = text_features / text_features.norm(dim=-1,\n                                                               keepdim=True)\n\n        self.text_features = text_features\n        self.clip_model = clip_model\n\n    def model_inference(self, image):\n        image_features = self.clip_model.encode_image(image)\n        image_features = image_features / image_features.norm(dim=-1,\n                                                              keepdim=True)\n        logit_scale = self.clip_model.logit_scale.exp()\n        logits = logit_scale * image_features @ self.text_features.t()\n        return logits\n\n\n@TRAINER_REGISTRY.register()\nclass ZeroshotCLIP2(ZeroshotCLIP):\n    \"\"\"Prompt ensembling.\"\"\"\n\n    # templates = IMAGENET_TEMPLATES\n    templates = IMAGENET_TEMPLATES_SELECT\n\n    def build_model(self):\n        cfg = self.cfg\n        classnames = self.dm.dataset.classnames\n\n        print(f\"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})\")\n        clip_model = load_clip_to_cpu(cfg)\n        clip_model.to(self.device)\n\n        for params in clip_model.parameters():\n            params.requires_grad_(False)\n\n        # add custom-made prompt\n        if cfg.DATASET.NAME != \"ImageNet\":\n            self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]\n\n        num_temp = len(self.templates)\n        print(f\"Prompt ensembling (n={num_temp})\")\n\n        mean_text_features = 0\n        for i, temp in enumerate(self.templates):\n            prompts = [temp.format(c.replace(\"_\", \" \")) for c in classnames]\n            prompts = torch.cat([clip.tokenize(p)\n                                 for p in prompts]).to(self.device)\n            text_features = clip_model.encode_text(prompts)\n            text_features = text_features / text_features.norm(dim=-1,\n                                                               keepdim=True)\n            mean_text_features = mean_text_features + text_features\n        mean_text_features = mean_text_features / num_temp\n        mean_text_features = mean_text_features / mean_text_features.norm(\n            dim=-1, keepdim=True)\n\n        self.text_features = mean_text_features\n        self.clip_model = clip_model\n"
  },
  {
    "path": "readme.md",
    "content": "# [ICCV23] Prompt-aligned Gradient for Prompt Tuning\n\nWe present Prompt-aligned Gradient, dubbed ProGrad, to prevent prompt tuning from forgetting the the general knowledge learned from VLMs. In particular, ProGrad only updates the prompt whose gradient is aligned (or non-conflicting) to the “general direction”, which is represented as the gradient of the KL loss of the pre-defined prompt prediction. Extensive experiments demonstrate the stronger few-shot generalization ability of ProGrad over state-of-the-art prompt tuning methods. \n\n![image](ProGrad.public/Pipeline.png)\n\n[[paper link]](https://doi.org/10.48550/arxiv.2205.14865)\n\nThe codes are organized into two folders:\n\n1. [Dassl.ProGrad.pytorch](Dassl.ProGrad.pytorch/) is the modified toolbox of [Dassl.pytorch](https://github.com/KaiyangZhou/Dassl.pytorch).\n2. [ProGrad.public](ProGrad.public/). To get the results in our paper, follow the [README.md](ProGrad.public/README.md) under [ProGrad.public/](ProGrad.public/) to set the environment.\n\n## Citation\n\nIf you find our paper or this project helps your research, please kindly consider citing our paper in your publication.\n\n```\n@inproceedings{https://doi.org/10.48550/arxiv.2205.14865,\n  author = {Zhu, Beier and Niu, Yulei and Han, Yucheng and Wu, Yue and Zhang, Hanwang},\n  title = {Prompt-aligned Gradient for Prompt Tuning},\n  publisher = {International Conference on Computer Vision},\n  year = {2023},\n}\n\n```\n\n## Acknowledgement\nOur codes are built on top of [CoOp](https://github.com/KaiyangZhou/CoOp) and [Dassl](https://github.com/KaiyangZhou/Dassl.pytorch).\n"
  }
]