Full Code of BeierZhu/Prompt-align for AI

main 1b762036cbe9 cached
244 files
606.4 KB
166.1k tokens
911 symbols
1 requests
Download .txt
Showing preview only (672K chars total). Download the full file or copy to clipboard to get everything.
Repository: BeierZhu/Prompt-align
Branch: main
Commit: 1b762036cbe9
Files: 244
Total size: 606.4 KB

Directory structure:
gitextract_cxcff6dn/

├── Dassl.ProGrad.pytorch/
│   ├── .flake8
│   ├── .gitignore
│   ├── .isort.cfg
│   ├── .style.yapf
│   ├── DATASETS.md
│   ├── LICENSE
│   ├── README.md
│   ├── configs/
│   │   ├── README.md
│   │   ├── datasets/
│   │   │   ├── da/
│   │   │   │   ├── cifar_stl.yaml
│   │   │   │   ├── digit5.yaml
│   │   │   │   ├── domainnet.yaml
│   │   │   │   ├── mini_domainnet.yaml
│   │   │   │   ├── office31.yaml
│   │   │   │   ├── office_home.yaml
│   │   │   │   └── visda17.yaml
│   │   │   ├── dg/
│   │   │   │   ├── cifar100_c.yaml
│   │   │   │   ├── cifar10_c.yaml
│   │   │   │   ├── digit_single.yaml
│   │   │   │   ├── digits_dg.yaml
│   │   │   │   ├── office_home_dg.yaml
│   │   │   │   ├── pacs.yaml
│   │   │   │   └── vlcs.yaml
│   │   │   └── ssl/
│   │   │       ├── cifar10.yaml
│   │   │       ├── cifar100.yaml
│   │   │       ├── stl10.yaml
│   │   │       └── svhn.yaml
│   │   └── trainers/
│   │       ├── da/
│   │       │   ├── dael/
│   │       │   │   ├── digit5.yaml
│   │       │   │   ├── domainnet.yaml
│   │       │   │   └── mini_domainnet.yaml
│   │       │   ├── m3sda/
│   │       │   │   ├── digit5.yaml
│   │       │   │   ├── domainnet.yaml
│   │       │   │   └── mini_domainnet.yaml
│   │       │   └── source_only/
│   │       │       ├── digit5.yaml
│   │       │       ├── mini_domainnet.yaml
│   │       │       ├── office31.yaml
│   │       │       └── visda17.yaml
│   │       ├── dg/
│   │       │   ├── dael/
│   │       │   │   ├── digits_dg.yaml
│   │       │   │   ├── office_home_dg.yaml
│   │       │   │   └── pacs.yaml
│   │       │   ├── ddaig/
│   │       │   │   ├── digits_dg.yaml
│   │       │   │   ├── office_home_dg.yaml
│   │       │   │   └── pacs.yaml
│   │       │   └── vanilla/
│   │       │       ├── digits_dg.yaml
│   │       │       ├── mini_domainnet.yaml
│   │       │       ├── office_home_dg.yaml
│   │       │       └── pacs.yaml
│   │       └── ssl/
│   │           └── fixmatch/
│   │               └── cifar10.yaml
│   ├── dassl/
│   │   ├── __init__.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   └── defaults.py
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   ├── data_manager.py
│   │   │   ├── datasets/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base_dataset.py
│   │   │   │   ├── build.py
│   │   │   │   ├── da/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── cifarstl.py
│   │   │   │   │   ├── digit5.py
│   │   │   │   │   ├── domainnet.py
│   │   │   │   │   ├── mini_domainnet.py
│   │   │   │   │   ├── office31.py
│   │   │   │   │   ├── office_home.py
│   │   │   │   │   └── visda17.py
│   │   │   │   ├── dg/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── cifar_c.py
│   │   │   │   │   ├── digit_single.py
│   │   │   │   │   ├── digits_dg.py
│   │   │   │   │   ├── office_home_dg.py
│   │   │   │   │   ├── pacs.py
│   │   │   │   │   └── vlcs.py
│   │   │   │   └── ssl/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── cifar.py
│   │   │   │       ├── stl10.py
│   │   │   │       └── svhn.py
│   │   │   ├── samplers.py
│   │   │   └── transforms/
│   │   │       ├── __init__.py
│   │   │       ├── autoaugment.py
│   │   │       ├── randaugment.py
│   │   │       └── transforms.py
│   │   ├── engine/
│   │   │   ├── __init__.py
│   │   │   ├── build.py
│   │   │   ├── da/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── adabn.py
│   │   │   │   ├── adda.py
│   │   │   │   ├── dael.py
│   │   │   │   ├── dann.py
│   │   │   │   ├── m3sda.py
│   │   │   │   ├── mcd.py
│   │   │   │   ├── mme.py
│   │   │   │   ├── self_ensembling.py
│   │   │   │   └── source_only.py
│   │   │   ├── dg/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── crossgrad.py
│   │   │   │   ├── daeldg.py
│   │   │   │   ├── ddaig.py
│   │   │   │   └── vanilla.py
│   │   │   ├── ssl/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── entmin.py
│   │   │   │   ├── fixmatch.py
│   │   │   │   ├── mean_teacher.py
│   │   │   │   ├── mixmatch.py
│   │   │   │   └── sup_baseline.py
│   │   │   └── trainer.py
│   │   ├── evaluation/
│   │   │   ├── __init__.py
│   │   │   ├── build.py
│   │   │   └── evaluator.py
│   │   ├── metrics/
│   │   │   ├── __init__.py
│   │   │   ├── accuracy.py
│   │   │   └── distance.py
│   │   ├── modeling/
│   │   │   ├── __init__.py
│   │   │   ├── backbone/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── alexnet.py
│   │   │   │   ├── backbone.py
│   │   │   │   ├── build.py
│   │   │   │   ├── cnn_digit5_m3sda.py
│   │   │   │   ├── cnn_digitsdg.py
│   │   │   │   ├── cnn_digitsingle.py
│   │   │   │   ├── efficientnet/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── model.py
│   │   │   │   │   └── utils.py
│   │   │   │   ├── mobilenetv2.py
│   │   │   │   ├── preact_resnet18.py
│   │   │   │   ├── resnet.py
│   │   │   │   ├── shufflenetv2.py
│   │   │   │   ├── vgg.py
│   │   │   │   └── wide_resnet.py
│   │   │   ├── head/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── build.py
│   │   │   │   └── mlp.py
│   │   │   ├── network/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── build.py
│   │   │   │   └── ddaig_fcn.py
│   │   │   └── ops/
│   │   │       ├── __init__.py
│   │   │       ├── cross_entropy.py
│   │   │       ├── dsbn.py
│   │   │       ├── efdmix.py
│   │   │       ├── mixstyle.py
│   │   │       ├── mixup.py
│   │   │       ├── mmd.py
│   │   │       ├── optimal_transport.py
│   │   │       ├── reverse_grad.py
│   │   │       ├── sequential2.py
│   │   │       ├── transnorm.py
│   │   │       └── utils.py
│   │   ├── optim/
│   │   │   ├── __init__.py
│   │   │   ├── lr_scheduler.py
│   │   │   ├── optimizer.py
│   │   │   └── radam.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── logger.py
│   │       ├── meters.py
│   │       ├── registry.py
│   │       ├── tools.py
│   │       └── torchtools.py
│   ├── datasets/
│   │   ├── da/
│   │   │   ├── cifar_stl.py
│   │   │   ├── digit5.py
│   │   │   └── visda17.sh
│   │   ├── dg/
│   │   │   └── cifar_c.py
│   │   └── ssl/
│   │       ├── cifar10_cifar100_svhn.py
│   │       └── stl10.py
│   ├── linter.sh
│   ├── requirements.txt
│   ├── setup.py
│   └── tools/
│       ├── parse_test_res.py
│       ├── replace_text.py
│       └── train.py
├── ProGrad.public/
│   ├── .gitignore
│   ├── DATASETS.md
│   ├── LICENSE
│   ├── README.md
│   ├── clip/
│   │   ├── __init__.py
│   │   ├── clip.py
│   │   ├── model.py
│   │   └── simple_tokenizer.py
│   ├── configs/
│   │   ├── datasets/
│   │   │   ├── caltech101.yaml
│   │   │   ├── dtd.yaml
│   │   │   ├── eurosat.yaml
│   │   │   ├── fgvc_aircraft.yaml
│   │   │   ├── food101.yaml
│   │   │   ├── imagenet.yaml
│   │   │   ├── imagenet_a.yaml
│   │   │   ├── imagenet_r.yaml
│   │   │   ├── imagenet_sketch.yaml
│   │   │   ├── imagenetv2.yaml
│   │   │   ├── oxford_flowers.yaml
│   │   │   ├── oxford_pets.yaml
│   │   │   ├── stanford_cars.yaml
│   │   │   ├── sun397.yaml
│   │   │   └── ucf101.yaml
│   │   └── trainers/
│   │       ├── CoCoOp/
│   │       │   ├── rn50_c4_ep10_batch1_ctxv1.yaml
│   │       │   ├── rn50_ep100_init.yaml
│   │       │   ├── rn50_ep50.yaml
│   │       │   ├── vit_b16_c16_ep10_batch1.yaml
│   │       │   ├── vit_b16_c4_ep10_batch1.yaml
│   │       │   ├── vit_b16_c4_ep10_batch1_ctxv1.yaml
│   │       │   └── vit_b16_c8_ep10_batch1.yaml
│   │       ├── CoOp/
│   │       │   ├── rn50.yaml
│   │       │   ├── rn50_ep100.yaml
│   │       │   ├── rn50_ep50.yaml
│   │       │   └── rn50_val.yaml
│   │       └── ProGrad/
│   │           ├── rn50.yaml
│   │           ├── rn50_ep100.yaml
│   │           └── rn50_ep50.yaml
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── caltech101.py
│   │   ├── dtd.py
│   │   ├── eurosat.py
│   │   ├── fgvc_aircraft.py
│   │   ├── food101.py
│   │   ├── imagenet.py
│   │   ├── imagenet_a.py
│   │   ├── imagenet_r.py
│   │   ├── imagenet_sketch.py
│   │   ├── imagenetv2.py
│   │   ├── oxford_flowers.py
│   │   ├── oxford_pets.py
│   │   ├── stanford_cars.py
│   │   ├── sun397.py
│   │   └── ucf101.py
│   ├── interpret_prompt.py
│   ├── lpclip/
│   │   ├── README.md
│   │   ├── feat_extractor.py
│   │   ├── feat_extractor.sh
│   │   ├── linear_probe.py
│   │   ├── linear_probe.sh
│   │   └── linear_probe_transfer.py
│   ├── parse_test_res.py
│   ├── requirements.txt
│   ├── scripts/
│   │   ├── base2new_test_main.sh
│   │   ├── base2new_test_prograd.sh
│   │   ├── base2new_train_main.sh
│   │   ├── base2new_train_prograd.sh
│   │   ├── eval.sh
│   │   ├── main.sh
│   │   ├── prograd.sh
│   │   └── zeroshot.sh
│   ├── train.py
│   └── trainers/
│       ├── __init__.py
│       ├── cocoop.py
│       ├── coop.py
│       ├── imagenet_templates.py
│       ├── prograd.py
│       └── zsclip.py
└── readme.md

================================================
FILE CONTENTS
================================================

================================================
FILE: Dassl.ProGrad.pytorch/.flake8
================================================
[flake8]
ignore =
    # At least two spaces before inline comment
    E261,
    # Line lengths are recommended to be no greater than 79 characters
    E501,
    # Missing whitespace around arithmetic operator 
    E226,
    # Blank line contains whitespace
    W293,
    # Do not use bare 'except'
    E722,
    # Line break after binary operator
    W504,
    # Too many leading '#' for block comment
    E266,
    # line break before binary operator
    W503,
    # continuation line over-indented for hanging indent
    E126
max-line-length = 79
exclude = __init__.py, build

================================================
FILE: Dassl.ProGrad.pytorch/.gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# OS X
.DS_Store
.Spotlight-V100
.Trashes
._*

# This project
output/
debug.sh
debug.py


================================================
FILE: Dassl.ProGrad.pytorch/.isort.cfg
================================================
[isort]
line_length=79
multi_line_output=6
length_sort=true
known_standard_library=numpy,setuptools
known_myself=dassl
known_third_party=matplotlib,cv2,torch,torchvision,PIL,yacs,scipy,gdown
no_lines_before=STDLIB,THIRDPARTY
sections=FUTURE,STDLIB,THIRDPARTY,myself,FIRSTPARTY,LOCALFOLDER
default_section=FIRSTPARTY

================================================
FILE: Dassl.ProGrad.pytorch/.style.yapf
================================================
[style]
BASED_ON_STYLE = pep8
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
DEDENT_CLOSING_BRACKETS = true
SPACES_BEFORE_COMMENT = 2
ARITHMETIC_PRECEDENCE_INDICATION = true

================================================
FILE: Dassl.ProGrad.pytorch/DATASETS.md
================================================
# How to Install Datasets

`$DATA` denotes the location where datasets are installed, e.g.

```
$DATA/
|–– office31/
|–– office_home/
|–– visda17/
```

[Domain Adaptation](#domain-adaptation)
- [Office-31](#office-31)
- [Office-Home](#office-home)
- [VisDA17](#visda17)
- [CIFAR10-STL10](#cifar10-stl10)
- [Digit-5](#digit-5)
- [DomainNet](#domainnet)
- [miniDomainNet](#miniDomainNet)

[Domain Generalization](#domain-generalization)
- [PACS](#pacs)
- [VLCS](#vlcs)
- [Office-Home-DG](#office-home-dg)
- [Digits-DG](#digits-dg)
- [Digit-Single](#digit-single)
- [CIFAR-10-C](#cifar-10-c)
- [CIFAR-100-C](#cifar-100-c)

[Semi-Supervised Learning](#semi-supervised-learning)
- [CIFAR10/100 and SVHN](#cifar10100-and-svhn)
- [STL10](#stl10)

## Domain Adaptation

### Office-31

Download link: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/#datasets_code.

File structure:

```
office31/
|–– amazon/
|   |–– back_pack/
|   |–– bike/
|   |–– ...
|–– dslr/
|   |–– back_pack/
|   |–– bike/
|   |–– ...
|–– webcam/
|   |–– back_pack/
|   |–– bike/
|   |–– ...
```

Note that within each domain folder you need to move all class folders out of the `images/` folder and then delete the `images/` folder.

### Office-Home

Download link: http://hemanthdv.org/OfficeHome-Dataset/.

File structure:

```
office_home/
|–– art/
|–– clipart/
|–– product/
|–– real_world/
```

### VisDA17

Download link: http://ai.bu.edu/visda-2017/.

The dataset can also be downloaded using our script at `datasets/da/visda17.sh`. Run the following command in your terminal under `Dassl.pytorch/datasets/da`,

```bash
sh visda17.sh $DATA
```

Once the download is finished, the file structure will look like

```
visda17/
|–– train/
|–– test/
|–– validation/
```

### CIFAR10-STL10

Run the following command in your terminal under `Dassl.pytorch/datasets/da`,

```bash
python cifar_stl.py $DATA/cifar_stl
```

This will create a folder named `cifar_stl` under `$DATA`. The file structure will look like

```
cifar_stl/
|–– cifar/
|   |–– train/
|   |–– test/
|–– stl/
|   |–– train/
|   |–– test/
```

Note that only 9 classes shared by both datasets are kept.

### Digit-5

Create a folder `$DATA/digit5` and download to this folder the dataset from [here](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download). This should give you

```
digit5/
|–– Digit-Five/
```

Then, run the following command in your terminal under `Dassl.pytorch/datasets/da`,

```bash 
python digit5.py $DATA/digit5
```

This will extract the data and organize the file structure as

```
digit5/
|–– Digit-Five/
|–– mnist/
|–– mnist_m/
|–– usps/
|–– svhn/
|–– syn/
```

### DomainNet

Download link: http://ai.bu.edu/M3SDA/. (Please download the cleaned version of split files)

File structure:

```
domainnet/
|–– clipart/
|–– infograph/
|–– painting/
|–– quickdraw/
|–– real/
|–– sketch/
|–– splits/
|   |–– clipart_train.txt
|   |–– clipart_test.txt
|   |–– ...
```

### miniDomainNet

You need to download the DomainNet dataset first. The miniDomainNet's split files can be downloaded at this [google drive](https://drive.google.com/open?id=15rrLDCrzyi6ZY-1vJar3u7plgLe4COL7). After the zip file is extracted, you should have the folder `$DATA/domainnet/splits_mini/`.

## Domain Generalization

### PACS

Download link: [google drive](https://drive.google.com/open?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE).

File structure:

```
pacs/
|–– images/
|–– splits/
```

You do not necessarily have to manually download this dataset. Once you run ``tools/train.py``, the code will detect if the dataset exists or not and automatically download the dataset to ``$DATA`` if missing. This also applies to VLCS, Office-Home-DG, and Digits-DG.

### VLCS

Download link: [google drive](https://drive.google.com/file/d/1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd/view?usp=sharing) (credit to https://github.com/fmcarlucci/JigenDG#vlcs)

File structure:

```
VLCS/
|–– CALTECH/
|–– LABELME/
|–– PASCAL/
|–– SUN/
```

### Office-Home-DG

Download link: [google drive](https://drive.google.com/open?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa).

File structure:

```
office_home_dg/
|–– art/
|–– clipart/
|–– product/
|–– real_world/
```

### Digits-DG

Download link: [google driv](https://drive.google.com/open?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7).

File structure:

```
digits_dg/
|–– mnist/
|–– mnist_m/
|–– svhn/
|–– syn/
```

### Digit-Single
Follow the steps for [Digit-5](#digit-5) to organize the dataset.

### CIFAR-10-C

First download the CIFAR-10-C dataset from https://zenodo.org/record/2535967#.YFxHEWQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal
```bash
python cifar_c.py $DATA/CIFAR-10-C
```
where the first argument denotes the path to the (uncompressed) CIFAR-10-C dataset.

The script will extract images from the `.npy` files and save them to `cifar10_c/` created under $DATA. The file structure will look like
```
cifar10_c/
|–– brightness/
|   |–– 1/ # 5 intensity levels in total
|   |–– 2/
|   |–– 3/
|   |–– 4/
|   |–– 5/
|–– ... # 19 corruption types in total
```

Note that `cifar10_c/` only contains the test images. The training images are the normal CIFAR-10 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-10 dataset.

### CIFAR-100-C

First download the CIFAR-100-C dataset from https://zenodo.org/record/3555552#.YFxpQmQzb0o to, e.g., $DATA, and extract the file under the same directory. Then, navigate to `Dassl.pytorch/datasets/dg` and run the following command in your terminal
```bash
python cifar_c.py $DATA/CIFAR-100-C
```
where the first argument denotes the path to the (uncompressed) CIFAR-100-C dataset.

The script will extract images from the `.npy` files and save them to `cifar100_c/` created under $DATA. The file structure will look like
```
cifar100_c/
|–– brightness/
|   |–– 1/ # 5 intensity levels in total
|   |–– 2/
|   |–– 3/
|   |–– 4/
|   |–– 5/
|–– ... # 19 corruption types in total
```

Note that `cifar100_c/` only contains the test images. The training images are the normal CIFAR-100 images. See [CIFAR10/100 and SVHN](#cifar10100-and-svhn) for how to prepare the CIFAR-100 dataset.

## Semi-Supervised Learning

### CIFAR10/100 and SVHN

Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`,

```bash
python cifar10_cifar100_svhn.py $DATA
```

This will create three folders under `$DATA`, i.e.

```
cifar10/
|–– train/
|–– test/
cifar100/
|–– train/
|–– test/
svhn/
|–– train/
|–– test/
```

### STL10

Run the following command in your terminal under `Dassl.pytorch/datasets/ssl`,

```bash
python stl10.py $DATA/stl10
```

This will create a folder named `stl10` under `$DATA` and extract the data into three folders, i.e. `train`, `test` and `unlabeled`. Then, download from http://ai.stanford.edu/~acoates/stl10/ the "Binary files" and extract it under `stl10`.

The file structure will look like

```
stl10/
|–– train/
|–– test/
|–– unlabeled/
|–– stl10_binary/
```

================================================
FILE: Dassl.ProGrad.pytorch/LICENSE
================================================
MIT License

Copyright (c) 2020 Kaiyang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: Dassl.ProGrad.pytorch/README.md
================================================
# Dassl

## Introduction

Dassl is a [PyTorch](https://pytorch.org) toolbox initially developed for our project [Domain Adaptive Ensemble Learning (DAEL)](https://arxiv.org/abs/2003.07325) to support research in domain adaptation and generalization---since in DAEL we study how to unify these two problems in a single learning framework. Given that domain adaptation is closely related to semi-supervised learning---both study how to exploit unlabeled data---we also incorporate components that support research for the latter.

Why the name "Dassl"? Dassl combines the initials of domain adaptation (DA) and semi-supervised learning (SSL), which sounds natural and informative.

Dassl has a modular design and unified interfaces, allowing fast prototyping and experimentation of new DA/DG/SSL methods. With Dassl, a new method can be implemented with only a few lines of code. Don't believe? Take a look at the [engine](https://github.com/KaiyangZhou/Dassl.pytorch/tree/master/dassl/engine) folder, which contains the implementations of many existing methods (then you will come back and star this repo). :-)

Basically, Dassl is perfect for doing research in the following areas:
- Domain adaptation
- Domain generalization
- Semi-supervised learning

BUT, thanks to the neat design, Dassl can also be used as a codebase to develop any deep learning projects, like [this](https://github.com/KaiyangZhou/CoOp). :-)

A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU training (Dassl uses `DataParallel` to wrap a model, which is less efficient than `DistributedDataParallel`).

We don't provide detailed documentations for Dassl, unlike another [project](https://kaiyangzhou.github.io/deep-person-reid/) of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-)

## What's new
- Mar 2022: A new domain generalization method [EFDM](https://arxiv.org/abs/2203.07740) developed by [Yabin Zhang (PolyU)](https://ybzh.github.io/) and to appear at CVPR'22 is added to this repo. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/pull/36) for more details.
- Feb 2022: In case you don't know, a class in the painting domain of DomainNet (the official splits) only has test images (no training images), which could affect performance. See section 4.a in our [paper](https://arxiv.org/abs/2003.07325) for more details.
- Oct 2021: `v0.5.0`: **Important changes** made to `transforms.py`. 1) `center_crop` becomes a default transform in testing (applied after resizing the smaller edge to a certain size to keep the image aspect ratio). 2) For training, `Resize(cfg.INPUT.SIZE)` is deactivated when `random_crop` or `random_resized_crop` is used. These changes won't make any difference to the training transforms used in existing config files, nor to the testing transforms unless the raw images are not squared (the only difference is that now the image aspect ratio is respected).
- Oct 2021: `v0.4.3`: Copy the attributes in `self.dm` (data manager) to `SimpleTrainer` and make `self.dm` optional, which means from now on, you can build data loaders from any source you like rather than being forced to use `DataManager`.
- Sep 2021: `v0.4.2`: An important update is to set `drop_last=is_train and len(data_source)>=batch_size` when constructing a data loader to avoid 0-length.

<details>
    <summary>More</summary>

- Aug 2021: `v0.4.0`: The most noteworthy update is adding the learning rate warmup scheduler. The implementation is detailed [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/optim/lr_scheduler.py#L10) and the config variables are specified [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L171).
- Jul 2021: `v0.3.4`: Adds a new function `generate_fewshot_dataset()` to the base dataset class, which allows for the generation of a few-shot learning setting. One can customize a few-shot dataset by specifying `_C.DATASET.NUM_SHOTS` and give it to `generate_fewshot_dataset()`.
- Jul 2021: `v0.3.2`: Adds `_C.INPUT.INTERPOLATION` (default: `bilinear`). Available interpolation modes are `bilinear`, `nearest`, and `bicubic`.
- Jul 2021 `v0.3.1`: Now you can use `*.register(force=True)` to replace previously registered modules.
- Jul 2021 `v0.3.0`: Allows to deploy the model with the best validation performance for final test (for the purpose of model selection). Specifically, a new config variable named `_C.TEST.FINAL_MODEL` is introduced, which takes either `"last_step"` (default) or `"best_val"`. When set to `"best_val"`, the model will be evaluated on the `val` set after each epoch and the one with the best validation performance will be saved and used for final test (see this [code](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/engine/trainer.py#L412)).
- Jul 2021 `v0.2.7`: Adds attribute `classnames` to the base dataset class. Now you can get a list of class names ordered by numeric labels by calling `trainer.dm.dataset.classnames`.
- Jun 2021 `v0.2.6`: Merges `MixStyle2` to `MixStyle`. A new variable `self.mix` is used to switch between random mixing and cross-domain mixing. Please see [this](https://github.com/KaiyangZhou/Dassl.pytorch/issues/23) for more details on the new features.
- Jun 2021 `v0.2.5`: Fixs a [bug](https://github.com/KaiyangZhou/Dassl.pytorch/commit/29881c7faee7405f80f5f674de4bbbf80d5dc77a) in the calculation of per-class recognition accuracy.
- Jun 2021 `v0.2.4`: Adds `extend_cfg(cfg)` to `train.py`. This function is particularly useful when you build your own methods on top of Dassl.pytorch and need to define some custom variables. Please see the repository [mixstyle-release](https://github.com/KaiyangZhou/mixstyle-release) or [ssdg-benchmark](https://github.com/KaiyangZhou/ssdg-benchmark) for examples.
- Jun 2021 New benchmarks for semi-supervised domain generalization at https://github.com/KaiyangZhou/ssdg-benchmark.
- Apr 2021 Do you know you can use `tools/parse_test_res.py` to read the log files and automatically calculate and print out the results including mean and standard deviation? Check the instructions in `tools/parse_test_res.py` for more details.
- Apr 2021 `v0.2.3`: A [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) layer can now be deactivated or activated by using `model.apply(deactivate_mixstyle)` or `model.apply(activate_mixstyle)` without modifying the source code. See [dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py) for the details.
- Apr 2021 `v0.2.2`: Adds `RandomClassSampler`, which samples from a certain number of classes a certain number of images to form a minibatch (the code is modified from [Torchreid](https://github.com/KaiyangZhou/deep-person-reid)).
- Apr 2021 `v0.2.1`: Slightly adjusts the ordering in `setup_cfg()` (see `tools/train.py`).
- Apr 2021 `v0.2.0`: Adds `_C.DATASET.ALL_AS_UNLABELED` (for the SSL setting) to the config variable list. When this variable is set to `True`, all labeled data will be included in the unlabeled data set.
- Apr 2021 `v0.1.9`: Adds [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf) to the benchmark datasets (see `dassl/data/datasets/dg/vlcs.py`).
- Mar 2021 `v0.1.8`: Allows `optim` and `sched` to be `None` in `register_model()`.
- Mar 2021 `v0.1.7`: Adds [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) models to [dassl/modeling/backbone/resnet.py](dassl/modeling/backbone/resnet.py). The training configs in `configs/trainers/dg/vanilla` can be used to train MixStyle models.
- Mar 2021 `v0.1.6`: Adds [CIFAR-10/100-C](https://arxiv.org/abs/1807.01697) to the benchmark datasets for evaluating a model's robustness to image corruptions.
- Mar 2021 We have just released a survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in this topic with coverage on the history, related problems, datasets, methodologies, potential directions, and so on.
- Jan 2021 Our recent work, [MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp) (mixing instance-level feature statistics of samples of different domains for improving domain generalization), is accepted to ICLR'21. The code is available at https://github.com/KaiyangZhou/mixstyle-release where the cross-domain image classification part is based on Dassl.pytorch.
- May 2020 `v0.1.3`: Adds the `Digit-Single` dataset for benchmarking single-source DG methods. The corresponding CNN model is [dassl/modeling/backbone/cnn_digitsingle.py](dassl/modeling/backbone/cnn_digitsingle.py) and the dataset config file is [configs/datasets/dg/digit_single.yaml](configs/datasets/dg/digit_single.yaml). See [Volpi et al. NIPS'18](https://arxiv.org/abs/1805.12018) for how to do evaluation.
- May 2020 `v0.1.2`: 1) Adds [EfficientNet](https://arxiv.org/abs/1905.11946) models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, set `MODEL.BACKBONE.NAME` to `efficientnet_b{N}` where `N={0, ..., 7}`. 2) `dassl/modeling/models` is renamed to `dassl/modeling/network` (`build_model()` to `build_network()` and `MODEL_REGISTRY` to `NETWORK_RESIGTRY`).

</details>

## Overview

Dassl has implemented the following methods:

- Single-source domain adaptation
    - [Semi-supervised Domain Adaptation via Minimax Entropy (ICCV'19)](https://arxiv.org/abs/1904.06487) [[dassl/engine/da/mme.py](dassl/engine/da/mme.py)]
    - [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR'18)](https://arxiv.org/abs/1712.02560https://arxiv.org/abs/1712.02560) [[dassl/engine/da/mcd.py](dassl/engine/da/mcd.py)]
    - [Self-ensembling for visual domain adaptation (ICLR'18)](https://arxiv.org/abs/1706.05208) [[dassl/engine/da/self_ensembling.py](dassl/engine/da/self_ensembling.py)]
    - [Revisiting Batch Normalization For Practical Domain Adaptation (ICLR-W'17)](https://arxiv.org/abs/1603.04779) [[dassl/engine/da/adabn.py](dassl/engine/da/adabn.py)]
    - [Adversarial Discriminative Domain Adaptation (CVPR'17)](https://arxiv.org/abs/1702.05464) [[dassl/engine/da/adda.py](dassl/engine/da/adda.py)]
    - [Domain-Adversarial Training of Neural Networks (JMLR'16) ](https://arxiv.org/abs/1505.07818) [[dassl/engine/da/dann.py](dassl/engine/da/dann.py)]

- Multi-source domain adaptation
    - [Domain Aadaptive Ensemble Learning](https://arxiv.org/abs/2003.07325) [[dassl/engine/da/dael.py](dassl/engine/da/dael.py)]
    - [Moment Matching for Multi-Source Domain Adaptation (ICCV'19)](https://arxiv.org/abs/1812.01754) [[dassl/engine/da/m3sda.py](dassl/engine/da/m3sda.py)]

- Domain generalization
    - [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization (CVPR'22)](https://arxiv.org/abs/2203.07740) [[dassl/modeling/ops/efdmix.py](dassl/modeling/ops/efdmix.py)]
    - [Domain Generalization with MixStyle (ICLR'21)](https://openreview.net/forum?id=6xHJ37MVxxp) [[dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py)]
    - [Deep Domain-Adversarial Image Generation for Domain Generalisation (AAAI'20)](https://arxiv.org/abs/2003.06054) [[dassl/engine/dg/ddaig.py](dassl/engine/dg/ddaig.py)]
    - [Generalizing Across Domains via Cross-Gradient Training (ICLR'18)](https://arxiv.org/abs/1804.10745) [[dassl/engine/dg/crossgrad.py](dassl/engine/dg/crossgrad.py)]

- Semi-supervised learning
    - [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence](https://arxiv.org/abs/2001.07685) [[dassl/engine/ssl/fixmatch.py](dassl/engine/ssl/fixmatch.py)]
    - [MixMatch: A Holistic Approach to Semi-Supervised Learning (NeurIPS'19)](https://arxiv.org/abs/1905.02249) [[dassl/engine/ssl/mixmatch.py](dassl/engine/ssl/mixmatch.py)]
    - [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NeurIPS'17)](https://arxiv.org/abs/1703.01780) [[dassl/engine/ssl/mean_teacher.py](dassl/engine/ssl/mean_teacher.py)]
    - [Semi-supervised Learning by Entropy Minimization (NeurIPS'04)](http://papers.nips.cc/paper/2740-semi-supervised-learning-by-entropy-minimization.pdf) [[dassl/engine/ssl/entmin.py](dassl/engine/ssl/entmin.py)]

*Feel free to make a [PR](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) to add your methods here to make it easier for others to benchmark!*

Dassl supports the following datasets:

- Domain adaptation
    - [Office-31](https://scalable.mpi-inf.mpg.de/files/2013/04/saenko_eccv_2010.pdf)
    - [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)
    - [VisDA17](http://ai.bu.edu/visda-2017/)
    - [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html)-[STL10](https://cs.stanford.edu/~acoates/stl10/)
    - [Digit-5](https://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA/code_MSDA_digit#digit-five-download)
    - [DomainNet](http://ai.bu.edu/M3SDA/)
    - [miniDomainNet](https://arxiv.org/abs/2003.07325)

- Domain generalization
    - [PACS](https://arxiv.org/abs/1710.03077)
    - [VLCS](https://people.csail.mit.edu/torralba/publications/datasets_cvpr11.pdf)
    - [Office-Home](http://hemanthdv.org/OfficeHome-Dataset/)
    - [Digits-DG](https://arxiv.org/abs/2003.06054)
    - [Digit-Single](https://arxiv.org/abs/1805.12018)
    - [CIFAR-10-C](https://arxiv.org/abs/1807.01697)
    - [CIFAR-100-C](https://arxiv.org/abs/1807.01697)

- Semi-supervised learning
    - [CIFAR10/100](https://www.cs.toronto.edu/~kriz/cifar.html.)
    - [SVHN](http://ufldl.stanford.edu/housenumbers/)
    - [STL10](https://cs.stanford.edu/~acoates/stl10/)

## Get started

### Installation

Make sure [conda](https://www.anaconda.com/distribution/) is installed properly.

```bash
# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/

# Create a conda environment
conda create -n dassl python=3.7

# Activate the environment
conda activate dassl

# Install dependencies
pip install -r requirements.txt

# Install torch (version >= 1.7.1) and torchvision
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

# Install this library (no need to re-build if the source code is modified)
python setup.py develop
```

Follow the instructions in [DATASETS.md](./DATASETS.md) to preprocess the datasets.

### Training

The main interface is implemented in `tools/train.py`, which basically does

1. initialize the config with `cfg = setup_cfg(args)` where `args` contains the command-line input (see `tools/train.py` for the list of input arguments);
2. instantiate a `trainer` with `build_trainer(cfg)` which loads the dataset and builds a deep neural network model;
3. call `trainer.train()` for training and evaluating the model.

Below we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31,

```bash
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31
```

`$DATA` denotes the location where datasets are installed. `--dataset-config-file` loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. `--config-file` loads the algorithm-specific setting such as hyper-parameters and optimization parameters.

To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to `--source-domains`. For instance, to train a source-only baseline on miniDomainNet, one can do

```bash
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn
```

After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.

To print out the results saved in the log file (so you do not need to exhaustively go through all log files and calculate the mean/std by yourself), you can use `tools/parse_test_res.py`. The instruction can be found in the code.

For other trainers such as `MCD`, you can set `--trainer MCD` while keeping the config file unchanged, i.e. using the same training parameters as `SourceOnly` (in the simplest case). To modify the hyper-parameters in MCD, like `N_STEP_F` (number of steps to update the feature extractor), you can append `TRAINER.MCD.N_STEP_F 4` to the existing input arguments (otherwise the default value will be used). Alternatively, you can create a new `.yaml` config file to store your custom setting. See [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/config/defaults.py#L176) for a complete list of algorithm-specific hyper-parameters.

### Test
Model testing can be done by using `--eval-only`, which asks the code to run `trainer.test()`. You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use `model.pth.tar-20` saved at `output/source_only_office31/model`, you can do

```bash
CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31_test \
--eval-only \
--model-dir output/source_only_office31 \
--load-epoch 20
```

Note that `--model-dir` takes as input the directory path which was specified in `--output-dir` in the training stage.

### Write a new trainer
A good practice is to go through `dassl/engine/trainer.py` to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass `TrainerXU`. For domain generalization, the new class can subclass `TrainerX`. In particular, `TrainerXU` and `TrainerX` mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the `forward_backward()` method, which performs loss computation and model update. See `dassl/enigne/da/source_only.py` for example.

### Add a new backbone/head/network
`backbone` corresponds to a convolutional neural network model which performs feature extraction. `head` (which is an optional module) is mounted on top of `backbone` for further processing, which can be, for example, a MLP. `backbone` and `head` are basic building blocks for constructing a `SimpleNet()` (see `dassl/engine/trainer.py`) which serves as the primary model for a task. `network` contains custom neural network models, such as an image generator.

To add a new module, namely a backbone/head/network, you need to first register the module using the corresponding `registry`, i.e. `BACKBONE_REGISTRY` for `backbone`, `HEAD_REGISTRY` for `head` and `NETWORK_RESIGTRY` for `network`. Note that for a new `backbone`, we require the model to subclass `Backbone` as defined in `dassl/modeling/backbone/backbone.py` and specify the `self._out_features` attribute.

We provide an example below for how to add a new `backbone`.
```python
from dassl.modeling import Backbone, BACKBONE_REGISTRY

class MyBackbone(Backbone):

    def __init__(self):
        super().__init__()
        # Create layers
        self.conv = ...

        self._out_features = 2048

    def forward(self, x):
        # Extract and return features

@BACKBONE_REGISTRY.register()
def my_backbone(**kwargs):
    return MyBackbone()
```
Then, you can set `MODEL.BACKBONE.NAME` to `my_backbone` to use your own architecture. For more details, please refer to the source code in `dassl/modeling`.

### Add a dataset
An example code structure is shown below. Make sure you subclass `DatasetBase` and register the dataset with `@DATASET_REGISTRY.register()`. All you need is to load `train_x`, `train_u` (optional), `val` (optional) and `test`, among which `train_u` and `val` could be `None` or simply ignored. Each of these variables contains a list of `Datum` objects. A `Datum` object (implemented [here](https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/data/datasets/base_dataset.py#L12)) contains information for a single image, like `impath` (string) and `label` (int).

```python
from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase

@DATASET_REGISTRY.register()
class NewDataset(DatasetBase):

    dataset_dir = ''

    def __init__(self, cfg):
        
        train_x = ...
        train_u = ...  # optional, can be None
        val = ...  # optional, can be None
        test = ...

        super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)
```

We suggest you take a look at the datasets code in some projects like [this](https://github.com/KaiyangZhou/CoOp), which is built on top of Dassl.

## Relevant Research

We would like to share here our research relevant to Dassl.

- [Domain Adaptive Ensemble Learning](https://arxiv.org/abs/2003.07325), TIP, 2021.
- [MixStyle Neural Networks for Domain Generalization and Adaptation](https://arxiv.org/abs/2107.02053), arxiv preprint, 2021.
- [Semi-Supervised Domain Generalization with Stochastic StyleMatch](https://arxiv.org/abs/2106.00592), arxiv preprint, 2021.
- [Domain Generalization in Vision: A Survey](https://arxiv.org/abs/2103.02503), arxiv preprint, 2021.
- [Domain Generalization with MixStyle](https://openreview.net/forum?id=6xHJ37MVxxp), in ICLR 2021.
- [Learning to Generate Novel Domains for Domain Generalization](https://arxiv.org/abs/2007.03304), in ECCV 2020.
- [Deep Domain-Adversarial Image Generation for Domain Generalisation](https://arxiv.org/abs/2003.06054), in AAAI 2020.

## Citation

If you find this code useful to your research, please give credit to the following paper

```
@article{zhou2020domain,
  title={Domain Adaptive Ensemble Learning},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  journal={IEEE Transactions on Image Processing (TIP)},
  year={2021}
}
```


================================================
FILE: Dassl.ProGrad.pytorch/configs/README.md
================================================
The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, data augmentation, network architecture) used by most papers. The `trainers/` folder contains method-specific config files which define optimization algorithms (e.g., optimizer, epoch) and hyperparameter settings.


================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml
================================================
INPUT:
  SIZE: (32, 32)
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]

DATASET:
  NAME: "CIFARSTL"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml
================================================
INPUT:
  SIZE: (32, 32)
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]
  TRANSFORMS: ["normalize"]

DATASET:
  NAME: "Digit5"

MODEL:
  BACKBONE:
    NAME: "cnn_digit5_m3sda"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml
================================================
INPUT:
  SIZE: (224, 224)
  TRANSFORMS: ["random_flip", "random_translation", "normalize"]

DATASET:
  NAME: "DomainNet"

MODEL:
  BACKBONE:
    NAME: "resnet101"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/mini_domainnet.yaml
================================================
INPUT:
  SIZE: (96, 96)
  TRANSFORMS: ["random_flip", "random_translation", "normalize"]

DATASET:
  NAME: "miniDomainNet"

MODEL:
  BACKBONE:
    NAME: "resnet18"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml
================================================
INPUT:
  SIZE: (224, 224)
  TRANSFORMS: ["random_flip", "random_translation", "normalize"]

DATASET:
  NAME: "Office31"

MODEL:
  BACKBONE:
    NAME: "resnet50"
  HEAD:
    NAME: "mlp"
    HIDDEN_LAYERS: [256]
    DROPOUT: 0.

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/office_home.yaml
================================================
INPUT:
  SIZE: (224, 224)

DATASET:
  NAME: "OfficeHome"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml
================================================
INPUT:
  SIZE: (224, 224)
  TRANSFORMS: ["random_flip", "center_crop", "normalize"]

DATASET:
  NAME: "VisDA17"

MODEL:
  BACKBONE:
    NAME: "resnet101"

TEST:
  PER_CLASS_RESULT: True

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml
================================================
INPUT:
  SIZE: (32, 32)
  TRANSFORMS: ["random_flip", "random_crop", "normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]

DATASET:
  NAME: "CIFAR100C"
  CIFAR_C_TYPE: "fog"
  CIFAR_C_LEVEL: 5

MODEL:
  BACKBONE:
    NAME: "wide_resnet_16_4"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml
================================================
INPUT:
  SIZE: (32, 32)
  TRANSFORMS: ["random_flip", "random_crop", "normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]

DATASET:
  NAME: "CIFAR10C"
  CIFAR_C_TYPE: "fog"
  CIFAR_C_LEVEL: 5

MODEL:
  BACKBONE:
    NAME: "wide_resnet_16_4"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml
================================================
INPUT:
  SIZE: (32, 32)
  TRANSFORMS: ["normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]

DATASET:
  NAME: "DigitSingle"

MODEL:
  BACKBONE:
    NAME: "cnn_digitsingle"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml
================================================
INPUT:
  SIZE: (32, 32)
  TRANSFORMS: ["normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]

DATASET:
  NAME: "DigitsDG"

MODEL:
  BACKBONE:
    NAME: "cnn_digitsdg"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/office_home_dg.yaml
================================================
INPUT:
  SIZE: (224, 224)
  TRANSFORMS: ["random_flip", "random_translation", "normalize"]

DATASET:
  NAME: "OfficeHomeDG"

MODEL:
  BACKBONE:
    NAME: "resnet18"
    PRETRAINED: True

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml
================================================
INPUT:
  SIZE: (224, 224)
  TRANSFORMS: ["random_flip", "random_translation", "normalize"]

DATASET:
  NAME: "PACS"

MODEL:
  BACKBONE:
    NAME: "resnet18"
    PRETRAINED: True

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml
================================================
INPUT:
  SIZE: (224, 224)
  TRANSFORMS: ["random_flip", "random_translation", "normalize"]

DATASET:
  NAME: "VLCS"

MODEL:
  BACKBONE:
    NAME: "resnet18"
    PRETRAINED: True

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml
================================================
INPUT:
  SIZE: (32, 32)
  TRANSFORMS: ["random_flip", "random_crop", "normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]

DATASET:
  NAME: "CIFAR10"
  NUM_LABELED: 4000
  VAL_PERCENT: 0.

MODEL:
  BACKBONE:
    NAME: "wide_resnet_28_2"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml
================================================
INPUT:
  SIZE: (32, 32)
  TRANSFORMS: ["random_flip", "random_crop", "normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]
  CROP_PADDING: 4

DATASET:
  NAME: "CIFAR100"
  NUM_LABELED: 10000
  VAL_PERCENT: 0.

MODEL:
  BACKBONE:
    NAME: "wide_resnet_28_2"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml
================================================
INPUT:
  SIZE: (96, 96)
  TRANSFORMS: ["random_flip", "random_crop", "normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]
  CROP_PADDING: 4

DATASET:
  NAME: "STL10"
  STL10_FOLD: 0

MODEL:
  BACKBONE:
    NAME: "wide_resnet_28_2"

================================================
FILE: Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml
================================================
INPUT:
  SIZE: (32, 32)
  TRANSFORMS: ["random_crop", "normalize"]
  PIXEL_MEAN: [0.5, 0.5, 0.5]
  PIXEL_STD: [0.5, 0.5, 0.5]
  CROP_PADDING: 4

DATASET:
  NAME: "SVHN"
  NUM_LABELED: 1000
  VAL_PERCENT: 0.

MODEL:
  BACKBONE:
    NAME: "wide_resnet_28_2"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml
================================================
DATALOADER:
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 256
  TRAIN_U:
    SAME_AS_X: False
    BATCH_SIZE: 64
  TEST:
    BATCH_SIZE: 256

OPTIM:
  NAME: "sgd"
  LR: 0.05
  STEPSIZE: [30]
  MAX_EPOCH: 30
  LR_SCHEDULER: "cosine"

TRAINER:
  DAEL:
    STRONG_TRANSFORMS: ["randaugment2", "normalize"]

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/domainnet.yaml
================================================
DATALOADER:
  NUM_WORKERS: 4
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 30
  TRAIN_U:
    SAME_AS_X: False
    BATCH_SIZE: 6
  TEST:
    BATCH_SIZE: 30

OPTIM:
  NAME: "sgd"
  LR: 0.002
  MAX_EPOCH: 40
  LR_SCHEDULER: "cosine"

TRAINER:
  DAEL:
    STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/dael/mini_domainnet.yaml
================================================
DATALOADER:
  NUM_WORKERS: 8
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 192
  TRAIN_U:
    SAME_AS_X: False
    BATCH_SIZE: 64
  TEST:
    BATCH_SIZE: 200

OPTIM:
  NAME: "sgd"
  LR: 0.005
  MAX_EPOCH: 60
  LR_SCHEDULER: "cosine"

TRAINER:
  DAEL:
    STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml
================================================
DATALOADER:
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 256
  TRAIN_U:
    SAME_AS_X: False
    BATCH_SIZE: 64
  TEST:
    BATCH_SIZE: 256

OPTIM:
  NAME: "sgd"
  LR: 0.05
  STEPSIZE: [30]
  MAX_EPOCH: 30
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/domainnet.yaml
================================================
DATALOADER:
  NUM_WORKERS: 4
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 30
  TRAIN_U:
    SAME_AS_X: False
    BATCH_SIZE: 6
  TEST:
    BATCH_SIZE: 30

OPTIM:
  NAME: "sgd"
  LR: 0.002
  MAX_EPOCH: 40
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/mini_domainnet.yaml
================================================
DATALOADER:
  NUM_WORKERS: 8
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 192
  TRAIN_U:
    SAME_AS_X: False
    BATCH_SIZE: 64
  TEST:
    BATCH_SIZE: 200

OPTIM:
  NAME: "sgd"
  LR: 0.005
  MAX_EPOCH: 60
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/digit5.yaml
================================================
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 256
  TEST:
    BATCH_SIZE: 256

OPTIM:
  NAME: "sgd"
  LR: 0.05
  STEPSIZE: [30]
  MAX_EPOCH: 30
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/mini_domainnet.yaml
================================================
DATALOADER:
  NUM_WORKERS: 8
  TRAIN_X:
    BATCH_SIZE: 128
  TEST:
    BATCH_SIZE: 128

OPTIM:
  NAME: "sgd"
  LR: 0.005
  MAX_EPOCH: 60
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/office31.yaml
================================================
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 32
  TEST:
    BATCH_SIZE: 32

OPTIM:
  NAME: "sgd"
  LR: 0.002
  STEPSIZE: [20]
  MAX_EPOCH: 20

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/da/source_only/visda17.yaml
================================================
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 32
  TEST:
    BATCH_SIZE: 32

OPTIM:
  NAME: "sgd"
  LR: 0.0001
  STEPSIZE: [2]
  MAX_EPOCH: 2

TRAIN:
  PRINT_FREQ: 50
  COUNT_ITER: "train_u"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/digits_dg.yaml
================================================
DATALOADER:
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 120
  TEST:
    BATCH_SIZE: 100

OPTIM:
  NAME: "sgd"
  LR: 0.05
  STEPSIZE: [20]
  MAX_EPOCH: 50

TRAINER:
  DAEL:
    STRONG_TRANSFORMS: ["randaugment2", "normalize"]

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/office_home_dg.yaml
================================================
DATALOADER:
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 30
  TEST:
    BATCH_SIZE: 100

OPTIM:
  NAME: "sgd"
  LR: 0.002
  MAX_EPOCH: 40
  LR_SCHEDULER: "cosine"

TRAINER:
  DAEL:
    STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml
================================================
DATALOADER:
  TRAIN_X:
    SAMPLER: "RandomDomainSampler"
    BATCH_SIZE: 30
  TEST:
    BATCH_SIZE: 100

OPTIM:
  NAME: "sgd"
  LR: 0.002
  MAX_EPOCH: 40
  LR_SCHEDULER: "cosine"

TRAINER:
  DAEL:
    STRONG_TRANSFORMS: ["random_flip", "cutout", "randaugment2", "normalize"]

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/digits_dg.yaml
================================================
INPUT:
  PIXEL_MEAN: [0., 0., 0.]
  PIXEL_STD: [1., 1., 1.]

DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 128
  TEST:
    BATCH_SIZE: 128

OPTIM:
  NAME: "sgd"
  LR: 0.05
  STEPSIZE: [20]
  MAX_EPOCH: 50

TRAINER:
  DDAIG:
    G_ARCH: "fcn_3x32_gctx"
    LMDA: 0.3

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/office_home_dg.yaml
================================================
INPUT:
  PIXEL_MEAN: [0., 0., 0.]
  PIXEL_STD: [1., 1., 1.]

DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 16
  TEST:
    BATCH_SIZE: 16

OPTIM:
  NAME: "sgd"
  LR: 0.0005
  STEPSIZE: [20]
  MAX_EPOCH: 25

TRAINER:
  DDAIG:
    G_ARCH: "fcn_3x64_gctx"
    WARMUP: 3
    LMDA: 0.3

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml
================================================
INPUT:
  PIXEL_MEAN: [0., 0., 0.]
  PIXEL_STD: [1., 1., 1.]

DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 16
  TEST:
    BATCH_SIZE: 16

OPTIM:
  NAME: "sgd"
  LR: 0.0005
  STEPSIZE: [20]
  MAX_EPOCH: 25

TRAINER:
  DDAIG:
    G_ARCH: "fcn_3x64_gctx"
    WARMUP: 3
    LMDA: 0.3

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/digits_dg.yaml
================================================
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 128
  TEST:
    BATCH_SIZE: 100
  NUM_WORKERS: 8

OPTIM:
  NAME: "sgd"
  LR: 0.05
  STEPSIZE: [20]
  MAX_EPOCH: 50

TRAIN:
  PRINT_FREQ: 20

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/mini_domainnet.yaml
================================================
DATALOADER:
  NUM_WORKERS: 8
  TRAIN_X:
    BATCH_SIZE: 128
  TEST:
    BATCH_SIZE: 128

OPTIM:
  NAME: "sgd"
  LR: 0.005
  MAX_EPOCH: 60
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/office_home_dg.yaml
================================================
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 64
  TEST:
    BATCH_SIZE: 100
  NUM_WORKERS: 8

OPTIM:
  NAME: "sgd"
  LR: 0.001
  MAX_EPOCH: 50
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml
================================================
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 64
  TEST:
    BATCH_SIZE: 100
  NUM_WORKERS: 8

OPTIM:
  NAME: "sgd"
  LR: 0.001
  MAX_EPOCH: 50
  LR_SCHEDULER: "cosine"

================================================
FILE: Dassl.ProGrad.pytorch/configs/trainers/ssl/fixmatch/cifar10.yaml
================================================
DATALOADER:
  TRAIN_X:
    BATCH_SIZE: 64
  TRAIN_U:
    SAME_AS_X: False
    BATCH_SIZE: 448
  TEST:
    BATCH_SIZE: 500

OPTIM:
  NAME: "sgd"
  LR: 0.05
  STEPSIZE: [4000]
  MAX_EPOCH: 4000
  LR_SCHEDULER: "cosine"

TRAIN:
  COUNT_ITER: "train_u"
  PRINT_FREQ: 10

TRAINER:
  FIXMATCH:
    STRONG_TRANSFORMS: ["random_flip", "randaugment_fixmatch", "normalize", "cutout"]

================================================
FILE: Dassl.ProGrad.pytorch/dassl/__init__.py
================================================
"""
Dassl
------
PyTorch toolbox for domain adaptation and semi-supervised learning.

URL: https://github.com/KaiyangZhou/Dassl.pytorch

@article{zhou2020domain,
  title={Domain Adaptive Ensemble Learning},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  journal={arXiv preprint arXiv:2003.07325},
  year={2020}
}
"""

__version__ = "0.5.0"
__author__ = "Kaiyang Zhou"
__homepage__ = "https://kaiyangzhou.github.io/"


================================================
FILE: Dassl.ProGrad.pytorch/dassl/config/__init__.py
================================================
from .defaults import _C as cfg_default


def get_cfg_default():
    return cfg_default.clone()


================================================
FILE: Dassl.ProGrad.pytorch/dassl/config/defaults.py
================================================
from yacs.config import CfgNode as CN

###########################
# Config definition
###########################

_C = CN()

_C.VERSION = 1

# Directory to save the output files (like log.txt and model weights)
_C.OUTPUT_DIR = "./output"
# Path to a directory where the files were saved previously
_C.RESUME = ""
# Set seed to negative value to randomize everything
# Set seed to positive value to use a fixed seed
_C.SEED = -1
_C.USE_CUDA = True
# Print detailed information
# E.g. trainer, dataset, and backbone
_C.VERBOSE = True

###########################
# Input
###########################
_C.INPUT = CN()
_C.INPUT.SIZE = (224, 224)
# Mode of interpolation in resize functions
_C.INPUT.INTERPOLATION = "bilinear"
# For available choices please refer to transforms.py
_C.INPUT.TRANSFORMS = ()
# If True, tfm_train and tfm_test will be None
_C.INPUT.NO_TRANSFORM = False
# Default mean and std come from ImageNet
_C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406]
_C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225]
# Padding for random crop
_C.INPUT.CROP_PADDING = 4
# Cutout
_C.INPUT.CUTOUT_N = 1
_C.INPUT.CUTOUT_LEN = 16
# Gaussian noise
_C.INPUT.GN_MEAN = 0.0
_C.INPUT.GN_STD = 0.15
# RandomAugment
_C.INPUT.RANDAUGMENT_N = 2
_C.INPUT.RANDAUGMENT_M = 10
# ColorJitter (brightness, contrast, saturation, hue)
_C.INPUT.COLORJITTER_B = 0.4
_C.INPUT.COLORJITTER_C = 0.4
_C.INPUT.COLORJITTER_S = 0.4
_C.INPUT.COLORJITTER_H = 0.1
# Random gray scale's probability
_C.INPUT.RGS_P = 0.2
# Gaussian blur
_C.INPUT.GB_P = 0.5  # propability of applying this operation
_C.INPUT.GB_K = 21  # kernel size (should be an odd number)

###########################
# Dataset
###########################
_C.DATASET = CN()
# Directory where datasets are stored
_C.DATASET.ROOT = ""
_C.DATASET.NAME = ""
# List of names of source domains
_C.DATASET.SOURCE_DOMAINS = ()
# List of names of target domains
_C.DATASET.TARGET_DOMAINS = ()
# Number of labeled instances in total
# Useful for the semi-supervised learning
_C.DATASET.NUM_LABELED = -1
# Number of images per class
_C.DATASET.NUM_SHOTS = -1
# Percentage of validation data (only used for SSL datasets)
# Set to 0 if do not want to use val data
# Using val data for hyperparameter tuning was done in Oliver et al. 2018
_C.DATASET.VAL_PERCENT = 0.1
# Fold index for STL-10 dataset (normal range is 0 - 9)
# Negative number means None
_C.DATASET.STL10_FOLD = -1
# CIFAR-10/100-C's corruption type and intensity level
_C.DATASET.CIFAR_C_TYPE = ""
_C.DATASET.CIFAR_C_LEVEL = 1
# Use all data in the unlabeled data set (e.g. FixMatch)
_C.DATASET.ALL_AS_UNLABELED = False

###########################
# Dataloader
###########################
_C.DATALOADER = CN()
_C.DATALOADER.NUM_WORKERS = 4
# Apply transformations to an image K times (during training)
_C.DATALOADER.K_TRANSFORMS = 1
# img0 denotes image tensor without augmentation
# Useful for consistency learning
_C.DATALOADER.RETURN_IMG0 = False
# Setting for the train_x data-loader
_C.DATALOADER.TRAIN_X = CN()
_C.DATALOADER.TRAIN_X.SAMPLER = "RandomSampler"
_C.DATALOADER.TRAIN_X.BATCH_SIZE = 32
# Parameter for RandomDomainSampler
# 0 or -1 means sampling from all domains
_C.DATALOADER.TRAIN_X.N_DOMAIN = 0
# Parameter of RandomClassSampler
# Number of instances per class
_C.DATALOADER.TRAIN_X.N_INS = 16

# Setting for the train_u data-loader
_C.DATALOADER.TRAIN_U = CN()
# Set to false if you want to have unique
# data loader params for train_u
_C.DATALOADER.TRAIN_U.SAME_AS_X = True
_C.DATALOADER.TRAIN_U.SAMPLER = "RandomSampler"
_C.DATALOADER.TRAIN_U.BATCH_SIZE = 32
_C.DATALOADER.TRAIN_U.N_DOMAIN = 0
_C.DATALOADER.TRAIN_U.N_INS = 16

# Setting for the test data-loader
_C.DATALOADER.TEST = CN()
_C.DATALOADER.TEST.SAMPLER = "SequentialSampler"
_C.DATALOADER.TEST.BATCH_SIZE = 32

###########################
# Model
###########################
_C.MODEL = CN()
# Path to model weights (for initialization)
_C.MODEL.INIT_WEIGHTS = ""
_C.MODEL.BACKBONE = CN()
_C.MODEL.BACKBONE.NAME = ""
_C.MODEL.BACKBONE.PRETRAINED = True
# Definition of embedding layers
_C.MODEL.HEAD = CN()
# If none, do not construct embedding layers, the
# backbone's output will be passed to the classifier
_C.MODEL.HEAD.NAME = ""
# Structure of hidden layers (a list), e.g. [512, 512]
# If undefined, no embedding layer will be constructed
_C.MODEL.HEAD.HIDDEN_LAYERS = ()
_C.MODEL.HEAD.ACTIVATION = "relu"
_C.MODEL.HEAD.BN = True
_C.MODEL.HEAD.DROPOUT = 0.0

###########################
# Optimization
###########################
_C.OPTIM = CN()
_C.OPTIM.NAME = "adam"
_C.OPTIM.LR = 0.0003
_C.OPTIM.WEIGHT_DECAY = 5e-4
_C.OPTIM.MOMENTUM = 0.9
_C.OPTIM.SGD_DAMPNING = 0
_C.OPTIM.SGD_NESTEROV = False
_C.OPTIM.RMSPROP_ALPHA = 0.99
_C.OPTIM.ADAM_BETA1 = 0.9
_C.OPTIM.ADAM_BETA2 = 0.999
# STAGED_LR allows different layers to have
# different lr, e.g. pre-trained base layers
# can be assigned a smaller lr than the new
# classification layer
_C.OPTIM.STAGED_LR = False
_C.OPTIM.NEW_LAYERS = ()
_C.OPTIM.BASE_LR_MULT = 0.1
# Learning rate scheduler
_C.OPTIM.LR_SCHEDULER = "single_step"
# -1 or 0 means the stepsize is equal to max_epoch
_C.OPTIM.STEPSIZE = (-1, )
_C.OPTIM.GAMMA = 0.1
_C.OPTIM.MAX_EPOCH = 10
# Set WARMUP_EPOCH larger than 0 to activate warmup training
_C.OPTIM.WARMUP_EPOCH = -1
# Either linear or constant
_C.OPTIM.WARMUP_TYPE = "linear"
# Constant learning rate when type=constant
_C.OPTIM.WARMUP_CONS_LR = 1e-5
# Minimum learning rate when type=linear
_C.OPTIM.WARMUP_MIN_LR = 1e-5
# Recount epoch for the next scheduler (last_epoch=-1)
# Otherwise last_epoch=warmup_epoch
_C.OPTIM.WARMUP_RECOUNT = True

###########################
# Train
###########################
_C.TRAIN = CN()
# How often (epoch) to save model during training
# Set to 0 or negative value to only save the last one
_C.TRAIN.CHECKPOINT_FREQ = 0
# How often (batch) to print training information
_C.TRAIN.PRINT_FREQ = 10
# Use 'train_x', 'train_u' or 'smaller_one' to count
# the number of iterations in an epoch (for DA and SSL)
_C.TRAIN.COUNT_ITER = "train_x"

###########################
# Test
###########################
_C.TEST = CN()
_C.TEST.EVALUATOR = "Classification"
_C.TEST.PER_CLASS_RESULT = False
# Compute confusion matrix, which will be saved
# to $OUTPUT_DIR/cmat.pt
_C.TEST.COMPUTE_CMAT = False
# If NO_TEST=True, no testing will be conducted
_C.TEST.NO_TEST = False
# Use test or val set for FINAL evaluation
_C.TEST.SPLIT = "test"
# Which model to test after training
# Either last_step or best_val
_C.TEST.FINAL_MODEL = "last_step"

###########################
# Trainer specifics
###########################
_C.TRAINER = CN()
_C.TRAINER.NAME = ""

# MCD
_C.TRAINER.MCD = CN()
_C.TRAINER.MCD.N_STEP_F = 4  # number of steps to train F
# MME
_C.TRAINER.MME = CN()
_C.TRAINER.MME.LMDA = 0.1  # weight for the entropy loss
# SelfEnsembling
_C.TRAINER.SE = CN()
_C.TRAINER.SE.EMA_ALPHA = 0.999
_C.TRAINER.SE.CONF_THRE = 0.95
_C.TRAINER.SE.RAMPUP = 300

# M3SDA
_C.TRAINER.M3SDA = CN()
_C.TRAINER.M3SDA.LMDA = 0.5  # weight for the moment distance loss
_C.TRAINER.M3SDA.N_STEP_F = 4  # follow MCD
# DAEL
_C.TRAINER.DAEL = CN()
_C.TRAINER.DAEL.WEIGHT_U = 0.5  # weight on the unlabeled loss
_C.TRAINER.DAEL.CONF_THRE = 0.95  # confidence threshold
_C.TRAINER.DAEL.STRONG_TRANSFORMS = ()

# CrossGrad
_C.TRAINER.CG = CN()
_C.TRAINER.CG.EPS_F = 1.0  # scaling parameter for D's gradients
_C.TRAINER.CG.EPS_D = 1.0  # scaling parameter for F's gradients
_C.TRAINER.CG.ALPHA_F = 0.5  # balancing weight for the label net's loss
_C.TRAINER.CG.ALPHA_D = 0.5  # balancing weight for the domain net's loss
# DDAIG
_C.TRAINER.DDAIG = CN()
_C.TRAINER.DDAIG.G_ARCH = ""  # generator's architecture
_C.TRAINER.DDAIG.LMDA = 0.3  # perturbation weight
_C.TRAINER.DDAIG.CLAMP = False  # clamp perturbation values
_C.TRAINER.DDAIG.CLAMP_MIN = -1.0
_C.TRAINER.DDAIG.CLAMP_MAX = 1.0
_C.TRAINER.DDAIG.WARMUP = 0
_C.TRAINER.DDAIG.ALPHA = 0.5  # balancing weight for the losses

# EntMin
_C.TRAINER.ENTMIN = CN()
_C.TRAINER.ENTMIN.LMDA = 1e-3  # weight on the entropy loss
# Mean Teacher
_C.TRAINER.MEANTEA = CN()
_C.TRAINER.MEANTEA.WEIGHT_U = 1.0  # weight on the unlabeled loss
_C.TRAINER.MEANTEA.EMA_ALPHA = 0.999
_C.TRAINER.MEANTEA.RAMPUP = 5  # epochs used to ramp up the loss_u weight
# MixMatch
_C.TRAINER.MIXMATCH = CN()
_C.TRAINER.MIXMATCH.WEIGHT_U = 100.0  # weight on the unlabeled loss
_C.TRAINER.MIXMATCH.TEMP = 2.0  # temperature for sharpening the probability
_C.TRAINER.MIXMATCH.MIXUP_BETA = 0.75
_C.TRAINER.MIXMATCH.RAMPUP = 20000  # steps used to ramp up the loss_u weight
# FixMatch
_C.TRAINER.FIXMATCH = CN()
_C.TRAINER.FIXMATCH.WEIGHT_U = 1.0  # weight on the unlabeled loss
_C.TRAINER.FIXMATCH.CONF_THRE = 0.95  # confidence threshold
_C.TRAINER.FIXMATCH.STRONG_TRANSFORMS = ()


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/__init__.py
================================================
from .data_manager import DataManager, DatasetWrapper


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/data_manager.py
================================================
import torch
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import Dataset as TorchDataset

from dassl.utils import read_image

from .datasets import build_dataset
from .samplers import build_sampler
from .transforms import build_transform

INTERPOLATION_MODES = {
    "bilinear": Image.BILINEAR,
    "bicubic": Image.BICUBIC,
    "nearest": Image.NEAREST,
}


def build_data_loader(
    cfg,
    sampler_type="SequentialSampler",
    data_source=None,
    batch_size=64,
    n_domain=0,
    n_ins=2,
    tfm=None,
    is_train=True,
    dataset_wrapper=None,
):
    # Build sampler
    sampler = build_sampler(
        sampler_type,
        cfg=cfg,
        data_source=data_source,
        batch_size=batch_size,
        n_domain=n_domain,
        n_ins=n_ins,
    )

    if dataset_wrapper is None:
        dataset_wrapper = DatasetWrapper

    # Build data loader
    data_loader = torch.utils.data.DataLoader(
        dataset_wrapper(cfg, data_source, transform=tfm, is_train=is_train),
        batch_size=batch_size,
        sampler=sampler,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        drop_last=is_train and len(data_source) >= batch_size,
        pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA),
    )
    assert len(data_loader) > 0

    return data_loader


class DataManager:

    def __init__(
        self,
        cfg,
        custom_tfm_train=None,
        custom_tfm_test=None,
        dataset_wrapper=None
    ):
        # Load dataset
        dataset = build_dataset(cfg)
        # Build transform
        if custom_tfm_train is None:
            tfm_train = build_transform(cfg, is_train=True)
        else:
            print("* Using custom transform for training")
            tfm_train = custom_tfm_train

        if custom_tfm_test is None:
            tfm_test = build_transform(cfg, is_train=False)
        else:
            print("* Using custom transform for testing")
            tfm_test = custom_tfm_test

        # Build train_loader_x
        train_loader_x = build_data_loader(
            cfg,
            sampler_type=cfg.DATALOADER.TRAIN_X.SAMPLER,
            data_source=dataset.train_x,
            batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
            n_domain=cfg.DATALOADER.TRAIN_X.N_DOMAIN,
            n_ins=cfg.DATALOADER.TRAIN_X.N_INS,
            tfm=tfm_train,
            is_train=True,
            dataset_wrapper=dataset_wrapper,
        )

        # Build train_loader_u
        train_loader_u = None
        if dataset.train_u:
            sampler_type_ = cfg.DATALOADER.TRAIN_U.SAMPLER
            batch_size_ = cfg.DATALOADER.TRAIN_U.BATCH_SIZE
            n_domain_ = cfg.DATALOADER.TRAIN_U.N_DOMAIN
            n_ins_ = cfg.DATALOADER.TRAIN_U.N_INS

            if cfg.DATALOADER.TRAIN_U.SAME_AS_X:
                sampler_type_ = cfg.DATALOADER.TRAIN_X.SAMPLER
                batch_size_ = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
                n_domain_ = cfg.DATALOADER.TRAIN_X.N_DOMAIN
                n_ins_ = cfg.DATALOADER.TRAIN_X.N_INS

            train_loader_u = build_data_loader(
                cfg,
                sampler_type=sampler_type_,
                data_source=dataset.train_u,
                batch_size=batch_size_,
                n_domain=n_domain_,
                n_ins=n_ins_,
                tfm=tfm_train,
                is_train=True,
                dataset_wrapper=dataset_wrapper,
            )

        # Build val_loader
        val_loader = None
        if dataset.val:
            val_loader = build_data_loader(
                cfg,
                sampler_type=cfg.DATALOADER.TEST.SAMPLER,
                data_source=dataset.val,
                batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
                tfm=tfm_test,
                is_train=False,
                dataset_wrapper=dataset_wrapper,
            )

        # Build test_loader
        test_loader = build_data_loader(
            cfg,
            sampler_type=cfg.DATALOADER.TEST.SAMPLER,
            data_source=dataset.test,
            batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
            tfm=tfm_test,
            is_train=False,
            dataset_wrapper=dataset_wrapper,
        )

        # Attributes
        self._num_classes = dataset.num_classes
        self._num_source_domains = len(cfg.DATASET.SOURCE_DOMAINS)
        self._lab2cname = dataset.lab2cname

        # Dataset and data-loaders
        self.dataset = dataset
        self.train_loader_x = train_loader_x
        self.train_loader_u = train_loader_u
        self.val_loader = val_loader
        self.test_loader = test_loader

        if cfg.VERBOSE:
            self.show_dataset_summary(cfg)

    @property
    def num_classes(self):
        return self._num_classes

    @property
    def num_source_domains(self):
        return self._num_source_domains

    @property
    def lab2cname(self):
        return self._lab2cname

    def show_dataset_summary(self, cfg):
        print("***** Dataset statistics *****")

        print("  Dataset: {}".format(cfg.DATASET.NAME))

        if cfg.DATASET.SOURCE_DOMAINS:
            print("  Source domains: {}".format(cfg.DATASET.SOURCE_DOMAINS))
        if cfg.DATASET.TARGET_DOMAINS:
            print("  Target domains: {}".format(cfg.DATASET.TARGET_DOMAINS))

        print("  # classes: {:,}".format(self.num_classes))

        print("  # train_x: {:,}".format(len(self.dataset.train_x)))

        if self.dataset.train_u:
            print("  # train_u: {:,}".format(len(self.dataset.train_u)))

        if self.dataset.val:
            print("  # val: {:,}".format(len(self.dataset.val)))

        print("  # test: {:,}".format(len(self.dataset.test)))


class DatasetWrapper(TorchDataset):

    def __init__(self, cfg, data_source, transform=None, is_train=False):
        self.cfg = cfg
        self.data_source = data_source
        self.transform = transform  # accept list (tuple) as input
        self.is_train = is_train
        # Augmenting an image K>1 times is only allowed during training
        self.k_tfm = cfg.DATALOADER.K_TRANSFORMS if is_train else 1
        self.return_img0 = cfg.DATALOADER.RETURN_IMG0

        if self.k_tfm > 1 and transform is None:
            raise ValueError(
                "Cannot augment the image {} times "
                "because transform is None".format(self.k_tfm)
            )

        # Build transform that doesn't apply any data augmentation
        interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]
        to_tensor = []
        to_tensor += [T.Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]
        to_tensor += [T.ToTensor()]
        if "normalize" in cfg.INPUT.TRANSFORMS:
            normalize = T.Normalize(
                mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
            )
            to_tensor += [normalize]
        self.to_tensor = T.Compose(to_tensor)

    def __len__(self):
        return len(self.data_source)

    def __getitem__(self, idx):
        item = self.data_source[idx]

        output = {
            "label": item.label,
            "domain": item.domain,
            "impath": item.impath
        }

        img0 = read_image(item.impath)

        if self.transform is not None:
            if isinstance(self.transform, (list, tuple)):
                for i, tfm in enumerate(self.transform):
                    img = self._transform_image(tfm, img0)
                    keyname = "img"
                    if (i + 1) > 1:
                        keyname += str(i + 1)
                    output[keyname] = img
            else:
                img = self._transform_image(self.transform, img0)
                output["img"] = img

        if self.return_img0:
            output["img0"] = self.to_tensor(img0)

        return output

    def _transform_image(self, tfm, img0):
        img_list = []

        for k in range(self.k_tfm):
            img_list.append(tfm(img0))

        img = img_list
        if len(img) == 1:
            img = img[0]

        return img


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py
================================================
from .build import DATASET_REGISTRY, build_dataset  # isort:skip
from .base_dataset import Datum, DatasetBase  # isort:skip

from .da import *
from .dg import *
from .ssl import *


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
================================================
import os
import random
import os.path as osp
import tarfile
import zipfile
from collections import defaultdict
import gdown

from dassl.utils import check_isfile


class Datum:
    """Data instance which defines the basic attributes.

    Args:
        impath (str): image path.
        label (int): class label.
        domain (int): domain label.
        classname (str): class name.
    """

    def __init__(self, impath="", label=0, domain=0, classname=""):
        assert isinstance(impath, str)
        assert check_isfile(impath)

        self._impath = impath
        self._label = label
        self._domain = domain
        self._classname = classname

    @property
    def impath(self):
        return self._impath

    @property
    def label(self):
        return self._label

    @property
    def domain(self):
        return self._domain

    @property
    def classname(self):
        return self._classname


class DatasetBase:
    """A unified dataset class for
    1) domain adaptation
    2) domain generalization
    3) semi-supervised learning
    """

    dataset_dir = ""  # the directory where the dataset is stored
    domains = []  # string names of all domains

    def __init__(self, train_x=None, train_u=None, val=None, test=None):
        self._train_x = train_x  # labeled training data
        self._train_u = train_u  # unlabeled training data (optional)
        self._val = val  # validation data (optional)
        self._test = test  # test data

        self._num_classes = self.get_num_classes(train_x)
        self._lab2cname, self._classnames = self.get_lab2cname(train_x)

    @property
    def train_x(self):
        return self._train_x

    @property
    def train_u(self):
        return self._train_u

    @property
    def val(self):
        return self._val

    @property
    def test(self):
        return self._test

    @property
    def lab2cname(self):
        return self._lab2cname

    @property
    def classnames(self):
        return self._classnames

    @property
    def num_classes(self):
        return self._num_classes

    def get_num_classes(self, data_source):
        """Count number of classes.

        Args:
            data_source (list): a list of Datum objects.
        """
        label_set = set()
        for item in data_source:
            label_set.add(item.label)
        return max(label_set) + 1

    def get_lab2cname(self, data_source):
        """Get a label-to-classname mapping (dict).

        Args:
            data_source (list): a list of Datum objects.
        """
        container = set()
        for item in data_source:
            container.add((item.label, item.classname))
        mapping = {label: classname for label, classname in container}
        labels = list(mapping.keys())
        labels.sort()
        classnames = [mapping[label] for label in labels]
        return mapping, classnames

    def check_input_domains(self, source_domains, target_domains):
        self.is_input_domain_valid(source_domains)
        self.is_input_domain_valid(target_domains)

    def is_input_domain_valid(self, input_domains):
        for domain in input_domains:
            if domain not in self.domains:
                raise ValueError(
                    "Input domain must belong to {}, "
                    "but got [{}]".format(self.domains, domain)
                )

    def download_data(self, url, dst, from_gdrive=True):
        if not osp.exists(osp.dirname(dst)):
            os.makedirs(osp.dirname(dst))

        if from_gdrive:
            gdown.download(url, dst, quiet=False)
        else:
            raise NotImplementedError

        print("Extracting file ...")

        try:
            tar = tarfile.open(dst)
            tar.extractall(path=osp.dirname(dst))
            tar.close()
        except:
            zip_ref = zipfile.ZipFile(dst, "r")
            zip_ref.extractall(osp.dirname(dst))
            zip_ref.close()

        print("File extracted to {}".format(osp.dirname(dst)))

    def generate_fewshot_dataset(
        self, *data_sources, num_shots=-1, repeat=False
    ):
        """Generate a few-shot dataset (typically for the training set).

        This function is useful when one wants to evaluate a model
        in a few-shot learning setting where each class only contains
        a few number of images.

        Args:
            data_sources: each individual is a list containing Datum objects.
            num_shots (int): number of instances per class to sample.
            repeat (bool): repeat images if needed (default: False).
        """
        if num_shots < 1:
            if len(data_sources) == 1:
                return data_sources[0]
            return data_sources

        print(f"Creating a {num_shots}-shot dataset")

        output = []

        for data_source in data_sources:
            tracker = self.split_dataset_by_label(data_source)
            dataset = []

            for label, items in tracker.items():
                if len(items) >= num_shots:
                    sampled_items = random.sample(items, num_shots)
                else:
                    if repeat:
                        sampled_items = random.choices(items, k=num_shots)
                    else:
                        sampled_items = items
                dataset.extend(sampled_items)

            output.append(dataset)

        if len(output) == 1:
            return output[0]

        return output

    def split_dataset_by_label(self, data_source):
        """Split a dataset, i.e. a list of Datum objects,
        into class-specific groups stored in a dictionary.

        Args:
            data_source (list): a list of Datum objects.
        """
        output = defaultdict(list)

        for item in data_source:
            output[item.label].append(item)

        return output

    def split_dataset_by_domain(self, data_source):
        """Split a dataset, i.e. a list of Datum objects,
        into domain-specific groups stored in a dictionary.

        Args:
            data_source (list): a list of Datum objects.
        """
        output = defaultdict(list)

        for item in data_source:
            output[item.domain].append(item)

        return output


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
================================================
from dassl.utils import Registry, check_availability

DATASET_REGISTRY = Registry("DATASET")


def build_dataset(cfg):
    avai_datasets = DATASET_REGISTRY.registered_names()
    check_availability(cfg.DATASET.NAME, avai_datasets)
    if cfg.VERBOSE:
        print("Loading dataset: {}".format(cfg.DATASET.NAME))
    return DATASET_REGISTRY.get(cfg.DATASET.NAME)(cfg)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py
================================================
from .digit5 import Digit5
from .visda17 import VisDA17
from .cifarstl import CIFARSTL
from .office31 import Office31
from .domainnet import DomainNet
from .office_home import OfficeHome
from .mini_domainnet import miniDomainNet


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
================================================
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class CIFARSTL(DatasetBase):
    """CIFAR-10 and STL-10.

    CIFAR-10:
        - 60,000 32x32 colour images.
        - 10 classes, with 6,000 images per class.
        - 50,000 training images and 10,000 test images.
        - URL: https://www.cs.toronto.edu/~kriz/cifar.html.

    STL-10:
        - 10 classes: airplane, bird, car, cat, deer, dog, horse,
        monkey, ship, truck.
        - Images are 96x96 pixels, color.
        - 500 training images (10 pre-defined folds), 800 test images
        per class.
        - URL: https://cs.stanford.edu/~acoates/stl10/.

    Reference:
        - Krizhevsky. Learning Multiple Layers of Features
        from Tiny Images. Tech report.
        - Coates et al. An Analysis of Single Layer Networks in
        Unsupervised Feature Learning. AISTATS 2011.
    """

    dataset_dir = "cifar_stl"
    domains = ["cifar", "stl"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data(self, input_domains, split="train"):
        items = []

        for domain, dname in enumerate(input_domains):
            data_dir = osp.join(self.dataset_dir, dname, split)
            class_names = listdir_nohidden(data_dir)

            for class_name in class_names:
                class_dir = osp.join(data_dir, class_name)
                imnames = listdir_nohidden(class_dir)
                label = int(class_name.split("_")[0])

                for imname in imnames:
                    impath = osp.join(class_dir, imname)
                    item = Datum(impath=impath, label=label, domain=domain)
                    items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
================================================
import random
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase

# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}


def read_image_list(im_dir, n_max=None, n_repeat=None):
    items = []

    for imname in listdir_nohidden(im_dir):
        imname_noext = osp.splitext(imname)[0]
        label = int(imname_noext.split("_")[1])
        impath = osp.join(im_dir, imname)
        items.append((impath, label))

    if n_max is not None:
        items = random.sample(items, n_max)

    if n_repeat is not None:
        items *= n_repeat

    return items


def load_mnist(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, MNIST[split])
    n_max = 25000 if split == "train" else 9000
    return read_image_list(data_dir, n_max=n_max)


def load_mnist_m(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, MNIST_M[split])
    n_max = 25000 if split == "train" else 9000
    return read_image_list(data_dir, n_max=n_max)


def load_svhn(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, SVHN[split])
    n_max = 25000 if split == "train" else 9000
    return read_image_list(data_dir, n_max=n_max)


def load_syn(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, SYN[split])
    n_max = 25000 if split == "train" else 9000
    return read_image_list(data_dir, n_max=n_max)


def load_usps(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, USPS[split])
    n_repeat = 3 if split == "train" else None
    return read_image_list(data_dir, n_repeat=n_repeat)


@DATASET_REGISTRY.register()
class Digit5(DatasetBase):
    """Five digit datasets.

    It contains:
        - MNIST: hand-written digits.
        - MNIST-M: variant of MNIST with blended background.
        - SVHN: street view house number.
        - SYN: synthetic digits.
        - USPS: hand-written digits, slightly different from MNIST.

    For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from
    the training set and 9,000 images from the test set. For USPS which has only
    9,298 images in total, we use the entire dataset but replicate its training
    set for 3 times so as to match the training set size of other domains.

    Reference:
        - Lecun et al. Gradient-based learning applied to document
        recognition. IEEE 1998.
        - Ganin et al. Domain-adversarial training of neural networks.
        JMLR 2016.
        - Netzer et al. Reading digits in natural images with unsupervised
        feature learning. NIPS-W 2011.
    """

    dataset_dir = "digit5"
    domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data(self, input_domains, split="train"):
        items = []

        for domain, dname in enumerate(input_domains):
            func = "load_" + dname
            domain_dir = osp.join(self.dataset_dir, dname)
            items_d = eval(func)(domain_dir, split=split)

            for impath, label in items_d:
                item = Datum(
                    impath=impath,
                    label=label,
                    domain=domain,
                    classname=str(label)
                )
                items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
================================================
import os.path as osp

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class DomainNet(DatasetBase):
    """DomainNet.

    Statistics:
        - 6 distinct domains: Clipart, Infograph, Painting, Quickdraw,
        Real, Sketch.
        - Around 0.6M images.
        - 345 categories.
        - URL: http://ai.bu.edu/M3SDA/.

    Special note: the t-shirt class (327) is missing in painting_train.txt.

    Reference:
        - Peng et al. Moment Matching for Multi-Source Domain
        Adaptation. ICCV 2019.
    """

    dataset_dir = "domainnet"
    domains = [
        "clipart", "infograph", "painting", "quickdraw", "real", "sketch"
    ]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.split_dir = osp.join(self.dataset_dir, "splits")

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")

        super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)

    def _read_data(self, input_domains, split="train"):
        items = []

        for domain, dname in enumerate(input_domains):
            filename = dname + "_" + split + ".txt"
            split_file = osp.join(self.split_dir, filename)

            with open(split_file, "r") as f:
                lines = f.readlines()
                for line in lines:
                    line = line.strip()
                    impath, label = line.split(" ")
                    classname = impath.split("/")[1]
                    impath = osp.join(self.dataset_dir, impath)
                    label = int(label)
                    item = Datum(
                        impath=impath,
                        label=label,
                        domain=domain,
                        classname=classname
                    )
                    items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/mini_domainnet.py
================================================
import os.path as osp

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class miniDomainNet(DatasetBase):
    """A subset of DomainNet.

    Reference:
        - Peng et al. Moment Matching for Multi-Source Domain
        Adaptation. ICCV 2019.
        - Zhou et al. Domain Adaptive Ensemble Learning.
    """

    dataset_dir = "domainnet"
    domains = ["clipart", "painting", "real", "sketch"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.split_dir = osp.join(self.dataset_dir, "splits_mini")

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data(self, input_domains, split="train"):
        items = []

        for domain, dname in enumerate(input_domains):
            filename = dname + "_" + split + ".txt"
            split_file = osp.join(self.split_dir, filename)

            with open(split_file, "r") as f:
                lines = f.readlines()
                for line in lines:
                    line = line.strip()
                    impath, label = line.split(" ")
                    classname = impath.split("/")[1]
                    impath = osp.join(self.dataset_dir, impath)
                    label = int(label)
                    item = Datum(
                        impath=impath,
                        label=label,
                        domain=domain,
                        classname=classname
                    )
                    items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
================================================
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class Office31(DatasetBase):
    """Office-31.

    Statistics:
        - 4,110 images.
        - 31 classes related to office objects.
        - 3 domains: Amazon, Webcam, Dslr.
        - URL: https://people.eecs.berkeley.edu/~jhoffman/domainadapt/.

    Reference:
        - Saenko et al. Adapting visual category models to
        new domains. ECCV 2010.
    """

    dataset_dir = "office31"
    domains = ["amazon", "webcam", "dslr"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS)

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data(self, input_domains):
        items = []

        for domain, dname in enumerate(input_domains):
            domain_dir = osp.join(self.dataset_dir, dname)
            class_names = listdir_nohidden(domain_dir)
            class_names.sort()

            for label, class_name in enumerate(class_names):
                class_path = osp.join(domain_dir, class_name)
                imnames = listdir_nohidden(class_path)

                for imname in imnames:
                    impath = osp.join(class_path, imname)
                    item = Datum(
                        impath=impath,
                        label=label,
                        domain=domain,
                        classname=class_name
                    )
                    items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
================================================
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class OfficeHome(DatasetBase):
    """Office-Home.

    Statistics:
        - Around 15,500 images.
        - 65 classes related to office and home objects.
        - 4 domains: Art, Clipart, Product, Real World.
        - URL: http://hemanthdv.org/OfficeHome-Dataset/.

    Reference:
        - Venkateswara et al. Deep Hashing Network for Unsupervised
        Domain Adaptation. CVPR 2017.
    """

    dataset_dir = "office_home"
    domains = ["art", "clipart", "product", "real_world"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS)
        train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS)
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS)

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data(self, input_domains):
        items = []

        for domain, dname in enumerate(input_domains):
            domain_dir = osp.join(self.dataset_dir, dname)
            class_names = listdir_nohidden(domain_dir)
            class_names.sort()

            for label, class_name in enumerate(class_names):
                class_path = osp.join(domain_dir, class_name)
                imnames = listdir_nohidden(class_path)

                for imname in imnames:
                    impath = osp.join(class_path, imname)
                    item = Datum(
                        impath=impath,
                        label=label,
                        domain=domain,
                        classname=class_name.lower(),
                    )
                    items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
================================================
import os.path as osp

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class VisDA17(DatasetBase):
    """VisDA17.

    Focusing on simulation-to-reality domain shift.

    URL: http://ai.bu.edu/visda-2017/.

    Reference:
        - Peng et al. VisDA: The Visual Domain Adaptation
        Challenge. ArXiv 2017.
    """

    dataset_dir = "visda17"
    domains = ["synthetic", "real"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train_x = self._read_data("synthetic")
        train_u = self._read_data("real")
        test = self._read_data("real")

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data(self, dname):
        filedir = "train" if dname == "synthetic" else "validation"
        image_list = osp.join(self.dataset_dir, filedir, "image_list.txt")
        items = []
        # There is only one source domain
        domain = 0

        with open(image_list, "r") as f:
            lines = f.readlines()

            for line in lines:
                line = line.strip()
                impath, label = line.split(" ")
                classname = impath.split("/")[0]
                impath = osp.join(self.dataset_dir, filedir, impath)
                label = int(label)
                item = Datum(
                    impath=impath,
                    label=label,
                    domain=domain,
                    classname=classname
                )
                items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py
================================================
from .pacs import PACS
from .vlcs import VLCS
from .cifar_c import CIFAR10C, CIFAR100C
from .digits_dg import DigitsDG
from .digit_single import DigitSingle
from .office_home_dg import OfficeHomeDG


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
================================================
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase

AVAI_C_TYPES = [
    "brightness",
    "contrast",
    "defocus_blur",
    "elastic_transform",
    "fog",
    "frost",
    "gaussian_blur",
    "gaussian_noise",
    "glass_blur",
    "impulse_noise",
    "jpeg_compression",
    "motion_blur",
    "pixelate",
    "saturate",
    "shot_noise",
    "snow",
    "spatter",
    "speckle_noise",
    "zoom_blur",
]


@DATASET_REGISTRY.register()
class CIFAR10C(DatasetBase):
    """CIFAR-10 -> CIFAR-10-C.

    Dataset link: https://zenodo.org/record/2535967#.YFwtV2Qzb0o

    Statistics:
        - 2 domains: the normal CIFAR-10 vs. a corrupted CIFAR-10
        - 10 categories

    Reference:
        - Hendrycks et al. Benchmarking neural network robustness
        to common corruptions and perturbations. ICLR 2019.
    """

    dataset_dir = ""
    domains = ["cifar10", "cifar10_c"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = root

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )
        source_domain = cfg.DATASET.SOURCE_DOMAINS[0]
        target_domain = cfg.DATASET.TARGET_DOMAINS[0]
        assert source_domain == self.domains[0]
        assert target_domain == self.domains[1]

        c_type = cfg.DATASET.CIFAR_C_TYPE
        c_level = cfg.DATASET.CIFAR_C_LEVEL

        if not c_type:
            raise ValueError(
                "Please specify DATASET.CIFAR_C_TYPE in the config file"
            )

        assert (
            c_type in AVAI_C_TYPES
        ), f'C_TYPE is expected to belong to {AVAI_C_TYPES}, but got "{c_type}"'
        assert 1 <= c_level <= 5

        train_dir = osp.join(self.dataset_dir, source_domain, "train")
        test_dir = osp.join(
            self.dataset_dir, target_domain, c_type, str(c_level)
        )

        if not osp.exists(test_dir):
            raise ValueError

        train = self._read_data(train_dir)
        test = self._read_data(test_dir)

        super().__init__(train_x=train, test=test)

    def _read_data(self, data_dir):
        class_names = listdir_nohidden(data_dir)
        class_names.sort()
        items = []

        for label, class_name in enumerate(class_names):
            class_dir = osp.join(data_dir, class_name)
            imnames = listdir_nohidden(class_dir)

            for imname in imnames:
                impath = osp.join(class_dir, imname)
                item = Datum(impath=impath, label=label, domain=0)
                items.append(item)

        return items


@DATASET_REGISTRY.register()
class CIFAR100C(CIFAR10C):
    """CIFAR-100 -> CIFAR-100-C.

    Dataset link: https://zenodo.org/record/3555552#.YFxpQmQzb0o

    Statistics:
        - 2 domains: the normal CIFAR-100 vs. a corrupted CIFAR-100
        - 10 categories

    Reference:
        - Hendrycks et al. Benchmarking neural network robustness
        to common corruptions and perturbations. ICLR 2019.
    """

    dataset_dir = ""
    domains = ["cifar100", "cifar100_c"]

    def __init__(self, cfg):
        super().__init__(cfg)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
================================================
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase

# Folder names for train and test sets
MNIST = {"train": "train_images", "test": "test_images"}
MNIST_M = {"train": "train_images", "test": "test_images"}
SVHN = {"train": "train_images", "test": "test_images"}
SYN = {"train": "train_images", "test": "test_images"}
USPS = {"train": "train_images", "test": "test_images"}


def read_image_list(im_dir, n_max=None, n_repeat=None):
    items = []

    for imname in listdir_nohidden(im_dir):
        imname_noext = osp.splitext(imname)[0]
        label = int(imname_noext.split("_")[1])
        impath = osp.join(im_dir, imname)
        items.append((impath, label))

    if n_max is not None:
        # Note that the sampling process is NOT random,
        # which follows that in Volpi et al. NIPS'18.
        items = items[:n_max]

    if n_repeat is not None:
        items *= n_repeat

    return items


def load_mnist(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, MNIST[split])
    n_max = 10000 if split == "train" else None
    return read_image_list(data_dir, n_max=n_max)


def load_mnist_m(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, MNIST_M[split])
    n_max = 10000 if split == "train" else None
    return read_image_list(data_dir, n_max=n_max)


def load_svhn(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, SVHN[split])
    n_max = 10000 if split == "train" else None
    return read_image_list(data_dir, n_max=n_max)


def load_syn(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, SYN[split])
    n_max = 10000 if split == "train" else None
    return read_image_list(data_dir, n_max=n_max)


def load_usps(dataset_dir, split="train"):
    data_dir = osp.join(dataset_dir, USPS[split])
    return read_image_list(data_dir)


@DATASET_REGISTRY.register()
class DigitSingle(DatasetBase):
    """Digit recognition datasets for single-source domain generalization.

    There are five digit datasets:
        - MNIST: hand-written digits.
        - MNIST-M: variant of MNIST with blended background.
        - SVHN: street view house number.
        - SYN: synthetic digits.
        - USPS: hand-written digits, slightly different from MNIST.

    Protocol:
        Volpi et al. train a model using 10,000 images from MNIST and
        evaluate the model on the test split of the other four datasets. However,
        the code does not restrict you to only use MNIST as the source dataset.
        Instead, you can use any dataset as the source. But note that only 10,000
        images will be sampled from the source dataset for training.

    Reference:
        - Lecun et al. Gradient-based learning applied to document
        recognition. IEEE 1998.
        - Ganin et al. Domain-adversarial training of neural networks.
        JMLR 2016.
        - Netzer et al. Reading digits in natural images with unsupervised
        feature learning. NIPS-W 2011.
        - Volpi et al. Generalizing to Unseen Domains via Adversarial Data
        Augmentation. NIPS 2018.
    """

    # Reuse the digit-5 folder instead of creating a new folder
    dataset_dir = "digit5"
    domains = ["mnist", "mnist_m", "svhn", "syn", "usps"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train")
        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test")

        super().__init__(train_x=train, val=val, test=test)

    def _read_data(self, input_domains, split="train"):
        items = []

        for domain, dname in enumerate(input_domains):
            func = "load_" + dname
            domain_dir = osp.join(self.dataset_dir, dname)
            items_d = eval(func)(domain_dir, split=split)

            for impath, label in items_d:
                item = Datum(impath=impath, label=label, domain=domain)
                items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
================================================
import glob
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class DigitsDG(DatasetBase):
    """Digits-DG.

    It contains 4 digit datasets:
        - MNIST: hand-written digits.
        - MNIST-M: variant of MNIST with blended background.
        - SVHN: street view house number.
        - SYN: synthetic digits.

    Reference:
        - Lecun et al. Gradient-based learning applied to document
        recognition. IEEE 1998.
        - Ganin et al. Domain-adversarial training of neural networks.
        JMLR 2016.
        - Netzer et al. Reading digits in natural images with unsupervised
        feature learning. NIPS-W 2011.
        - Zhou et al. Deep Domain-Adversarial Image Generation for Domain
        Generalisation. AAAI 2020.
    """

    dataset_dir = "digits_dg"
    domains = ["mnist", "mnist_m", "svhn", "syn"]
    data_url = "https://drive.google.com/uc?id=15V7EsHfCcfbKgsDmzQKj_DfXt_XYp_P7"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        if not osp.exists(self.dataset_dir):
            dst = osp.join(root, "digits_dg.zip")
            self.download_data(self.data_url, dst, from_gdrive=True)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = self.read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
        )
        val = self.read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
        )
        test = self.read_data(
            self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
        )

        super().__init__(train_x=train, val=val, test=test)

    @staticmethod
    def read_data(dataset_dir, input_domains, split):

        def _load_data_from_directory(directory):
            folders = listdir_nohidden(directory)
            folders.sort()
            items_ = []

            for label, folder in enumerate(folders):
                impaths = glob.glob(osp.join(directory, folder, "*.jpg"))

                for impath in impaths:
                    items_.append((impath, label))

            return items_

        items = []

        for domain, dname in enumerate(input_domains):
            if split == "all":
                train_dir = osp.join(dataset_dir, dname, "train")
                impath_label_list = _load_data_from_directory(train_dir)
                val_dir = osp.join(dataset_dir, dname, "val")
                impath_label_list += _load_data_from_directory(val_dir)
            else:
                split_dir = osp.join(dataset_dir, dname, split)
                impath_label_list = _load_data_from_directory(split_dir)

            for impath, label in impath_label_list:
                class_name = impath.split("/")[-2].lower()
                item = Datum(
                    impath=impath,
                    label=label,
                    domain=domain,
                    classname=class_name
                )
                items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/office_home_dg.py
================================================
import os.path as osp

from ..build import DATASET_REGISTRY
from .digits_dg import DigitsDG
from ..base_dataset import DatasetBase


@DATASET_REGISTRY.register()
class OfficeHomeDG(DatasetBase):
    """Office-Home.

    Statistics:
        - Around 15,500 images.
        - 65 classes related to office and home objects.
        - 4 domains: Art, Clipart, Product, Real World.
        - URL: http://hemanthdv.org/OfficeHome-Dataset/.

    Reference:
        - Venkateswara et al. Deep Hashing Network for Unsupervised
        Domain Adaptation. CVPR 2017.
    """

    dataset_dir = "office_home_dg"
    domains = ["art", "clipart", "product", "real_world"]
    data_url = "https://drive.google.com/uc?id=1gkbf_KaxoBws-GWT3XIPZ7BnkqbAxIFa"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        if not osp.exists(self.dataset_dir):
            dst = osp.join(root, "office_home_dg.zip")
            self.download_data(self.data_url, dst, from_gdrive=True)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = DigitsDG.read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "train"
        )
        val = DigitsDG.read_data(
            self.dataset_dir, cfg.DATASET.SOURCE_DOMAINS, "val"
        )
        test = DigitsDG.read_data(
            self.dataset_dir, cfg.DATASET.TARGET_DOMAINS, "all"
        )

        super().__init__(train_x=train, val=val, test=test)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
================================================
import os.path as osp

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class PACS(DatasetBase):
    """PACS.

    Statistics:
        - 4 domains: Photo (1,670), Art (2,048), Cartoon
        (2,344), Sketch (3,929).
        - 7 categories: dog, elephant, giraffe, guitar, horse,
        house and person.

    Reference:
        - Li et al. Deeper, broader and artier domain generalization.
        ICCV 2017.
    """

    dataset_dir = "pacs"
    domains = ["art_painting", "cartoon", "photo", "sketch"]
    data_url = "https://drive.google.com/uc?id=1m4X4fROCCXMO0lRLrr6Zz9Vb3974NWhE"
    # the following images contain errors and should be ignored
    _error_paths = ["sketch/dog/n02103406_4068-1.png"]

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.image_dir = osp.join(self.dataset_dir, "images")
        self.split_dir = osp.join(self.dataset_dir, "splits")

        if not osp.exists(self.dataset_dir):
            dst = osp.join(root, "pacs.zip")
            self.download_data(self.data_url, dst, from_gdrive=True)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "all")

        super().__init__(train_x=train, val=val, test=test)

    def _read_data(self, input_domains, split):
        items = []

        for domain, dname in enumerate(input_domains):
            if split == "all":
                file_train = osp.join(
                    self.split_dir, dname + "_train_kfold.txt"
                )
                impath_label_list = self._read_split_pacs(file_train)
                file_val = osp.join(
                    self.split_dir, dname + "_crossval_kfold.txt"
                )
                impath_label_list += self._read_split_pacs(file_val)
            else:
                file = osp.join(
                    self.split_dir, dname + "_" + split + "_kfold.txt"
                )
                impath_label_list = self._read_split_pacs(file)

            for impath, label in impath_label_list:
                classname = impath.split("/")[-2]
                item = Datum(
                    impath=impath,
                    label=label,
                    domain=domain,
                    classname=classname
                )
                items.append(item)

        return items

    def _read_split_pacs(self, split_file):
        items = []

        with open(split_file, "r") as f:
            lines = f.readlines()

            for line in lines:
                line = line.strip()
                impath, label = line.split(" ")
                if impath in self._error_paths:
                    continue
                impath = osp.join(self.image_dir, impath)
                label = int(label) - 1
                items.append((impath, label))

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
================================================
import glob
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class VLCS(DatasetBase):
    """VLCS.

    Statistics:
        - 4 domains: CALTECH, LABELME, PASCAL, SUN
        - 5 categories: bird, car, chair, dog, and person.

    Reference:
        - Torralba and Efros. Unbiased look at dataset bias. CVPR 2011.
    """

    dataset_dir = "VLCS"
    domains = ["caltech", "labelme", "pascal", "sun"]
    data_url = "https://drive.google.com/uc?id=1r0WL5DDqKfSPp9E3tRENwHaXNs1olLZd"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)

        if not osp.exists(self.dataset_dir):
            dst = osp.join(root, "vlcs.zip")
            self.download_data(self.data_url, dst, from_gdrive=True)

        self.check_input_domains(
            cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS
        )

        train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "train")
        val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, "crossval")
        test = self._read_data(cfg.DATASET.TARGET_DOMAINS, "test")

        super().__init__(train_x=train, val=val, test=test)

    def _read_data(self, input_domains, split):
        items = []

        for domain, dname in enumerate(input_domains):
            dname = dname.upper()
            path = osp.join(self.dataset_dir, dname, split)
            folders = listdir_nohidden(path)
            folders.sort()

            for label, folder in enumerate(folders):
                impaths = glob.glob(osp.join(path, folder, "*.jpg"))

                for impath in impaths:
                    item = Datum(impath=impath, label=label, domain=domain)
                    items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/__init__.py
================================================
from .svhn import SVHN
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
================================================
import math
import random
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class CIFAR10(DatasetBase):
    """CIFAR10 for SSL.

    Reference:
        - Krizhevsky. Learning Multiple Layers of Features
        from Tiny Images. Tech report.
    """

    dataset_dir = "cifar10"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)
        train_dir = osp.join(self.dataset_dir, "train")
        test_dir = osp.join(self.dataset_dir, "test")

        assert cfg.DATASET.NUM_LABELED > 0

        train_x, train_u, val = self._read_data_train(
            train_dir, cfg.DATASET.NUM_LABELED, cfg.DATASET.VAL_PERCENT
        )
        test = self._read_data_test(test_dir)

        if cfg.DATASET.ALL_AS_UNLABELED:
            train_u = train_u + train_x

        if len(val) == 0:
            val = None

        super().__init__(train_x=train_x, train_u=train_u, val=val, test=test)

    def _read_data_train(self, data_dir, num_labeled, val_percent):
        class_names = listdir_nohidden(data_dir)
        class_names.sort()
        num_labeled_per_class = num_labeled / len(class_names)
        items_x, items_u, items_v = [], [], []

        for label, class_name in enumerate(class_names):
            class_dir = osp.join(data_dir, class_name)
            imnames = listdir_nohidden(class_dir)

            # Split into train and val following Oliver et al. 2018
            # Set cfg.DATASET.VAL_PERCENT to 0 to not use val data
            num_val = math.floor(len(imnames) * val_percent)
            imnames_train = imnames[num_val:]
            imnames_val = imnames[:num_val]

            # Note we do shuffle after split
            random.shuffle(imnames_train)

            for i, imname in enumerate(imnames_train):
                impath = osp.join(class_dir, imname)
                item = Datum(impath=impath, label=label)

                if (i + 1) <= num_labeled_per_class:
                    items_x.append(item)

                else:
                    items_u.append(item)

            for imname in imnames_val:
                impath = osp.join(class_dir, imname)
                item = Datum(impath=impath, label=label)
                items_v.append(item)

        return items_x, items_u, items_v

    def _read_data_test(self, data_dir):
        class_names = listdir_nohidden(data_dir)
        class_names.sort()
        items = []

        for label, class_name in enumerate(class_names):
            class_dir = osp.join(data_dir, class_name)
            imnames = listdir_nohidden(class_dir)

            for imname in imnames:
                impath = osp.join(class_dir, imname)
                item = Datum(impath=impath, label=label)
                items.append(item)

        return items


@DATASET_REGISTRY.register()
class CIFAR100(CIFAR10):
    """CIFAR100 for SSL.

    Reference:
        - Krizhevsky. Learning Multiple Layers of Features
        from Tiny Images. Tech report.
    """

    dataset_dir = "cifar100"

    def __init__(self, cfg):
        super().__init__(cfg)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
================================================
import numpy as np
import os.path as osp

from dassl.utils import listdir_nohidden

from ..build import DATASET_REGISTRY
from ..base_dataset import Datum, DatasetBase


@DATASET_REGISTRY.register()
class STL10(DatasetBase):
    """STL-10 dataset.

    Description:
    - 10 classes: airplane, bird, car, cat, deer, dog, horse,
    monkey, ship, truck.
    - Images are 96x96 pixels, color.
    - 500 training images per class, 800 test images per class.
    - 100,000 unlabeled images for unsupervised learning.

    Reference:
        - Coates et al. An Analysis of Single Layer Networks in
        Unsupervised Feature Learning. AISTATS 2011.
    """

    dataset_dir = "stl10"

    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        self.dataset_dir = osp.join(root, self.dataset_dir)
        train_dir = osp.join(self.dataset_dir, "train")
        test_dir = osp.join(self.dataset_dir, "test")
        unlabeled_dir = osp.join(self.dataset_dir, "unlabeled")
        fold_file = osp.join(
            self.dataset_dir, "stl10_binary", "fold_indices.txt"
        )

        # Only use the first five splits
        assert 0 <= cfg.DATASET.STL10_FOLD <= 4

        train_x = self._read_data_train(
            train_dir, cfg.DATASET.STL10_FOLD, fold_file
        )
        train_u = self._read_data_all(unlabeled_dir)
        test = self._read_data_all(test_dir)

        if cfg.DATASET.ALL_AS_UNLABELED:
            train_u = train_u + train_x

        super().__init__(train_x=train_x, train_u=train_u, test=test)

    def _read_data_train(self, data_dir, fold, fold_file):
        imnames = listdir_nohidden(data_dir)
        imnames.sort()
        items = []

        list_idx = list(range(len(imnames)))
        if fold >= 0:
            with open(fold_file, "r") as f:
                str_idx = f.read().splitlines()[fold]
                list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=" ")

        for i in list_idx:
            imname = imnames[i]
            impath = osp.join(data_dir, imname)
            label = osp.splitext(imname)[0].split("_")[1]
            label = int(label)
            item = Datum(impath=impath, label=label)
            items.append(item)

        return items

    def _read_data_all(self, data_dir):
        imnames = listdir_nohidden(data_dir)
        items = []

        for imname in imnames:
            impath = osp.join(data_dir, imname)
            label = osp.splitext(imname)[0].split("_")[1]
            if label == "none":
                label = -1
            else:
                label = int(label)
            item = Datum(impath=impath, label=label)
            items.append(item)

        return items


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
================================================
from .cifar import CIFAR10
from ..build import DATASET_REGISTRY


@DATASET_REGISTRY.register()
class SVHN(CIFAR10):
    """SVHN for SSL.

    Reference:
        - Netzer et al. Reading Digits in Natural Images with
        Unsupervised Feature Learning. NIPS-W 2011.
    """

    dataset_dir = "svhn"

    def __init__(self, cfg):
        super().__init__(cfg)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/samplers.py
================================================
import copy
import numpy as np
import random
from collections import defaultdict
from torch.utils.data.sampler import Sampler, RandomSampler, SequentialSampler


class RandomDomainSampler(Sampler):
    """Randomly samples N domains each with K images
    to form a minibatch of size N*K.

    Args:
        data_source (list): list of Datums.
        batch_size (int): batch size.
        n_domain (int): number of domains to sample in a minibatch.
    """

    def __init__(self, data_source, batch_size, n_domain):
        self.data_source = data_source

        # Keep track of image indices for each domain
        self.domain_dict = defaultdict(list)
        for i, item in enumerate(data_source):
            self.domain_dict[item.domain].append(i)
        self.domains = list(self.domain_dict.keys())

        # Make sure each domain has equal number of images
        if n_domain is None or n_domain <= 0:
            n_domain = len(self.domains)
        assert batch_size % n_domain == 0
        self.n_img_per_domain = batch_size // n_domain

        self.batch_size = batch_size
        # n_domain denotes number of domains sampled in a minibatch
        self.n_domain = n_domain
        self.length = len(list(self.__iter__()))

    def __iter__(self):
        domain_dict = copy.deepcopy(self.domain_dict)
        final_idxs = []
        stop_sampling = False

        while not stop_sampling:
            selected_domains = random.sample(self.domains, self.n_domain)

            for domain in selected_domains:
                idxs = domain_dict[domain]
                selected_idxs = random.sample(idxs, self.n_img_per_domain)
                final_idxs.extend(selected_idxs)

                for idx in selected_idxs:
                    domain_dict[domain].remove(idx)

                remaining = len(domain_dict[domain])
                if remaining < self.n_img_per_domain:
                    stop_sampling = True

        return iter(final_idxs)

    def __len__(self):
        return self.length


class SeqDomainSampler(Sampler):
    """Sequential domain sampler, which randomly samples K
    images from each domain to form a minibatch.

    Args:
        data_source (list): list of Datums.
        batch_size (int): batch size.
    """

    def __init__(self, data_source, batch_size):
        self.data_source = data_source

        # Keep track of image indices for each domain
        self.domain_dict = defaultdict(list)
        for i, item in enumerate(data_source):
            self.domain_dict[item.domain].append(i)
        self.domains = list(self.domain_dict.keys())
        self.domains.sort()

        # Make sure each domain has equal number of images
        n_domain = len(self.domains)
        assert batch_size % n_domain == 0
        self.n_img_per_domain = batch_size // n_domain

        self.batch_size = batch_size
        # n_domain denotes number of domains sampled in a minibatch
        self.n_domain = n_domain
        self.length = len(list(self.__iter__()))

    def __iter__(self):
        domain_dict = copy.deepcopy(self.domain_dict)
        final_idxs = []
        stop_sampling = False

        while not stop_sampling:
            for domain in self.domains:
                idxs = domain_dict[domain]
                selected_idxs = random.sample(idxs, self.n_img_per_domain)
                final_idxs.extend(selected_idxs)

                for idx in selected_idxs:
                    domain_dict[domain].remove(idx)

                remaining = len(domain_dict[domain])
                if remaining < self.n_img_per_domain:
                    stop_sampling = True

        return iter(final_idxs)

    def __len__(self):
        return self.length


class RandomClassSampler(Sampler):
    """Randomly samples N classes each with K instances to
    form a minibatch of size N*K.

    Modified from https://github.com/KaiyangZhou/deep-person-reid.

    Args:
        data_source (list): list of Datums.
        batch_size (int): batch size.
        n_ins (int): number of instances per class to sample in a minibatch.
    """

    def __init__(self, data_source, batch_size, n_ins):
        if batch_size < n_ins:
            raise ValueError(
                "batch_size={} must be no less "
                "than n_ins={}".format(batch_size, n_ins)
            )

        self.data_source = data_source
        self.batch_size = batch_size
        self.n_ins = n_ins
        self.ncls_per_batch = self.batch_size // self.n_ins
        self.index_dic = defaultdict(list)
        for index, item in enumerate(data_source):
            self.index_dic[item.label].append(index)
        self.labels = list(self.index_dic.keys())
        assert len(self.labels) >= self.ncls_per_batch

        # estimate number of images in an epoch
        self.length = len(list(self.__iter__()))

    def __iter__(self):
        batch_idxs_dict = defaultdict(list)

        for label in self.labels:
            idxs = copy.deepcopy(self.index_dic[label])
            if len(idxs) < self.n_ins:
                idxs = np.random.choice(idxs, size=self.n_ins, replace=True)
            random.shuffle(idxs)
            batch_idxs = []
            for idx in idxs:
                batch_idxs.append(idx)
                if len(batch_idxs) == self.n_ins:
                    batch_idxs_dict[label].append(batch_idxs)
                    batch_idxs = []

        avai_labels = copy.deepcopy(self.labels)
        final_idxs = []

        while len(avai_labels) >= self.ncls_per_batch:
            selected_labels = random.sample(avai_labels, self.ncls_per_batch)
            for label in selected_labels:
                batch_idxs = batch_idxs_dict[label].pop(0)
                final_idxs.extend(batch_idxs)
                if len(batch_idxs_dict[label]) == 0:
                    avai_labels.remove(label)

        return iter(final_idxs)

    def __len__(self):
        return self.length


def build_sampler(
    sampler_type,
    cfg=None,
    data_source=None,
    batch_size=32,
    n_domain=0,
    n_ins=16
):
    if sampler_type == "RandomSampler":
        return RandomSampler(data_source)

    elif sampler_type == "SequentialSampler":
        return SequentialSampler(data_source)

    elif sampler_type == "RandomDomainSampler":
        return RandomDomainSampler(data_source, batch_size, n_domain)

    elif sampler_type == "SeqDomainSampler":
        return SeqDomainSampler(data_source, batch_size)

    elif sampler_type == "RandomClassSampler":
        return RandomClassSampler(data_source, batch_size, n_ins)

    else:
        raise ValueError("Unknown sampler type: {}".format(sampler_type))


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py
================================================
from .transforms import build_transform


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
================================================
"""
Source: https://github.com/DeepVoltaire/AutoAugment
"""
import numpy as np
import random
from PIL import Image, ImageOps, ImageEnhance


class ImageNetPolicy:
    """Randomly choose one of the best 24 Sub-policies on ImageNet.

    Example:
        >>> policy = ImageNetPolicy()
        >>> transformed = policy(image)

    Example as a PyTorch Transform:
        >>> transform=transforms.Compose([
        >>>     transforms.Resize(256),
        >>>     ImageNetPolicy(),
        >>>     transforms.ToTensor()])
    """

    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
            SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
            SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
            SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
            SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
            SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
            SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
            SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
            SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
            SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
            SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
            SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
            SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
            SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
            SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
            SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
            SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
            SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
            SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
            SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
            SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
            SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
            SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
            SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
            SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return "AutoAugment ImageNet Policy"


class CIFAR10Policy:
    """Randomly choose one of the best 25 Sub-policies on CIFAR10.

    Example:
        >>> policy = CIFAR10Policy()
        >>> transformed = policy(image)

    Example as a PyTorch Transform:
        >>> transform=transforms.Compose([
        >>>     transforms.Resize(256),
        >>>     CIFAR10Policy(),
        >>>     transforms.ToTensor()])
    """

    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
            SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
            SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
            SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
            SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
            SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
            SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
            SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
            SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
            SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
            SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
            SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
            SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
            SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
            SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
            SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
            SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
            SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
            SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
            SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
            SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
            SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
            SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return "AutoAugment CIFAR10 Policy"


class SVHNPolicy:
    """Randomly choose one of the best 25 Sub-policies on SVHN.

    Example:
        >>> policy = SVHNPolicy()
        >>> transformed = policy(image)

    Example as a PyTorch Transform:
        >>> transform=transforms.Compose([
        >>>     transforms.Resize(256),
        >>>     SVHNPolicy(),
        >>>     transforms.ToTensor()])
    """

    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
            SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
            SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
            SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
            SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
            SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
            SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
            SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
            SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
            SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
            SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
            SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
            SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
            SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
            SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
            SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
            SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
            SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
            SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
            SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
            SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
            SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
            SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
            SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
            SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor),
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return "AutoAugment SVHN Policy"


class SubPolicy(object):

    def __init__(
        self,
        p1,
        operation1,
        magnitude_idx1,
        p2,
        operation2,
        magnitude_idx2,
        fillcolor=(128, 128, 128),
    ):
        ranges = {
            "shearX": np.linspace(0, 0.3, 10),
            "shearY": np.linspace(0, 0.3, 10),
            "translateX": np.linspace(0, 150 / 331, 10),
            "translateY": np.linspace(0, 150 / 331, 10),
            "rotate": np.linspace(0, 30, 10),
            "color": np.linspace(0.0, 0.9, 10),
            "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
            "solarize": np.linspace(256, 0, 10),
            "contrast": np.linspace(0.0, 0.9, 10),
            "sharpness": np.linspace(0.0, 0.9, 10),
            "brightness": np.linspace(0.0, 0.9, 10),
            "autocontrast": [0] * 10,
            "equalize": [0] * 10,
            "invert": [0] * 10,
        }

        # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
        def rotate_with_fill(img, magnitude):
            rot = img.convert("RGBA").rotate(magnitude)
            return Image.composite(
                rot, Image.new("RGBA", rot.size, (128, ) * 4), rot
            ).convert(img.mode)

        func = {
            "shearX":
            lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
                Image.BICUBIC,
                fillcolor=fillcolor,
            ),
            "shearY":
            lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
                Image.BICUBIC,
                fillcolor=fillcolor,
            ),
            "translateX":
            lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (
                    1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0,
                    1, 0
                ),
                fillcolor=fillcolor,
            ),
            "translateY":
            lambda img, magnitude: img.transform(
                img.size,
                Image.AFFINE,
                (
                    1, 0, 0, 0, 1, magnitude * img.size[1] * random.
                    choice([-1, 1])
                ),
                fillcolor=fillcolor,
            ),
            "rotate":
            lambda img, magnitude: rotate_with_fill(img, magnitude),
            "color":
            lambda img, magnitude: ImageEnhance.Color(img).
            enhance(1 + magnitude * random.choice([-1, 1])),
            "posterize":
            lambda img, magnitude: ImageOps.posterize(img, magnitude),
            "solarize":
            lambda img, magnitude: ImageOps.solarize(img, magnitude),
            "contrast":
            lambda img, magnitude: ImageEnhance.Contrast(img).
            enhance(1 + magnitude * random.choice([-1, 1])),
            "sharpness":
            lambda img, magnitude: ImageEnhance.Sharpness(img).
            enhance(1 + magnitude * random.choice([-1, 1])),
            "brightness":
            lambda img, magnitude: ImageEnhance.Brightness(img).
            enhance(1 + magnitude * random.choice([-1, 1])),
            "autocontrast":
            lambda img, magnitude: ImageOps.autocontrast(img),
            "equalize":
            lambda img, magnitude: ImageOps.equalize(img),
            "invert":
            lambda img, magnitude: ImageOps.invert(img),
        }

        self.p1 = p1
        self.operation1 = func[operation1]
        self.magnitude1 = ranges[operation1][magnitude_idx1]
        self.p2 = p2
        self.operation2 = func[operation2]
        self.magnitude2 = ranges[operation2][magnitude_idx2]

    def __call__(self, img):
        if random.random() < self.p1:
            img = self.operation1(img, self.magnitude1)
        if random.random() < self.p2:
            img = self.operation2(img, self.magnitude2)
        return img


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
================================================
"""
Credit to
1) https://github.com/ildoonet/pytorch-randaugment
2) https://github.com/kakaobrain/fast-autoaugment
"""
import numpy as np
import random
import PIL
import torch
import PIL.ImageOps
import PIL.ImageDraw
import PIL.ImageEnhance
from PIL import Image


def ShearX(img, v):
    assert -0.3 <= v <= 0.3
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))


def ShearY(img, v):
    assert -0.3 <= v <= 0.3
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))


def TranslateX(img, v):
    # [-150, 150] => percentage: [-0.45, 0.45]
    assert -0.45 <= v <= 0.45
    if random.random() > 0.5:
        v = -v
    v = v * img.size[0]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateXabs(img, v):
    # [-150, 150] => percentage: [-0.45, 0.45]
    assert 0 <= v
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))


def TranslateY(img, v):
    # [-150, 150] => percentage: [-0.45, 0.45]
    assert -0.45 <= v <= 0.45
    if random.random() > 0.5:
        v = -v
    v = v * img.size[1]
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def TranslateYabs(img, v):
    # [-150, 150] => percentage: [-0.45, 0.45]
    assert 0 <= v
    if random.random() > 0.5:
        v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))


def Rotate(img, v):
    assert -30 <= v <= 30
    if random.random() > 0.5:
        v = -v
    return img.rotate(v)


def AutoContrast(img, _):
    return PIL.ImageOps.autocontrast(img)


def Invert(img, _):
    return PIL.ImageOps.invert(img)


def Equalize(img, _):
    return PIL.ImageOps.equalize(img)


def Flip(img, _):
    return PIL.ImageOps.mirror(img)


def Solarize(img, v):
    assert 0 <= v <= 256
    return PIL.ImageOps.solarize(img, v)


def SolarizeAdd(img, addition=0, threshold=128):
    img_np = np.array(img).astype(np.int)
    img_np = img_np + addition
    img_np = np.clip(img_np, 0, 255)
    img_np = img_np.astype(np.uint8)
    img = Image.fromarray(img_np)
    return PIL.ImageOps.solarize(img, threshold)


def Posterize(img, v):
    assert 4 <= v <= 8
    v = int(v)
    return PIL.ImageOps.posterize(img, v)


def Contrast(img, v):
    assert 0.0 <= v <= 2.0
    return PIL.ImageEnhance.Contrast(img).enhance(v)


def Color(img, v):
    assert 0.0 <= v <= 2.0
    return PIL.ImageEnhance.Color(img).enhance(v)


def Brightness(img, v):
    assert 0.0 <= v <= 2.0
    return PIL.ImageEnhance.Brightness(img).enhance(v)


def Sharpness(img, v):
    assert 0.0 <= v <= 2.0
    return PIL.ImageEnhance.Sharpness(img).enhance(v)


def Cutout(img, v):
    # [0, 60] => percentage: [0, 0.2]
    assert 0.0 <= v <= 0.2
    if v <= 0.0:
        return img

    v = v * img.size[0]
    return CutoutAbs(img, v)


def CutoutAbs(img, v):
    # [0, 60] => percentage: [0, 0.2]
    # assert 0 <= v <= 20
    if v < 0:
        return img
    w, h = img.size
    x0 = np.random.uniform(w)
    y0 = np.random.uniform(h)

    x0 = int(max(0, x0 - v/2.0))
    y0 = int(max(0, y0 - v/2.0))
    x1 = min(w, x0 + v)
    y1 = min(h, y0 + v)

    xy = (x0, y0, x1, y1)
    color = (125, 123, 114)
    # color = (0, 0, 0)
    img = img.copy()
    PIL.ImageDraw.Draw(img).rectangle(xy, color)
    return img


def SamplePairing(imgs):
    # [0, 0.4]
    def f(img1, v):
        i = np.random.choice(len(imgs))
        img2 = PIL.Image.fromarray(imgs[i])
        return PIL.Image.blend(img1, img2, v)

    return f


def Identity(img, v):
    return img


class Lighting:
    """Lighting noise (AlexNet - style PCA - based noise)."""

    def __init__(self, alphastd, eigval, eigvec):
        self.alphastd = alphastd
        self.eigval = torch.Tensor(eigval)
        self.eigvec = torch.Tensor(eigvec)

    def __call__(self, img):
        if self.alphastd == 0:
            return img

        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = (
            self.eigvec.type_as(img).clone().mul(
                alpha.view(1, 3).expand(3, 3)
            ).mul(self.eigval.view(1, 3).expand(3, 3)).sum(1).squeeze()
        )

        return img.add(rgb.view(3, 1, 1).expand_as(img))


class CutoutDefault:
    """
    Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
    """

    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1:y2, x1:x2] = 0.0
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def randaugment_list():
    # 16 oeprations and their ranges
    # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
    # augs = [
    #     (Identity, 0., 1.0),
    #     (ShearX, 0., 0.3),  # 0
    #     (ShearY, 0., 0.3),  # 1
    #     (TranslateX, 0., 0.33),  # 2
    #     (TranslateY, 0., 0.33),  # 3
    #     (Rotate, 0, 30),  # 4
    #     (AutoContrast, 0, 1),  # 5
    #     (Invert, 0, 1),  # 6
    #     (Equalize, 0, 1),  # 7
    #     (Solarize, 0, 110),  # 8
    #     (Posterize, 4, 8),  # 9
    #     # (Contrast, 0.1, 1.9),  # 10
    #     (Color, 0.1, 1.9),  # 11
    #     (Brightness, 0.1, 1.9),  # 12
    #     (Sharpness, 0.1, 1.9),  # 13
    #     # (Cutout, 0, 0.2),  # 14
    #     # (SamplePairing(imgs), 0, 0.4)  # 15
    # ]

    # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
    augs = [
        (AutoContrast, 0, 1),
        (Equalize, 0, 1),
        (Invert, 0, 1),
        (Rotate, 0, 30),
        (Posterize, 4, 8),
        (Solarize, 0, 256),
        (SolarizeAdd, 0, 110),
        (Color, 0.1, 1.9),
        (Contrast, 0.1, 1.9),
        (Brightness, 0.1, 1.9),
        (Sharpness, 0.1, 1.9),
        (ShearX, 0.0, 0.3),
        (ShearY, 0.0, 0.3),
        (CutoutAbs, 0, 40),
        (TranslateXabs, 0.0, 100),
        (TranslateYabs, 0.0, 100),
    ]

    return augs


def randaugment_list2():
    augs = [
        (AutoContrast, 0, 1),
        (Brightness, 0.1, 1.9),
        (Color, 0.1, 1.9),
        (Contrast, 0.1, 1.9),
        (Equalize, 0, 1),
        (Identity, 0, 1),
        (Invert, 0, 1),
        (Posterize, 4, 8),
        (Rotate, -30, 30),
        (Sharpness, 0.1, 1.9),
        (ShearX, -0.3, 0.3),
        (ShearY, -0.3, 0.3),
        (Solarize, 0, 256),
        (TranslateX, -0.3, 0.3),
        (TranslateY, -0.3, 0.3),
    ]

    return augs


def fixmatch_list():
    # https://arxiv.org/abs/2001.07685
    augs = [
        (AutoContrast, 0, 1),
        (Brightness, 0.05, 0.95),
        (Color, 0.05, 0.95),
        (Contrast, 0.05, 0.95),
        (Equalize, 0, 1),
        (Identity, 0, 1),
        (Posterize, 4, 8),
        (Rotate, -30, 30),
        (Sharpness, 0.05, 0.95),
        (ShearX, -0.3, 0.3),
        (ShearY, -0.3, 0.3),
        (Solarize, 0, 256),
        (TranslateX, -0.3, 0.3),
        (TranslateY, -0.3, 0.3),
    ]

    return augs


class RandAugment:

    def __init__(self, n=2, m=10):
        assert 0 <= m <= 30
        self.n = n
        self.m = m
        self.augment_list = randaugment_list()

    def __call__(self, img):
        ops = random.choices(self.augment_list, k=self.n)

        for op, minval, maxval in ops:
            val = (self.m / 30) * (maxval-minval) + minval
            img = op(img, val)

        return img


class RandAugment2:

    def __init__(self, n=2, p=0.6):
        self.n = n
        self.p = p
        self.augment_list = randaugment_list2()

    def __call__(self, img):
        ops = random.choices(self.augment_list, k=self.n)

        for op, minval, maxval in ops:
            if random.random() > self.p:
                continue
            m = random.random()
            val = m * (maxval-minval) + minval
            img = op(img, val)

        return img


class RandAugmentFixMatch:

    def __init__(self, n=2):
        self.n = n
        self.augment_list = fixmatch_list()

    def __call__(self, img):
        ops = random.choices(self.augment_list, k=self.n)

        for op, minval, maxval in ops:
            m = random.random()
            val = m * (maxval-minval) + minval
            img = op(img, val)

        return img


================================================
FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
================================================
import numpy as np
import random
import torch
from PIL import Image
from torchvision.transforms import (
    Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, ColorJitter,
    RandomApply, GaussianBlur, RandomGrayscale, RandomResizedCrop,
    RandomHorizontalFlip
)

from .autoaugment import SVHNPolicy, CIFAR10Policy, ImageNetPolicy
from .randaugment import RandAugment, RandAugment2, RandAugmentFixMatch

AVAI_CHOICES = [
    "random_flip",
    "random_resized_crop",
    "normalize",
    "instance_norm",
    "random_crop",
    "random_translation",
    "center_crop",  # This has become a default operation for test
    "cutout",
    "imagenet_policy",
    "cifar10_policy",
    "svhn_policy",
    "randaugment",
    "randaugment_fixmatch",
    "randaugment2",
    "gaussian_noise",
    "colorjitter",
    "randomgrayscale",
    "gaussian_blur",
]

INTERPOLATION_MODES = {
    "bilinear": Image.BILINEAR,
    "bicubic": Image.BICUBIC,
    "nearest": Image.NEAREST,
}


class Random2DTranslation:
    """Given an image of (height, width), we resize it to
    (height*1.125, width*1.125), and then perform random cropping.

    Args:
        height (int): target image height.
        width (int): target image width.
        p (float, optional): probability that this operation takes place.
            Default is 0.5.
        interpolation (int, optional): desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
        self.height = height
        self.width = width
        self.p = p
        self.interpolation = interpolation

    def __call__(self, img):
        if random.uniform(0, 1) > self.p:
            return img.resize((self.width, self.height), self.interpolation)

        new_width = int(round(self.width * 1.125))
        new_height = int(round(self.height * 1.125))
        resized_img = img.resize((new_width, new_height), self.interpolation)

        x_maxrange = new_width - self.width
        y_maxrange = new_height - self.height
        x1 = int(round(random.uniform(0, x_maxrange)))
        y1 = int(round(random.uniform(0, y_maxrange)))
        croped_img = resized_img.crop(
            (x1, y1, x1 + self.width, y1 + self.height)
        )

        return croped_img


class InstanceNormalization:
    """Normalize data using per-channel mean and standard deviation.

    Reference:
        - Ulyanov et al. Instance normalization: The missing in- gredient
          for fast stylization. ArXiv 2016.
        - Shu et al. A DIRT-T Approach to Unsupervised Domain Adaptation.
          ICLR 2018.
    """

    def __init__(self, eps=1e-8):
        self.eps = eps

    def __call__(self, img):
        C, H, W = img.shape
        img_re = img.reshape(C, H * W)
        mean = img_re.mean(1).view(C, 1, 1)
        std = img_re.std(1).view(C, 1, 1)
        return (img-mean) / (std + self.eps)


class Cutout:
    """Randomly mask out one or more patches from an image.

    https://github.com/uoguelph-mlrg/Cutout

    Args:
        n_holes (int, optional): number of patches to cut out
            of each image. Default is 1.
        length (int, optinal): length (in pixels) of each square
            patch. Default is 16.
    """

    def __init__(self, n_holes=1, length=16):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): tensor image of size (C, H, W).

        Returns:
            Tensor: image with n_holes of dimension
                length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1:y2, x1:x2] = 0.0

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        return img * mask


class GaussianNoise:
    """Add gaussian noise."""

    def __init__(self, mean=0, std=0.15, p=0.5):
        self.mean = mean
        self.std = std
        self.p = p

    def __call__(self, img):
        if random.uniform(0, 1) > self.p:
            return img
        noise = torch.randn(img.size()) * self.std + self.mean
        return img + noise


def build_transform(cfg, is_train=True, choices=None):
    """Build transformation function.

    Args:
        cfg (CfgNode): config.
        is_train (bool, optional): for training (True) or test (False).
            Default is True.
        choices (list, optional): list of strings which will overwrite
            cfg.INPUT.TRANSFORMS if given. Default is None.
    """
    if cfg.INPUT.NO_TRANSFORM:
        print("Note: no transform is applied!")
        return None

    if choices is None:
        choices = cfg.INPUT.TRANSFORMS

    for choice in choices:
        assert choice in AVAI_CHOICES

    target_size = f"{cfg.INPUT.SIZE[0]}x{cfg.INPUT.SIZE[1]}"

    normalize = Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)

    if is_train:
        return _build_transform_train(cfg, choices, target_size, normalize)
    else:
        return _build_transform_test(cfg, choices, target_size, normalize)


def _build_transform_train(cfg, choices, target_size, normalize):
    print("Building transform_train")
    tfm_train = []

    interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]

    # Make sure the image size matches the target size
    conditions = []
    conditions += ["random_crop" not in choices]
    conditions += ["random_resized_crop" not in choices]
    if all(conditions):
        print(f"+ resize to {target_size}")
        tfm_train += [Resize(cfg.INPUT.SIZE, interpolation=interp_mode)]

    if "random_translation" in choices:
        print("+ random translation")
        tfm_train += [
            Random2DTranslation(cfg.INPUT.SIZE[0], cfg.INPUT.SIZE[1])
        ]

    if "random_crop" in choices:
        crop_padding = cfg.INPUT.CROP_PADDING
        print("+ random crop (padding = {})".format(crop_padding))
        tfm_train += [RandomCrop(cfg.INPUT.SIZE, padding=crop_padding)]

    if "random_resized_crop" in choices:
        print(f"+ random resized crop (size={cfg.INPUT.SIZE})")
        tfm_train += [
            RandomResizedCrop(cfg.INPUT.SIZE, interpolation=interp_mode)
        ]

    if "center_crop" in choices:
        print(f"+ center crop (size={cfg.INPUT.SIZE})")
        tfm_train += [CenterCrop(cfg.INPUT.SIZE)]

    if "random_flip" in choices:
        print("+ random flip")
        tfm_train += [RandomHorizontalFlip()]

    if "imagenet_policy" in choices:
        print("+ imagenet policy")
        tfm_train += [ImageNetPolicy()]

    if "cifar10_policy" in choices:
        print("+ cifar10 policy")
        tfm_train += [CIFAR10Policy()]

    if "svhn_policy" in choices:
        print("+ svhn policy")
        tfm_train += [SVHNPolicy()]

    if "randaugment" in choices:
        n_ = cfg.INPUT.RANDAUGMENT_N
        m_ = cfg.INPUT.RANDAUGMENT_M
        print("+ randaugment (n={}, m={})".format(n_, m_))
        tfm_train += [RandAugment(n_, m_)]

    if "randaugment_fixmatch" in choices:
        n_ = cfg.INPUT.RANDAUGMENT_N
        print("+ randaugment_fixmatch (n={})".format(n_))
        tfm_train += [RandAugmentFixMatch(n_)]

    if "randaugment2" in choices:
        n_ = cfg.INPUT.RANDAUGMENT_N
        print("+ randaugment2 (n={})".format(n_))
        tfm_train += [RandAugment2(n_)]

    if "colorjitter" in choices:
        print("+ color jitter")
        tfm_train += [
            ColorJitter(
                brightness=cfg.INPUT.COLORJITTER_B,
                contrast=cfg.INPUT.COLORJITTER_C,
                saturation=cfg.INPUT.COLORJITTER_S,
                hue=cfg.INPUT.COLORJITTER_H,
            )
        ]

    if "randomgrayscale" in choices:
        print("+ random gray scale")
        tfm_train += [RandomGrayscale(p=cfg.INPUT.RGS_P)]

    if "gaussian_blur" in choices:
        print(f"+ gaussian blur (kernel={cfg.INPUT.GB_K})")
        tfm_train += [
            RandomApply([GaussianBlur(cfg.INPUT.GB_K)], p=cfg.INPUT.GB_P)
        ]

    print("+ to torch tensor of range [0, 1]")
    tfm_train += [ToTensor()]

    if "cutout" in choices:
        cutout_n = cfg.INPUT.CUTOUT_N
        cutout_len = cfg.INPUT.CUTOUT_LEN
        print("+ cutout (n_holes={}, length={})".format(cutout_n, cutout_len))
        tfm_train += [Cutout(cutout_n, cutout_len)]

    if "normalize" in choices:
        print(
            "+ normalization (mean={}, "
            "std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
        )
        tfm_train += [normalize]

    if "gaussian_noise" in choices:
        print(
            "+ gaussian noise (mean={}, std={})".format(
                cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD
            )
        )
        tfm_train += [GaussianNoise(cfg.INPUT.GN_MEAN, cfg.INPUT.GN_STD)]

    if "instance_norm" in choices:
        print("+ instance normalization")
        tfm_train += [InstanceNormalization()]

    tfm_train = Compose(tfm_train)

    return tfm_train


def _build_transform_test(cfg, choices, target_size, normalize):
    print("Building transform_test")
    tfm_test = []

    interp_mode = INTERPOLATION_MODES[cfg.INPUT.INTERPOLATION]

    print(f"+ resize the smaller edge to {max(cfg.INPUT.SIZE)}")
    tfm_test += [Resize(max(cfg.INPUT.SIZE), interpolation=interp_mode)]

    print(f"+ {target_size} center crop")
    tfm_test += [CenterCrop(cfg.INPUT.SIZE)]

    print("+ to torch tensor of range [0, 1]")
    tfm_test += [ToTensor()]

    if "normalize" in choices:
        print(
            "+ normalization (mean={}, "
            "std={})".format(cfg.INPUT.PIXEL_MEAN, cfg.INPUT.PIXEL_STD)
        )
        tfm_test += [normalize]

    if "instance_norm" in choices:
        print("+ instance normalization")
        tfm_test += [InstanceNormalization()]

    tfm_test = Compose(tfm_test)

    return tfm_test


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/__init__.py
================================================
from .build import TRAINER_REGISTRY, build_trainer  # isort:skip
from .trainer import TrainerX, TrainerXU, TrainerBase, SimpleTrainer, SimpleNet  # isort:skip

from .da import *
from .dg import *
from .ssl import *


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/build.py
================================================
from dassl.utils import Registry, check_availability

TRAINER_REGISTRY = Registry("TRAINER")


def build_trainer(cfg):
    avai_trainers = TRAINER_REGISTRY.registered_names()
    check_availability(cfg.TRAINER.NAME, avai_trainers)
    if cfg.VERBOSE:
        print("Loading trainer: {}".format(cfg.TRAINER.NAME))
    return TRAINER_REGISTRY.get(cfg.TRAINER.NAME)(cfg)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py
================================================
from .mcd import MCD
from .mme import MME
from .adda import ADDA
from .dael import DAEL
from .dann import DANN
from .adabn import AdaBN
from .m3sda import M3SDA
from .source_only import SourceOnly
from .self_ensembling import SelfEnsembling


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
================================================
import torch

from dassl.utils import check_isfile
from dassl.engine import TRAINER_REGISTRY, TrainerXU


@TRAINER_REGISTRY.register()
class AdaBN(TrainerXU):
    """Adaptive Batch Normalization.

    https://arxiv.org/abs/1603.04779.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.done_reset_bn_stats = False

    def check_cfg(self, cfg):
        assert check_isfile(
            cfg.MODEL.INIT_WEIGHTS
        ), "The weights of source model must be provided"

    def before_epoch(self):
        if not self.done_reset_bn_stats:
            for m in self.model.modules():
                classname = m.__class__.__name__
                if classname.find("BatchNorm") != -1:
                    m.reset_running_stats()

            self.done_reset_bn_stats = True

    def forward_backward(self, batch_x, batch_u):
        input_u = batch_u["img"].to(self.device)

        with torch.no_grad():
            self.model(input_u)

        return None


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
================================================
import copy
import torch
import torch.nn as nn

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import check_isfile, count_num_param, open_specified_layers
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling import build_head


@TRAINER_REGISTRY.register()
class ADDA(TrainerXU):
    """Adversarial Discriminative Domain Adaptation.

    https://arxiv.org/abs/1702.05464.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.open_layers = ["backbone"]
        if isinstance(self.model.head, nn.Module):
            self.open_layers.append("head")

        self.source_model = copy.deepcopy(self.model)
        self.source_model.eval()
        for param in self.source_model.parameters():
            param.requires_grad_(False)

        self.build_critic()

        self.bce = nn.BCEWithLogitsLoss()

    def check_cfg(self, cfg):
        assert check_isfile(
            cfg.MODEL.INIT_WEIGHTS
        ), "The weights of source model must be provided"

    def build_critic(self):
        cfg = self.cfg

        print("Building critic network")
        fdim = self.model.fdim
        critic_body = build_head(
            "mlp",
            verbose=cfg.VERBOSE,
            in_features=fdim,
            hidden_layers=[fdim, fdim // 2],
            activation="leaky_relu",
        )
        self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
        print("# params: {:,}".format(count_num_param(self.critic)))
        self.critic.to(self.device)
        self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
        self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
        self.register_model("critic", self.critic, self.optim_c, self.sched_c)

    def forward_backward(self, batch_x, batch_u):
        open_specified_layers(self.model, self.open_layers)
        input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
        domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
        domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)

        _, feat_x = self.source_model(input_x, return_feature=True)
        _, feat_u = self.model(input_u, return_feature=True)

        logit_xd = self.critic(feat_x)
        logit_ud = self.critic(feat_u.detach())

        loss_critic = self.bce(logit_xd, domain_x)
        loss_critic += self.bce(logit_ud, domain_u)
        self.model_backward_and_update(loss_critic, "critic")

        logit_ud = self.critic(feat_u)
        loss_model = self.bce(logit_ud, 1 - domain_u)
        self.model_backward_and_update(loss_model, "model")

        loss_summary = {
            "loss_critic": loss_critic.item(),
            "loss_model": loss_model.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
================================================
import torch
import torch.nn as nn

from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot


class Experts(nn.Module):

    def __init__(self, n_source, fdim, num_classes):
        super().__init__()
        self.linears = nn.ModuleList(
            [nn.Linear(fdim, num_classes) for _ in range(n_source)]
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, i, x):
        x = self.linears[i](x)
        x = self.softmax(x)
        return x


@TRAINER_REGISTRY.register()
class DAEL(TrainerXU):
    """Domain Adaptive Ensemble Learning.

    https://arxiv.org/abs/2003.07325.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.weight_u = cfg.TRAINER.DAEL.WEIGHT_U
        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X
        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0

    def build_data_loader(self):
        cfg = self.cfg
        tfm_train = build_transform(cfg, is_train=True)
        custom_tfm_train = [tfm_train]
        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
        custom_tfm_train += [tfm_train_strong]
        dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
        self.train_loader_x = dm.train_loader_x
        self.train_loader_u = dm.train_loader_u
        self.val_loader = dm.val_loader
        self.test_loader = dm.test_loader
        self.num_classes = dm.num_classes
        self.num_source_domains = dm.num_source_domains
        self.lab2cname = dm.lab2cname

    def build_model(self):
        cfg = self.cfg

        print("Building F")
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print("# params: {:,}".format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model("F", self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print("Building E")
        self.E = Experts(self.num_source_domains, fdim, self.num_classes)
        self.E.to(self.device)
        print("# params: {:,}".format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model("E", self.E, self.optim_E, self.sched_E)

    def forward_backward(self, batch_x, batch_u):
        parsed_data = self.parse_batch_train(batch_x, batch_u)
        input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data

        input_x = torch.split(input_x, self.split_batch, 0)
        input_x2 = torch.split(input_x2, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]

        # Generate pseudo label
        with torch.no_grad():
            feat_u = self.F(input_u)
            pred_u = []
            for k in range(self.num_source_domains):
                pred_uk = self.E(k, feat_u)
                pred_uk = pred_uk.unsqueeze(1)
                pred_u.append(pred_uk)
            pred_u = torch.cat(pred_u, 1)  # (B, K, C)
            # Get the highest probability and index (label) for each expert
            experts_max_p, experts_max_idx = pred_u.max(2)  # (B, K)
            # Get the most confident expert
            max_expert_p, max_expert_idx = experts_max_p.max(1)  # (B)
            pseudo_label_u = []
            for i, experts_label in zip(max_expert_idx, experts_max_idx):
                pseudo_label_u.append(experts_label[i])
            pseudo_label_u = torch.stack(pseudo_label_u, 0)
            pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
            pseudo_label_u = pseudo_label_u.to(self.device)
            label_u_mask = (max_expert_p >= self.conf_thre).float()

        loss_x = 0
        loss_cr = 0
        acc_x = 0

        feat_x = [self.F(x) for x in input_x]
        feat_x2 = [self.F(x) for x in input_x2]
        feat_u2 = self.F(input_u2)

        for feat_xi, feat_x2i, label_xi, i in zip(
            feat_x, feat_x2, label_x, domain_x
        ):
            cr_s = [j for j in domain_x if j != i]

            # Learning expert
            pred_xi = self.E(i, feat_xi)
            loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
            expert_label_xi = pred_xi.detach()
            acc_x += compute_accuracy(pred_xi.detach(),
                                      label_xi.max(1)[1])[0].item()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat_x2i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        acc_x /= self.n_domain

        # Unsupervised loss
        pred_u = []
        for k in range(self.num_source_domains):
            pred_uk = self.E(k, feat_u2)
            pred_uk = pred_uk.unsqueeze(1)
            pred_u.append(pred_uk)
        pred_u = torch.cat(pred_u, 1)
        pred_u = pred_u.mean(1)
        l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
        loss_u = (l_u * label_u_mask).mean()

        loss = 0
        loss += loss_x
        loss += loss_cr
        loss += loss_u * self.weight_u
        self.model_backward_and_update(loss)

        loss_summary = {
            "loss_x": loss_x.item(),
            "acc_x": acc_x,
            "loss_cr": loss_cr.item(),
            "loss_u": loss_u.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch_x, batch_u):
        input_x = batch_x["img"]
        input_x2 = batch_x["img2"]
        label_x = batch_x["label"]
        domain_x = batch_x["domain"]
        input_u = batch_u["img"]
        input_u2 = batch_u["img2"]

        label_x = create_onehot(label_x, self.num_classes)

        input_x = input_x.to(self.device)
        input_x2 = input_x2.to(self.device)
        label_x = label_x.to(self.device)
        input_u = input_u.to(self.device)
        input_u2 = input_u2.to(self.device)

        return input_x, input_x2, label_x, domain_x, input_u, input_u2

    def model_inference(self, input):
        f = self.F(input)
        p = []
        for k in range(self.num_source_domains):
            p_k = self.E(k, f)
            p_k = p_k.unsqueeze(1)
            p.append(p_k)
        p = torch.cat(p, 1)
        p = p.mean(1)
        return p


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
================================================
import numpy as np
import torch
import torch.nn as nn

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling import build_head
from dassl.modeling.ops import ReverseGrad


@TRAINER_REGISTRY.register()
class DANN(TrainerXU):
    """Domain-Adversarial Neural Networks.

    https://arxiv.org/abs/1505.07818.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.build_critic()
        self.ce = nn.CrossEntropyLoss()
        self.bce = nn.BCEWithLogitsLoss()

    def build_critic(self):
        cfg = self.cfg

        print("Building critic network")
        fdim = self.model.fdim
        critic_body = build_head(
            "mlp",
            verbose=cfg.VERBOSE,
            in_features=fdim,
            hidden_layers=[fdim, fdim],
            activation="leaky_relu",
        )
        self.critic = nn.Sequential(critic_body, nn.Linear(fdim, 1))
        print("# params: {:,}".format(count_num_param(self.critic)))
        self.critic.to(self.device)
        self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
        self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
        self.register_model("critic", self.critic, self.optim_c, self.sched_c)
        self.revgrad = ReverseGrad()

    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
        domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
        domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)

        global_step = self.batch_idx + self.epoch * self.num_batches
        progress = global_step / (self.max_epoch * self.num_batches)
        lmda = 2 / (1 + np.exp(-10 * progress)) - 1

        logit_x, feat_x = self.model(input_x, return_feature=True)
        _, feat_u = self.model(input_u, return_feature=True)

        loss_x = self.ce(logit_x, label_x)

        feat_x = self.revgrad(feat_x, grad_scaling=lmda)
        feat_u = self.revgrad(feat_u, grad_scaling=lmda)
        output_xd = self.critic(feat_x)
        output_ud = self.critic(feat_u)
        loss_d = self.bce(output_xd, domain_x) + self.bce(output_ud, domain_u)

        loss = loss_x + loss_d
        self.model_backward_and_update(loss)

        loss_summary = {
            "loss_x": loss_x.item(),
            "acc_x": compute_accuracy(logit_x, label_x)[0].item(),
            "loss_d": loss_d.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet


class PairClassifiers(nn.Module):

    def __init__(self, fdim, num_classes):
        super().__init__()
        self.c1 = nn.Linear(fdim, num_classes)
        self.c2 = nn.Linear(fdim, num_classes)

    def forward(self, x):
        z1 = self.c1(x)
        if not self.training:
            return z1
        z2 = self.c2(x)
        return z1, z2


@TRAINER_REGISTRY.register()
class M3SDA(TrainerXU):
    """Moment Matching for Multi-Source Domain Adaptation.

    https://arxiv.org/abs/1812.01754.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.n_step_F = cfg.TRAINER.M3SDA.N_STEP_F
        self.lmda = cfg.TRAINER.M3SDA.LMDA

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
        assert not cfg.DATALOADER.TRAIN_U.SAME_AS_X

    def build_model(self):
        cfg = self.cfg

        print("Building F")
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print("# params: {:,}".format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model("F", self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print("Building C")
        self.C = nn.ModuleList(
            [
                PairClassifiers(fdim, self.num_classes)
                for _ in range(self.num_source_domains)
            ]
        )
        self.C.to(self.device)
        print("# params: {:,}".format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model("C", self.C, self.optim_C, self.sched_C)

    def forward_backward(self, batch_x, batch_u):
        parsed = self.parse_batch_train(batch_x, batch_u)
        input_x, label_x, domain_x, input_u = parsed

        input_x = torch.split(input_x, self.split_batch, 0)
        label_x = torch.split(label_x, self.split_batch, 0)
        domain_x = torch.split(domain_x, self.split_batch, 0)
        domain_x = [d[0].item() for d in domain_x]

        # Step A
        loss_x = 0
        feat_x = []

        for x, y, d in zip(input_x, label_x, domain_x):
            f = self.F(x)
            z1, z2 = self.C[d](f)
            loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)

            feat_x.append(f)

        loss_x /= self.n_domain

        feat_u = self.F(input_u)
        loss_msda = self.moment_distance(feat_x, feat_u)

        loss_step_A = loss_x + loss_msda * self.lmda
        self.model_backward_and_update(loss_step_A)

        # Step B
        with torch.no_grad():
            feat_u = self.F(input_u)

        loss_x, loss_dis = 0, 0

        for x, y, d in zip(input_x, label_x, domain_x):
            with torch.no_grad():
                f = self.F(x)
            z1, z2 = self.C[d](f)
            loss_x += F.cross_entropy(z1, y) + F.cross_entropy(z2, y)

            z1, z2 = self.C[d](feat_u)
            p1 = F.softmax(z1, 1)
            p2 = F.softmax(z2, 1)
            loss_dis += self.discrepancy(p1, p2)

        loss_x /= self.n_domain
        loss_dis /= self.n_domain

        loss_step_B = loss_x - loss_dis
        self.model_backward_and_update(loss_step_B, "C")

        # Step C
        for _ in range(self.n_step_F):
            feat_u = self.F(input_u)

            loss_dis = 0

            for d in domain_x:
                z1, z2 = self.C[d](feat_u)
                p1 = F.softmax(z1, 1)
                p2 = F.softmax(z2, 1)
                loss_dis += self.discrepancy(p1, p2)

            loss_dis /= self.n_domain
            loss_step_C = loss_dis

            self.model_backward_and_update(loss_step_C, "F")

        loss_summary = {
            "loss_step_A": loss_step_A.item(),
            "loss_step_B": loss_step_B.item(),
            "loss_step_C": loss_step_C.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def moment_distance(self, x, u):
        # x (list): a list of feature matrix.
        # u (torch.Tensor): feature matrix.
        x_mean = [xi.mean(0) for xi in x]
        u_mean = u.mean(0)
        dist1 = self.pairwise_distance(x_mean, u_mean)

        x_var = [xi.var(0) for xi in x]
        u_var = u.var(0)
        dist2 = self.pairwise_distance(x_var, u_var)

        return (dist1+dist2) / 2

    def pairwise_distance(self, x, u):
        # x (list): a list of feature vector.
        # u (torch.Tensor): feature vector.
        dist = 0
        count = 0

        for xi in x:
            dist += self.euclidean(xi, u)
            count += 1

        for i in range(len(x) - 1):
            for j in range(i + 1, len(x)):
                dist += self.euclidean(x[i], x[j])
                count += 1

        return dist / count

    def euclidean(self, input1, input2):
        return ((input1 - input2)**2).sum().sqrt()

    def discrepancy(self, y1, y2):
        return (y1 - y2).abs().mean()

    def parse_batch_train(self, batch_x, batch_u):
        input_x = batch_x["img"]
        label_x = batch_x["label"]
        domain_x = batch_x["domain"]
        input_u = batch_u["img"]

        input_x = input_x.to(self.device)
        label_x = label_x.to(self.device)
        input_u = input_u.to(self.device)

        return input_x, label_x, domain_x, input_u

    def model_inference(self, input):
        f = self.F(input)
        p = 0
        for C_i in self.C:
            z = C_i(f)
            p += F.softmax(z, 1)
        p = p / len(self.C)
        return p


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.engine.trainer import SimpleNet


@TRAINER_REGISTRY.register()
class MCD(TrainerXU):
    """Maximum Classifier Discrepancy.

    https://arxiv.org/abs/1712.02560.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.n_step_F = cfg.TRAINER.MCD.N_STEP_F

    def build_model(self):
        cfg = self.cfg

        print("Building F")
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print("# params: {:,}".format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model("F", self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print("Building C1")
        self.C1 = nn.Linear(fdim, self.num_classes)
        self.C1.to(self.device)
        print("# params: {:,}".format(count_num_param(self.C1)))
        self.optim_C1 = build_optimizer(self.C1, cfg.OPTIM)
        self.sched_C1 = build_lr_scheduler(self.optim_C1, cfg.OPTIM)
        self.register_model("C1", self.C1, self.optim_C1, self.sched_C1)

        print("Building C2")
        self.C2 = nn.Linear(fdim, self.num_classes)
        self.C2.to(self.device)
        print("# params: {:,}".format(count_num_param(self.C2)))
        self.optim_C2 = build_optimizer(self.C2, cfg.OPTIM)
        self.sched_C2 = build_lr_scheduler(self.optim_C2, cfg.OPTIM)
        self.register_model("C2", self.C2, self.optim_C2, self.sched_C2)

    def forward_backward(self, batch_x, batch_u):
        parsed = self.parse_batch_train(batch_x, batch_u)
        input_x, label_x, input_u = parsed

        # Step A
        feat_x = self.F(input_x)
        logit_x1 = self.C1(feat_x)
        logit_x2 = self.C2(feat_x)
        loss_x1 = F.cross_entropy(logit_x1, label_x)
        loss_x2 = F.cross_entropy(logit_x2, label_x)
        loss_step_A = loss_x1 + loss_x2
        self.model_backward_and_update(loss_step_A)

        # Step B
        with torch.no_grad():
            feat_x = self.F(input_x)
        logit_x1 = self.C1(feat_x)
        logit_x2 = self.C2(feat_x)
        loss_x1 = F.cross_entropy(logit_x1, label_x)
        loss_x2 = F.cross_entropy(logit_x2, label_x)
        loss_x = loss_x1 + loss_x2

        with torch.no_grad():
            feat_u = self.F(input_u)
        pred_u1 = F.softmax(self.C1(feat_u), 1)
        pred_u2 = F.softmax(self.C2(feat_u), 1)
        loss_dis = self.discrepancy(pred_u1, pred_u2)

        loss_step_B = loss_x - loss_dis
        self.model_backward_and_update(loss_step_B, ["C1", "C2"])

        # Step C
        for _ in range(self.n_step_F):
            feat_u = self.F(input_u)
            pred_u1 = F.softmax(self.C1(feat_u), 1)
            pred_u2 = F.softmax(self.C2(feat_u), 1)
            loss_step_C = self.discrepancy(pred_u1, pred_u2)
            self.model_backward_and_update(loss_step_C, "F")

        loss_summary = {
            "loss_step_A": loss_step_A.item(),
            "loss_step_B": loss_step_B.item(),
            "loss_step_C": loss_step_C.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def discrepancy(self, y1, y2):
        return (y1 - y2).abs().mean()

    def model_inference(self, input):
        feat = self.F(input)
        return self.C1(feat)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import ReverseGrad
from dassl.engine.trainer import SimpleNet


class Prototypes(nn.Module):

    def __init__(self, fdim, num_classes, temp=0.05):
        super().__init__()
        self.prototypes = nn.Linear(fdim, num_classes, bias=False)
        self.temp = temp

    def forward(self, x):
        x = F.normalize(x, p=2, dim=1)
        out = self.prototypes(x)
        out = out / self.temp
        return out


@TRAINER_REGISTRY.register()
class MME(TrainerXU):
    """Minimax Entropy.

    https://arxiv.org/abs/1904.06487.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.lmda = cfg.TRAINER.MME.LMDA

    def build_model(self):
        cfg = self.cfg

        print("Building F")
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print("# params: {:,}".format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model("F", self.F, self.optim_F, self.sched_F)

        print("Building C")
        self.C = Prototypes(self.F.fdim, self.num_classes)
        self.C.to(self.device)
        print("# params: {:,}".format(count_num_param(self.C)))
        self.optim_C = build_optimizer(self.C, cfg.OPTIM)
        self.sched_C = build_lr_scheduler(self.optim_C, cfg.OPTIM)
        self.register_model("C", self.C, self.optim_C, self.sched_C)

        self.revgrad = ReverseGrad()

    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)

        feat_x = self.F(input_x)
        logit_x = self.C(feat_x)
        loss_x = F.cross_entropy(logit_x, label_x)
        self.model_backward_and_update(loss_x)

        feat_u = self.F(input_u)
        feat_u = self.revgrad(feat_u)
        logit_u = self.C(feat_u)
        prob_u = F.softmax(logit_u, 1)
        loss_u = -(-prob_u * torch.log(prob_u + 1e-5)).sum(1).mean()
        self.model_backward_and_update(loss_u * self.lmda)

        loss_summary = {
            "loss_x": loss_x.item(),
            "acc_x": compute_accuracy(logit_x, label_x)[0].item(),
            "loss_u": loss_u.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def model_inference(self, input):
        return self.C(self.F(input))


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
================================================
import copy
from torch.nn import functional as F

from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy
from dassl.modeling.ops.utils import sigmoid_rampup, ema_model_update


@TRAINER_REGISTRY.register()
class SelfEnsembling(TrainerXU):
    """Self-ensembling for visual domain adaptation.

    https://arxiv.org/abs/1706.05208.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.ema_alpha = cfg.TRAINER.SE.EMA_ALPHA
        self.conf_thre = cfg.TRAINER.SE.CONF_THRE
        self.rampup = cfg.TRAINER.SE.RAMPUP

        self.teacher = copy.deepcopy(self.model)
        self.teacher.train()
        for param in self.teacher.parameters():
            param.requires_grad_(False)

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.K_TRANSFORMS == 2

    def forward_backward(self, batch_x, batch_u):
        global_step = self.batch_idx + self.epoch * self.num_batches
        parsed = self.parse_batch_train(batch_x, batch_u)
        input_x, label_x, input_u1, input_u2 = parsed

        logit_x = self.model(input_x)
        loss_x = F.cross_entropy(logit_x, label_x)

        prob_u = F.softmax(self.model(input_u1), 1)
        t_prob_u = F.softmax(self.teacher(input_u2), 1)
        loss_u = ((prob_u - t_prob_u)**2).sum(1)

        if self.conf_thre:
            max_prob = t_prob_u.max(1)[0]
            mask = (max_prob > self.conf_thre).float()
            loss_u = (loss_u * mask).mean()
        else:
            weight_u = sigmoid_rampup(global_step, self.rampup)
            loss_u = loss_u.mean() * weight_u

        loss = loss_x + loss_u
        self.model_backward_and_update(loss)

        ema_alpha = min(1 - 1 / (global_step+1), self.ema_alpha)
        ema_model_update(self.model, self.teacher, ema_alpha)

        loss_summary = {
            "loss_x": loss_x.item(),
            "acc_x": compute_accuracy(logit_x, label_x)[0].item(),
            "loss_u": loss_u.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch_x, batch_u):
        input_x = batch_x["img"][0]
        label_x = batch_x["label"]
        input_u = batch_u["img"]
        input_u1, input_u2 = input_u

        input_x = input_x.to(self.device)
        label_x = label_x.to(self.device)
        input_u1 = input_u1.to(self.device)
        input_u2 = input_u2.to(self.device)

        return input_x, label_x, input_u1, input_u2


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
================================================
from torch.nn import functional as F

from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.metrics import compute_accuracy


@TRAINER_REGISTRY.register()
class SourceOnly(TrainerXU):
    """Baseline model for domain adaptation, which is
    trained using source data only.
    """

    def forward_backward(self, batch_x, batch_u):
        input, label = self.parse_batch_train(batch_x, batch_u)
        output = self.model(input)
        loss = F.cross_entropy(output, label)
        self.model_backward_and_update(loss)

        loss_summary = {
            "loss": loss.item(),
            "acc": compute_accuracy(output, label)[0].item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch_x, batch_u):
        input = batch_x["img"]
        label = batch_x["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py
================================================
from .ddaig import DDAIG
from .daeldg import DAELDG
from .vanilla import Vanilla
from .crossgrad import CrossGrad


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
================================================
import torch
from torch.nn import functional as F

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.engine.trainer import SimpleNet


@TRAINER_REGISTRY.register()
class CrossGrad(TrainerX):
    """Cross-gradient training.

    https://arxiv.org/abs/1804.10745.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.eps_f = cfg.TRAINER.CG.EPS_F
        self.eps_d = cfg.TRAINER.CG.EPS_D
        self.alpha_f = cfg.TRAINER.CG.ALPHA_F
        self.alpha_d = cfg.TRAINER.CG.ALPHA_D

    def build_model(self):
        cfg = self.cfg

        print("Building F")
        self.F = SimpleNet(cfg, cfg.MODEL, self.num_classes)
        self.F.to(self.device)
        print("# params: {:,}".format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model("F", self.F, self.optim_F, self.sched_F)

        print("Building D")
        self.D = SimpleNet(cfg, cfg.MODEL, self.num_source_domains)
        self.D.to(self.device)
        print("# params: {:,}".format(count_num_param(self.D)))
        self.optim_D = build_optimizer(self.D, cfg.OPTIM)
        self.sched_D = build_lr_scheduler(self.optim_D, cfg.OPTIM)
        self.register_model("D", self.D, self.optim_D, self.sched_D)

    def forward_backward(self, batch):
        input, label, domain = self.parse_batch_train(batch)

        input.requires_grad = True

        # Compute domain perturbation
        loss_d = F.cross_entropy(self.D(input), domain)
        loss_d.backward()
        grad_d = torch.clamp(input.grad.data, min=-0.1, max=0.1)
        input_d = input.data + self.eps_f * grad_d

        # Compute label perturbation
        input.grad.data.zero_()
        loss_f = F.cross_entropy(self.F(input), label)
        loss_f.backward()
        grad_f = torch.clamp(input.grad.data, min=-0.1, max=0.1)
        input_f = input.data + self.eps_d * grad_f

        input = input.detach()

        # Update label net
        loss_f1 = F.cross_entropy(self.F(input), label)
        loss_f2 = F.cross_entropy(self.F(input_d), label)
        loss_f = (1 - self.alpha_f) * loss_f1 + self.alpha_f * loss_f2
        self.model_backward_and_update(loss_f, "F")

        # Update domain net
        loss_d1 = F.cross_entropy(self.D(input), domain)
        loss_d2 = F.cross_entropy(self.D(input_f), domain)
        loss_d = (1 - self.alpha_d) * loss_d1 + self.alpha_d * loss_d2
        self.model_backward_and_update(loss_d, "D")

        loss_summary = {"loss_f": loss_f.item(), "loss_d": loss_d.item()}

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def model_inference(self, input):
        return self.F(input)


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
================================================
import torch
import torch.nn as nn

from dassl.data import DataManager
from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.engine.trainer import SimpleNet
from dassl.data.transforms import build_transform
from dassl.modeling.ops.utils import create_onehot


class Experts(nn.Module):

    def __init__(self, n_source, fdim, num_classes):
        super().__init__()
        self.linears = nn.ModuleList(
            [nn.Linear(fdim, num_classes) for _ in range(n_source)]
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, i, x):
        x = self.linears[i](x)
        x = self.softmax(x)
        return x


@TRAINER_REGISTRY.register()
class DAELDG(TrainerX):
    """Domain Adaptive Ensemble Learning.

    DG version: only use labeled source data.

    https://arxiv.org/abs/2003.07325.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        n_domain = cfg.DATALOADER.TRAIN_X.N_DOMAIN
        batch_size = cfg.DATALOADER.TRAIN_X.BATCH_SIZE
        if n_domain <= 0:
            n_domain = self.num_source_domains
        self.split_batch = batch_size // n_domain
        self.n_domain = n_domain

        self.conf_thre = cfg.TRAINER.DAEL.CONF_THRE

    def check_cfg(self, cfg):
        assert cfg.DATALOADER.TRAIN_X.SAMPLER == "RandomDomainSampler"
        assert len(cfg.TRAINER.DAEL.STRONG_TRANSFORMS) > 0

    def build_data_loader(self):
        cfg = self.cfg
        tfm_train = build_transform(cfg, is_train=True)
        custom_tfm_train = [tfm_train]
        choices = cfg.TRAINER.DAEL.STRONG_TRANSFORMS
        tfm_train_strong = build_transform(cfg, is_train=True, choices=choices)
        custom_tfm_train += [tfm_train_strong]
        dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train)
        self.train_loader_x = dm.train_loader_x
        self.train_loader_u = dm.train_loader_u
        self.val_loader = dm.val_loader
        self.test_loader = dm.test_loader
        self.num_classes = dm.num_classes
        self.num_source_domains = dm.num_source_domains
        self.lab2cname = dm.lab2cname

    def build_model(self):
        cfg = self.cfg

        print("Building F")
        self.F = SimpleNet(cfg, cfg.MODEL, 0)
        self.F.to(self.device)
        print("# params: {:,}".format(count_num_param(self.F)))
        self.optim_F = build_optimizer(self.F, cfg.OPTIM)
        self.sched_F = build_lr_scheduler(self.optim_F, cfg.OPTIM)
        self.register_model("F", self.F, self.optim_F, self.sched_F)
        fdim = self.F.fdim

        print("Building E")
        self.E = Experts(self.num_source_domains, fdim, self.num_classes)
        self.E.to(self.device)
        print("# params: {:,}".format(count_num_param(self.E)))
        self.optim_E = build_optimizer(self.E, cfg.OPTIM)
        self.sched_E = build_lr_scheduler(self.optim_E, cfg.OPTIM)
        self.register_model("E", self.E, self.optim_E, self.sched_E)

    def forward_backward(self, batch):
        parsed_data = self.parse_batch_train(batch)
        input, input2, label, domain = parsed_data

        input = torch.split(input, self.split_batch, 0)
        input2 = torch.split(input2, self.split_batch, 0)
        label = torch.split(label, self.split_batch, 0)
        domain = torch.split(domain, self.split_batch, 0)
        domain = [d[0].item() for d in domain]

        loss_x = 0
        loss_cr = 0
        acc = 0

        feat = [self.F(x) for x in input]
        feat2 = [self.F(x) for x in input2]

        for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
            cr_s = [j for j in domain if j != i]

            # Learning expert
            pred_i = self.E(i, feat_i)
            loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
            expert_label_i = pred_i.detach()
            acc += compute_accuracy(pred_i.detach(),
                                    label_i.max(1)[1])[0].item()

            # Consistency regularization
            cr_pred = []
            for j in cr_s:
                pred_j = self.E(j, feat2_i)
                pred_j = pred_j.unsqueeze(1)
                cr_pred.append(pred_j)
            cr_pred = torch.cat(cr_pred, 1)
            cr_pred = cr_pred.mean(1)
            loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()

        loss_x /= self.n_domain
        loss_cr /= self.n_domain
        acc /= self.n_domain

        loss = 0
        loss += loss_x
        loss += loss_cr
        self.model_backward_and_update(loss)

        loss_summary = {
            "loss_x": loss_x.item(),
            "acc": acc,
            "loss_cr": loss_cr.item()
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch["img"]
        input2 = batch["img2"]
        label = batch["label"]
        domain = batch["domain"]

        label = create_onehot(label, self.num_classes)

        input = input.to(self.device)
        input2 = input2.to(self.device)
        label = label.to(self.device)

        return input, input2, label, domain

    def model_inference(self, input):
        f = self.F(input)
        p = []
        for k in range(self.num_source_domains):
            p_k = self.E(k, f)
            p_k = p_k.unsqueeze(1)
            p.append(p_k)
        p = torch.cat(p, 1)
        p = p.mean(1)
        return p


================================================
FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
================================================
import torch
from torch.nn import functional as F

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import count_num_param
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.modeling import build_network
from dassl.engine.trainer import SimpleNet


@TRAINER_REGISTRY.register()
class DDAIG(TrainerX):
    """Deep Domain
Download .txt
gitextract_cxcff6dn/

├── Dassl.ProGrad.pytorch/
│   ├── .flake8
│   ├── .gitignore
│   ├── .isort.cfg
│   ├── .style.yapf
│   ├── DATASETS.md
│   ├── LICENSE
│   ├── README.md
│   ├── configs/
│   │   ├── README.md
│   │   ├── datasets/
│   │   │   ├── da/
│   │   │   │   ├── cifar_stl.yaml
│   │   │   │   ├── digit5.yaml
│   │   │   │   ├── domainnet.yaml
│   │   │   │   ├── mini_domainnet.yaml
│   │   │   │   ├── office31.yaml
│   │   │   │   ├── office_home.yaml
│   │   │   │   └── visda17.yaml
│   │   │   ├── dg/
│   │   │   │   ├── cifar100_c.yaml
│   │   │   │   ├── cifar10_c.yaml
│   │   │   │   ├── digit_single.yaml
│   │   │   │   ├── digits_dg.yaml
│   │   │   │   ├── office_home_dg.yaml
│   │   │   │   ├── pacs.yaml
│   │   │   │   └── vlcs.yaml
│   │   │   └── ssl/
│   │   │       ├── cifar10.yaml
│   │   │       ├── cifar100.yaml
│   │   │       ├── stl10.yaml
│   │   │       └── svhn.yaml
│   │   └── trainers/
│   │       ├── da/
│   │       │   ├── dael/
│   │       │   │   ├── digit5.yaml
│   │       │   │   ├── domainnet.yaml
│   │       │   │   └── mini_domainnet.yaml
│   │       │   ├── m3sda/
│   │       │   │   ├── digit5.yaml
│   │       │   │   ├── domainnet.yaml
│   │       │   │   └── mini_domainnet.yaml
│   │       │   └── source_only/
│   │       │       ├── digit5.yaml
│   │       │       ├── mini_domainnet.yaml
│   │       │       ├── office31.yaml
│   │       │       └── visda17.yaml
│   │       ├── dg/
│   │       │   ├── dael/
│   │       │   │   ├── digits_dg.yaml
│   │       │   │   ├── office_home_dg.yaml
│   │       │   │   └── pacs.yaml
│   │       │   ├── ddaig/
│   │       │   │   ├── digits_dg.yaml
│   │       │   │   ├── office_home_dg.yaml
│   │       │   │   └── pacs.yaml
│   │       │   └── vanilla/
│   │       │       ├── digits_dg.yaml
│   │       │       ├── mini_domainnet.yaml
│   │       │       ├── office_home_dg.yaml
│   │       │       └── pacs.yaml
│   │       └── ssl/
│   │           └── fixmatch/
│   │               └── cifar10.yaml
│   ├── dassl/
│   │   ├── __init__.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   └── defaults.py
│   │   ├── data/
│   │   │   ├── __init__.py
│   │   │   ├── data_manager.py
│   │   │   ├── datasets/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base_dataset.py
│   │   │   │   ├── build.py
│   │   │   │   ├── da/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── cifarstl.py
│   │   │   │   │   ├── digit5.py
│   │   │   │   │   ├── domainnet.py
│   │   │   │   │   ├── mini_domainnet.py
│   │   │   │   │   ├── office31.py
│   │   │   │   │   ├── office_home.py
│   │   │   │   │   └── visda17.py
│   │   │   │   ├── dg/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── cifar_c.py
│   │   │   │   │   ├── digit_single.py
│   │   │   │   │   ├── digits_dg.py
│   │   │   │   │   ├── office_home_dg.py
│   │   │   │   │   ├── pacs.py
│   │   │   │   │   └── vlcs.py
│   │   │   │   └── ssl/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── cifar.py
│   │   │   │       ├── stl10.py
│   │   │   │       └── svhn.py
│   │   │   ├── samplers.py
│   │   │   └── transforms/
│   │   │       ├── __init__.py
│   │   │       ├── autoaugment.py
│   │   │       ├── randaugment.py
│   │   │       └── transforms.py
│   │   ├── engine/
│   │   │   ├── __init__.py
│   │   │   ├── build.py
│   │   │   ├── da/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── adabn.py
│   │   │   │   ├── adda.py
│   │   │   │   ├── dael.py
│   │   │   │   ├── dann.py
│   │   │   │   ├── m3sda.py
│   │   │   │   ├── mcd.py
│   │   │   │   ├── mme.py
│   │   │   │   ├── self_ensembling.py
│   │   │   │   └── source_only.py
│   │   │   ├── dg/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── crossgrad.py
│   │   │   │   ├── daeldg.py
│   │   │   │   ├── ddaig.py
│   │   │   │   └── vanilla.py
│   │   │   ├── ssl/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── entmin.py
│   │   │   │   ├── fixmatch.py
│   │   │   │   ├── mean_teacher.py
│   │   │   │   ├── mixmatch.py
│   │   │   │   └── sup_baseline.py
│   │   │   └── trainer.py
│   │   ├── evaluation/
│   │   │   ├── __init__.py
│   │   │   ├── build.py
│   │   │   └── evaluator.py
│   │   ├── metrics/
│   │   │   ├── __init__.py
│   │   │   ├── accuracy.py
│   │   │   └── distance.py
│   │   ├── modeling/
│   │   │   ├── __init__.py
│   │   │   ├── backbone/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── alexnet.py
│   │   │   │   ├── backbone.py
│   │   │   │   ├── build.py
│   │   │   │   ├── cnn_digit5_m3sda.py
│   │   │   │   ├── cnn_digitsdg.py
│   │   │   │   ├── cnn_digitsingle.py
│   │   │   │   ├── efficientnet/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── model.py
│   │   │   │   │   └── utils.py
│   │   │   │   ├── mobilenetv2.py
│   │   │   │   ├── preact_resnet18.py
│   │   │   │   ├── resnet.py
│   │   │   │   ├── shufflenetv2.py
│   │   │   │   ├── vgg.py
│   │   │   │   └── wide_resnet.py
│   │   │   ├── head/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── build.py
│   │   │   │   └── mlp.py
│   │   │   ├── network/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── build.py
│   │   │   │   └── ddaig_fcn.py
│   │   │   └── ops/
│   │   │       ├── __init__.py
│   │   │       ├── cross_entropy.py
│   │   │       ├── dsbn.py
│   │   │       ├── efdmix.py
│   │   │       ├── mixstyle.py
│   │   │       ├── mixup.py
│   │   │       ├── mmd.py
│   │   │       ├── optimal_transport.py
│   │   │       ├── reverse_grad.py
│   │   │       ├── sequential2.py
│   │   │       ├── transnorm.py
│   │   │       └── utils.py
│   │   ├── optim/
│   │   │   ├── __init__.py
│   │   │   ├── lr_scheduler.py
│   │   │   ├── optimizer.py
│   │   │   └── radam.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── logger.py
│   │       ├── meters.py
│   │       ├── registry.py
│   │       ├── tools.py
│   │       └── torchtools.py
│   ├── datasets/
│   │   ├── da/
│   │   │   ├── cifar_stl.py
│   │   │   ├── digit5.py
│   │   │   └── visda17.sh
│   │   ├── dg/
│   │   │   └── cifar_c.py
│   │   └── ssl/
│   │       ├── cifar10_cifar100_svhn.py
│   │       └── stl10.py
│   ├── linter.sh
│   ├── requirements.txt
│   ├── setup.py
│   └── tools/
│       ├── parse_test_res.py
│       ├── replace_text.py
│       └── train.py
├── ProGrad.public/
│   ├── .gitignore
│   ├── DATASETS.md
│   ├── LICENSE
│   ├── README.md
│   ├── clip/
│   │   ├── __init__.py
│   │   ├── clip.py
│   │   ├── model.py
│   │   └── simple_tokenizer.py
│   ├── configs/
│   │   ├── datasets/
│   │   │   ├── caltech101.yaml
│   │   │   ├── dtd.yaml
│   │   │   ├── eurosat.yaml
│   │   │   ├── fgvc_aircraft.yaml
│   │   │   ├── food101.yaml
│   │   │   ├── imagenet.yaml
│   │   │   ├── imagenet_a.yaml
│   │   │   ├── imagenet_r.yaml
│   │   │   ├── imagenet_sketch.yaml
│   │   │   ├── imagenetv2.yaml
│   │   │   ├── oxford_flowers.yaml
│   │   │   ├── oxford_pets.yaml
│   │   │   ├── stanford_cars.yaml
│   │   │   ├── sun397.yaml
│   │   │   └── ucf101.yaml
│   │   └── trainers/
│   │       ├── CoCoOp/
│   │       │   ├── rn50_c4_ep10_batch1_ctxv1.yaml
│   │       │   ├── rn50_ep100_init.yaml
│   │       │   ├── rn50_ep50.yaml
│   │       │   ├── vit_b16_c16_ep10_batch1.yaml
│   │       │   ├── vit_b16_c4_ep10_batch1.yaml
│   │       │   ├── vit_b16_c4_ep10_batch1_ctxv1.yaml
│   │       │   └── vit_b16_c8_ep10_batch1.yaml
│   │       ├── CoOp/
│   │       │   ├── rn50.yaml
│   │       │   ├── rn50_ep100.yaml
│   │       │   ├── rn50_ep50.yaml
│   │       │   └── rn50_val.yaml
│   │       └── ProGrad/
│   │           ├── rn50.yaml
│   │           ├── rn50_ep100.yaml
│   │           └── rn50_ep50.yaml
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── caltech101.py
│   │   ├── dtd.py
│   │   ├── eurosat.py
│   │   ├── fgvc_aircraft.py
│   │   ├── food101.py
│   │   ├── imagenet.py
│   │   ├── imagenet_a.py
│   │   ├── imagenet_r.py
│   │   ├── imagenet_sketch.py
│   │   ├── imagenetv2.py
│   │   ├── oxford_flowers.py
│   │   ├── oxford_pets.py
│   │   ├── stanford_cars.py
│   │   ├── sun397.py
│   │   └── ucf101.py
│   ├── interpret_prompt.py
│   ├── lpclip/
│   │   ├── README.md
│   │   ├── feat_extractor.py
│   │   ├── feat_extractor.sh
│   │   ├── linear_probe.py
│   │   ├── linear_probe.sh
│   │   └── linear_probe_transfer.py
│   ├── parse_test_res.py
│   ├── requirements.txt
│   ├── scripts/
│   │   ├── base2new_test_main.sh
│   │   ├── base2new_test_prograd.sh
│   │   ├── base2new_train_main.sh
│   │   ├── base2new_train_prograd.sh
│   │   ├── eval.sh
│   │   ├── main.sh
│   │   ├── prograd.sh
│   │   └── zeroshot.sh
│   ├── train.py
│   └── trainers/
│       ├── __init__.py
│       ├── cocoop.py
│       ├── coop.py
│       ├── imagenet_templates.py
│       ├── prograd.py
│       └── zsclip.py
└── readme.md
Download .txt
SYMBOL INDEX (911 symbols across 122 files)

FILE: Dassl.ProGrad.pytorch/dassl/config/__init__.py
  function get_cfg_default (line 4) | def get_cfg_default():

FILE: Dassl.ProGrad.pytorch/dassl/data/data_manager.py
  function build_data_loader (line 19) | def build_data_loader(
  class DataManager (line 57) | class DataManager:
    method __init__ (line 59) | def __init__(
    method num_classes (line 160) | def num_classes(self):
    method num_source_domains (line 164) | def num_source_domains(self):
    method lab2cname (line 168) | def lab2cname(self):
    method show_dataset_summary (line 171) | def show_dataset_summary(self, cfg):
  class DatasetWrapper (line 194) | class DatasetWrapper(TorchDataset):
    method __init__ (line 196) | def __init__(self, cfg, data_source, transform=None, is_train=False):
    method __len__ (line 223) | def __len__(self):
    method __getitem__ (line 226) | def __getitem__(self, idx):
    method _transform_image (line 254) | def _transform_image(self, tfm, img0):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py
  class Datum (line 12) | class Datum:
    method __init__ (line 22) | def __init__(self, impath="", label=0, domain=0, classname=""):
    method impath (line 32) | def impath(self):
    method label (line 36) | def label(self):
    method domain (line 40) | def domain(self):
    method classname (line 44) | def classname(self):
  class DatasetBase (line 48) | class DatasetBase:
    method __init__ (line 58) | def __init__(self, train_x=None, train_u=None, val=None, test=None):
    method train_x (line 68) | def train_x(self):
    method train_u (line 72) | def train_u(self):
    method val (line 76) | def val(self):
    method test (line 80) | def test(self):
    method lab2cname (line 84) | def lab2cname(self):
    method classnames (line 88) | def classnames(self):
    method num_classes (line 92) | def num_classes(self):
    method get_num_classes (line 95) | def get_num_classes(self, data_source):
    method get_lab2cname (line 106) | def get_lab2cname(self, data_source):
    method check_input_domains (line 121) | def check_input_domains(self, source_domains, target_domains):
    method is_input_domain_valid (line 125) | def is_input_domain_valid(self, input_domains):
    method download_data (line 133) | def download_data(self, url, dst, from_gdrive=True):
    method generate_fewshot_dataset (line 155) | def generate_fewshot_dataset(
    method split_dataset_by_label (line 199) | def split_dataset_by_label(self, data_source):
    method split_dataset_by_domain (line 213) | def split_dataset_by_domain(self, data_source):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/build.py
  function build_dataset (line 6) | def build_dataset(cfg):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py
  class CIFARSTL (line 10) | class CIFARSTL(DatasetBase):
    method __init__ (line 37) | def __init__(self, cfg):
    method _read_data (line 51) | def _read_data(self, input_domains, split="train"):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py
  function read_image_list (line 17) | def read_image_list(im_dir, n_max=None, n_repeat=None):
  function load_mnist (line 35) | def load_mnist(dataset_dir, split="train"):
  function load_mnist_m (line 41) | def load_mnist_m(dataset_dir, split="train"):
  function load_svhn (line 47) | def load_svhn(dataset_dir, split="train"):
  function load_syn (line 53) | def load_syn(dataset_dir, split="train"):
  function load_usps (line 59) | def load_usps(dataset_dir, split="train"):
  class Digit5 (line 66) | class Digit5(DatasetBase):
    method __init__ (line 93) | def __init__(self, cfg):
    method _read_data (line 107) | def _read_data(self, input_domains, split="train"):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py
  class DomainNet (line 8) | class DomainNet(DatasetBase):
    method __init__ (line 30) | def __init__(self, cfg):
    method _read_data (line 46) | def _read_data(self, input_domains, split="train"):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/mini_domainnet.py
  class miniDomainNet (line 8) | class miniDomainNet(DatasetBase):
    method __init__ (line 20) | def __init__(self, cfg):
    method _read_data (line 35) | def _read_data(self, input_domains, split="train"):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py
  class Office31 (line 10) | class Office31(DatasetBase):
    method __init__ (line 27) | def __init__(self, cfg):
    method _read_data (line 41) | def _read_data(self, input_domains):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py
  class OfficeHome (line 10) | class OfficeHome(DatasetBase):
    method __init__ (line 27) | def __init__(self, cfg):
    method _read_data (line 41) | def _read_data(self, input_domains):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py
  class VisDA17 (line 8) | class VisDA17(DatasetBase):
    method __init__ (line 23) | def __init__(self, cfg):
    method _read_data (line 37) | def _read_data(self, dname):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py
  class CIFAR10C (line 32) | class CIFAR10C(DatasetBase):
    method __init__ (line 49) | def __init__(self, cfg):
    method _read_data (line 87) | def _read_data(self, data_dir):
  class CIFAR100C (line 105) | class CIFAR100C(CIFAR10C):
    method __init__ (line 122) | def __init__(self, cfg):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py
  function read_image_list (line 16) | def read_image_list(im_dir, n_max=None, n_repeat=None):
  function load_mnist (line 36) | def load_mnist(dataset_dir, split="train"):
  function load_mnist_m (line 42) | def load_mnist_m(dataset_dir, split="train"):
  function load_svhn (line 48) | def load_svhn(dataset_dir, split="train"):
  function load_syn (line 54) | def load_syn(dataset_dir, split="train"):
  function load_usps (line 60) | def load_usps(dataset_dir, split="train"):
  class DigitSingle (line 66) | class DigitSingle(DatasetBase):
    method __init__ (line 98) | def __init__(self, cfg):
    method _read_data (line 112) | def _read_data(self, input_domains, split="train"):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py
  class DigitsDG (line 11) | class DigitsDG(DatasetBase):
    method __init__ (line 35) | def __init__(self, cfg):
    method read_data (line 60) | def read_data(dataset_dir, input_domains, split):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/office_home_dg.py
  class OfficeHomeDG (line 9) | class OfficeHomeDG(DatasetBase):
    method __init__ (line 27) | def __init__(self, cfg):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py
  class PACS (line 8) | class PACS(DatasetBase):
    method __init__ (line 28) | def __init__(self, cfg):
    method _read_data (line 48) | def _read_data(self, input_domains, split):
    method _read_split_pacs (line 79) | def _read_split_pacs(self, split_file):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py
  class VLCS (line 11) | class VLCS(DatasetBase):
    method __init__ (line 26) | def __init__(self, cfg):
    method _read_data (line 44) | def _read_data(self, input_domains, split):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py
  class CIFAR10 (line 12) | class CIFAR10(DatasetBase):
    method __init__ (line 22) | def __init__(self, cfg):
    method _read_data_train (line 43) | def _read_data_train(self, data_dir, num_labeled, val_percent):
    method _read_data_test (line 79) | def _read_data_test(self, data_dir):
  class CIFAR100 (line 97) | class CIFAR100(CIFAR10):
    method __init__ (line 107) | def __init__(self, cfg):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py
  class STL10 (line 11) | class STL10(DatasetBase):
    method __init__ (line 28) | def __init__(self, cfg):
    method _read_data_train (line 52) | def _read_data_train(self, data_dir, fold, fold_file):
    method _read_data_all (line 73) | def _read_data_all(self, data_dir):

FILE: Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py
  class SVHN (line 6) | class SVHN(CIFAR10):
    method __init__ (line 16) | def __init__(self, cfg):

FILE: Dassl.ProGrad.pytorch/dassl/data/samplers.py
  class RandomDomainSampler (line 8) | class RandomDomainSampler(Sampler):
    method __init__ (line 18) | def __init__(self, data_source, batch_size, n_domain):
    method __iter__ (line 38) | def __iter__(self):
    method __len__ (line 60) | def __len__(self):
  class SeqDomainSampler (line 64) | class SeqDomainSampler(Sampler):
    method __init__ (line 73) | def __init__(self, data_source, batch_size):
    method __iter__ (line 93) | def __iter__(self):
    method __len__ (line 113) | def __len__(self):
  class RandomClassSampler (line 117) | class RandomClassSampler(Sampler):
    method __init__ (line 129) | def __init__(self, data_source, batch_size, n_ins):
    method __iter__ (line 149) | def __iter__(self):
    method __len__ (line 177) | def __len__(self):
  function build_sampler (line 181) | def build_sampler(

FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py
  class ImageNetPolicy (line 9) | class ImageNetPolicy:
    method __init__ (line 23) | def __init__(self, fillcolor=(128, 128, 128)):
    method __call__ (line 52) | def __call__(self, img):
    method __repr__ (line 56) | def __repr__(self):
  class CIFAR10Policy (line 60) | class CIFAR10Policy:
    method __init__ (line 74) | def __init__(self, fillcolor=(128, 128, 128)):
    method __call__ (line 103) | def __call__(self, img):
    method __repr__ (line 107) | def __repr__(self):
  class SVHNPolicy (line 111) | class SVHNPolicy:
    method __init__ (line 125) | def __init__(self, fillcolor=(128, 128, 128)):
    method __call__ (line 154) | def __call__(self, img):
    method __repr__ (line 158) | def __repr__(self):
  class SubPolicy (line 162) | class SubPolicy(object):
    method __init__ (line 164) | def __init__(
    method __call__ (line 268) | def __call__(self, img):

FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py
  function ShearX (line 16) | def ShearX(img, v):
  function ShearY (line 23) | def ShearY(img, v):
  function TranslateX (line 30) | def TranslateX(img, v):
  function TranslateXabs (line 39) | def TranslateXabs(img, v):
  function TranslateY (line 47) | def TranslateY(img, v):
  function TranslateYabs (line 56) | def TranslateYabs(img, v):
  function Rotate (line 64) | def Rotate(img, v):
  function AutoContrast (line 71) | def AutoContrast(img, _):
  function Invert (line 75) | def Invert(img, _):
  function Equalize (line 79) | def Equalize(img, _):
  function Flip (line 83) | def Flip(img, _):
  function Solarize (line 87) | def Solarize(img, v):
  function SolarizeAdd (line 92) | def SolarizeAdd(img, addition=0, threshold=128):
  function Posterize (line 101) | def Posterize(img, v):
  function Contrast (line 107) | def Contrast(img, v):
  function Color (line 112) | def Color(img, v):
  function Brightness (line 117) | def Brightness(img, v):
  function Sharpness (line 122) | def Sharpness(img, v):
  function Cutout (line 127) | def Cutout(img, v):
  function CutoutAbs (line 137) | def CutoutAbs(img, v):
  function SamplePairing (line 159) | def SamplePairing(imgs):
  function Identity (line 169) | def Identity(img, v):
  class Lighting (line 173) | class Lighting:
    method __init__ (line 176) | def __init__(self, alphastd, eigval, eigvec):
    method __call__ (line 181) | def __call__(self, img):
  class CutoutDefault (line 195) | class CutoutDefault:
    method __init__ (line 200) | def __init__(self, length):
    method __call__ (line 203) | def __call__(self, img):
  function randaugment_list (line 221) | def randaugment_list():
  function randaugment_list2 (line 267) | def randaugment_list2():
  function fixmatch_list (line 289) | def fixmatch_list():
  class RandAugment (line 311) | class RandAugment:
    method __init__ (line 313) | def __init__(self, n=2, m=10):
    method __call__ (line 319) | def __call__(self, img):
  class RandAugment2 (line 329) | class RandAugment2:
    method __init__ (line 331) | def __init__(self, n=2, p=0.6):
    method __call__ (line 336) | def __call__(self, img):
  class RandAugmentFixMatch (line 349) | class RandAugmentFixMatch:
    method __init__ (line 351) | def __init__(self, n=2):
    method __call__ (line 355) | def __call__(self, img):

FILE: Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py
  class Random2DTranslation (line 42) | class Random2DTranslation:
    method __init__ (line 55) | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
    method __call__ (line 61) | def __call__(self, img):
  class InstanceNormalization (line 80) | class InstanceNormalization:
    method __init__ (line 90) | def __init__(self, eps=1e-8):
    method __call__ (line 93) | def __call__(self, img):
  class Cutout (line 101) | class Cutout:
    method __init__ (line 113) | def __init__(self, n_holes=1, length=16):
    method __call__ (line 117) | def __call__(self, img):
  class GaussianNoise (line 147) | class GaussianNoise:
    method __init__ (line 150) | def __init__(self, mean=0, std=0.15, p=0.5):
    method __call__ (line 155) | def __call__(self, img):
  function build_transform (line 162) | def build_transform(cfg, is_train=True, choices=None):
  function _build_transform_train (line 192) | def _build_transform_train(cfg, choices, target_size, normalize):
  function _build_transform_test (line 313) | def _build_transform_test(cfg, choices, target_size, normalize):

FILE: Dassl.ProGrad.pytorch/dassl/engine/build.py
  function build_trainer (line 6) | def build_trainer(cfg):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py
  class AdaBN (line 8) | class AdaBN(TrainerXU):
    method __init__ (line 14) | def __init__(self, cfg):
    method check_cfg (line 18) | def check_cfg(self, cfg):
    method before_epoch (line 23) | def before_epoch(self):
    method forward_backward (line 32) | def forward_backward(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/adda.py
  class ADDA (line 12) | class ADDA(TrainerXU):
    method __init__ (line 18) | def __init__(self, cfg):
    method check_cfg (line 33) | def check_cfg(self, cfg):
    method build_critic (line 38) | def build_critic(self):
    method forward_backward (line 57) | def forward_backward(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dael.py
  class Experts (line 14) | class Experts(nn.Module):
    method __init__ (line 16) | def __init__(self, n_source, fdim, num_classes):
    method forward (line 23) | def forward(self, i, x):
  class DAEL (line 30) | class DAEL(TrainerXU):
    method __init__ (line 36) | def __init__(self, cfg):
    method check_cfg (line 48) | def check_cfg(self, cfg):
    method build_data_loader (line 53) | def build_data_loader(self):
    method build_model (line 69) | def build_model(self):
    method forward_backward (line 89) | def forward_backward(self, batch_x, batch_u):
    method parse_batch_train (line 183) | def parse_batch_train(self, batch_x, batch_u):
    method model_inference (line 201) | def model_inference(self, input):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/dann.py
  class DANN (line 14) | class DANN(TrainerXU):
    method __init__ (line 20) | def __init__(self, cfg):
    method build_critic (line 26) | def build_critic(self):
    method forward_backward (line 46) | def forward_backward(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py
  class PairClassifiers (line 11) | class PairClassifiers(nn.Module):
    method __init__ (line 13) | def __init__(self, fdim, num_classes):
    method forward (line 18) | def forward(self, x):
  class M3SDA (line 27) | class M3SDA(TrainerXU):
    method __init__ (line 33) | def __init__(self, cfg):
    method check_cfg (line 45) | def check_cfg(self, cfg):
    method build_model (line 49) | def build_model(self):
    method forward_backward (line 74) | def forward_backward(self, batch_x, batch_u):
    method moment_distance (line 153) | def moment_distance(self, x, u):
    method pairwise_distance (line 166) | def pairwise_distance(self, x, u):
    method euclidean (line 183) | def euclidean(self, input1, input2):
    method discrepancy (line 186) | def discrepancy(self, y1, y2):
    method parse_batch_train (line 189) | def parse_batch_train(self, batch_x, batch_u):
    method model_inference (line 201) | def model_inference(self, input):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py
  class MCD (line 12) | class MCD(TrainerXU):
    method __init__ (line 18) | def __init__(self, cfg):
    method build_model (line 22) | def build_model(self):
    method forward_backward (line 50) | def forward_backward(self, batch_x, batch_u):
    method discrepancy (line 100) | def discrepancy(self, y1, y2):
    method model_inference (line 103) | def model_inference(self, input):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/mme.py
  class Prototypes (line 13) | class Prototypes(nn.Module):
    method __init__ (line 15) | def __init__(self, fdim, num_classes, temp=0.05):
    method forward (line 20) | def forward(self, x):
  class MME (line 28) | class MME(TrainerXU):
    method __init__ (line 34) | def __init__(self, cfg):
    method build_model (line 38) | def build_model(self):
    method forward_backward (line 59) | def forward_backward(self, batch_x, batch_u):
    method model_inference (line 85) | def model_inference(self, input):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py
  class SelfEnsembling (line 10) | class SelfEnsembling(TrainerXU):
    method __init__ (line 16) | def __init__(self, cfg):
    method check_cfg (line 27) | def check_cfg(self, cfg):
    method forward_backward (line 30) | def forward_backward(self, batch_x, batch_u):
    method parse_batch_train (line 67) | def parse_batch_train(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py
  class SourceOnly (line 8) | class SourceOnly(TrainerXU):
    method forward_backward (line 13) | def forward_backward(self, batch_x, batch_u):
    method parse_batch_train (line 29) | def parse_batch_train(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py
  class CrossGrad (line 11) | class CrossGrad(TrainerX):
    method __init__ (line 17) | def __init__(self, cfg):
    method build_model (line 24) | def build_model(self):
    method forward_backward (line 43) | def forward_backward(self, batch):
    method model_inference (line 82) | def model_inference(self, input):

FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py
  class Experts (line 14) | class Experts(nn.Module):
    method __init__ (line 16) | def __init__(self, n_source, fdim, num_classes):
    method forward (line 23) | def forward(self, i, x):
  class DAELDG (line 30) | class DAELDG(TrainerX):
    method __init__ (line 38) | def __init__(self, cfg):
    method check_cfg (line 49) | def check_cfg(self, cfg):
    method build_data_loader (line 53) | def build_data_loader(self):
    method build_model (line 69) | def build_model(self):
    method forward_backward (line 89) | def forward_backward(self, batch):
    method parse_batch_train (line 146) | def parse_batch_train(self, batch):
    method model_inference (line 160) | def model_inference(self, input):

FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py
  class DDAIG (line 12) | class DDAIG(TrainerX):
    method __init__ (line 18) | def __init__(self, cfg):
    method build_model (line 27) | def build_model(self):
    method forward_backward (line 54) | def forward_backward(self, batch):
    method model_inference (line 106) | def model_inference(self, input):

FILE: Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py
  class Vanilla (line 8) | class Vanilla(TrainerX):
    method forward_backward (line 11) | def forward_backward(self, batch):
    method parse_batch_train (line 27) | def parse_batch_train(self, batch):

FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py
  class EntMin (line 9) | class EntMin(TrainerXU):
    method __init__ (line 15) | def __init__(self, cfg):
    method forward_backward (line 19) | def forward_backward(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py
  class FixMatch (line 11) | class FixMatch(TrainerXU):
    method __init__ (line 18) | def __init__(self, cfg):
    method check_cfg (line 23) | def check_cfg(self, cfg):
    method build_data_loader (line 26) | def build_data_loader(self):
    method assess_y_pred_quality (line 40) | def assess_y_pred_quality(self, y_pred, y_true, mask):
    method forward_backward (line 52) | def forward_backward(self, batch_x, batch_u):
    method parse_batch_train (line 96) | def parse_batch_train(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py
  class MeanTeacher (line 10) | class MeanTeacher(TrainerXU):
    method __init__ (line 16) | def __init__(self, cfg):
    method forward_backward (line 27) | def forward_backward(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py
  class MixMatch (line 12) | class MixMatch(TrainerXU):
    method __init__ (line 18) | def __init__(self, cfg):
    method check_cfg (line 25) | def check_cfg(self, cfg):
    method forward_backward (line 28) | def forward_backward(self, batch_x, batch_u):
    method parse_batch_train (line 88) | def parse_batch_train(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py
  class SupBaseline (line 8) | class SupBaseline(TrainerXU):
    method forward_backward (line 11) | def forward_backward(self, batch_x, batch_u):
    method parse_batch_train (line 27) | def parse_batch_train(self, batch_x, batch_u):

FILE: Dassl.ProGrad.pytorch/dassl/engine/trainer.py
  class SimpleNet (line 23) | class SimpleNet(nn.Module):
    method __init__ (line 28) | def __init__(self, cfg, model_cfg, num_classes, **kwargs):
    method fdim (line 59) | def fdim(self):
    method forward (line 62) | def forward(self, x, return_feature=False):
  class TrainerBase (line 78) | class TrainerBase:
    method __init__ (line 81) | def __init__(self):
    method register_model (line 87) | def register_model(self, name="model", model=None, optim=None, sched=N...
    method get_model_names (line 109) | def get_model_names(self, names=None):
    method save_model (line 119) | def save_model(self, epoch, directory, is_best=False, model_name=""):
    method resume_model_if_exist (line 145) | def resume_model_if_exist(self, directory):
    method load_model (line 172) | def load_model(self, directory, epoch=None):
    method set_model_mode (line 206) | def set_model_mode(self, mode="train", names=None):
    method update_lr (line 217) | def update_lr(self, names=None):
    method detect_anomaly (line 224) | def detect_anomaly(self, loss):
    method init_writer (line 228) | def init_writer(self, log_dir):
    method close_writer (line 236) | def close_writer(self):
    method write_scalar (line 240) | def write_scalar(self, tag, scalar_value, global_step=None):
    method train (line 248) | def train(self, start_epoch, max_epoch):
    method before_train (line 260) | def before_train(self):
    method after_train (line 263) | def after_train(self):
    method before_epoch (line 266) | def before_epoch(self):
    method after_epoch (line 269) | def after_epoch(self):
    method run_epoch (line 272) | def run_epoch(self):
    method test (line 275) | def test(self):
    method parse_batch_train (line 278) | def parse_batch_train(self, batch):
    method parse_batch_test (line 281) | def parse_batch_test(self, batch):
    method forward_backward (line 284) | def forward_backward(self, batch):
    method model_inference (line 287) | def model_inference(self, input):
    method model_zero_grad (line 290) | def model_zero_grad(self, names=None):
    method model_backward (line 296) | def model_backward(self, loss):
    method model_update (line 300) | def model_update(self, names=None):
    method model_backward_and_update (line 306) | def model_backward_and_update(self, loss, names=None):
    method prograd_backward_and_update (line 311) | def prograd_backward_and_update(
  class SimpleTrainer (line 352) | class SimpleTrainer(TrainerBase):
    method __init__ (line 355) | def __init__(self, cfg):
    method check_cfg (line 375) | def check_cfg(self, cfg):
    method build_data_loader (line 387) | def build_data_loader(self):
    method build_model (line 405) | def build_model(self):
    method train (line 432) | def train(self):
    method before_train (line 435) | def before_train(self):
    method after_train (line 449) | def after_train(self):
    method after_epoch (line 467) | def after_epoch(self):
    method output_test (line 490) | def output_test(self, split=None):
    method test (line 529) | def test(self, split=None):
    method model_inference (line 557) | def model_inference(self, input):
    method parse_batch_test (line 560) | def parse_batch_test(self, batch):
    method get_current_lr (line 569) | def get_current_lr(self, names=None):
  class TrainerXU (line 575) | class TrainerXU(SimpleTrainer):
    method run_epoch (line 585) | def run_epoch(self):
    method parse_batch_train (line 661) | def parse_batch_train(self, batch_x, batch_u):
  class TrainerX (line 673) | class TrainerX(SimpleTrainer):
    method run_epoch (line 676) | def run_epoch(self):
    method parse_batch_train (line 726) | def parse_batch_train(self, batch):

FILE: Dassl.ProGrad.pytorch/dassl/evaluation/build.py
  function build_evaluator (line 6) | def build_evaluator(cfg, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py
  class EvaluatorBase (line 10) | class EvaluatorBase:
    method __init__ (line 13) | def __init__(self, cfg):
    method reset (line 16) | def reset(self):
    method process (line 19) | def process(self, mo, gt):
    method evaluate (line 22) | def evaluate(self):
  class Classification (line 27) | class Classification(EvaluatorBase):
    method __init__ (line 30) | def __init__(self, cfg, lab2cname=None, **kwargs):
    method reset (line 42) | def reset(self):
    method process (line 50) | def process(self, mo, gt):
    method evaluate (line 67) | def evaluate(self):

FILE: Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py
  function compute_accuracy (line 1) | def compute_accuracy(output, target, topk=(1, )):

FILE: Dassl.ProGrad.pytorch/dassl/metrics/distance.py
  function compute_distance_matrix (line 8) | def compute_distance_matrix(input1, input2, metric="euclidean"):
  function euclidean_squared_distance (line 46) | def euclidean_squared_distance(input1, input2):
  function cosine_distance (line 64) | def cosine_distance(input1, input2):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py
  class AlexNet (line 13) | class AlexNet(Backbone):
    method __init__ (line 15) | def __init__(self):
    method forward (line 45) | def forward(self, x):
  function init_pretrained_weights (line 52) | def init_pretrained_weights(model, model_url):
  function alexnet (line 58) | def alexnet(pretrained=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py
  class Backbone (line 4) | class Backbone(nn.Module):
    method __init__ (line 6) | def __init__(self):
    method forward (line 9) | def forward(self):
    method out_features (line 13) | def out_features(self):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py
  function build_backbone (line 6) | def build_backbone(name, verbose=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py
  class FeatureExtractor (line 13) | class FeatureExtractor(Backbone):
    method __init__ (line 15) | def __init__(self):
    method _check_input (line 30) | def _check_input(self, x):
    method forward (line 36) | def forward(self, x):
  function cnn_digit5_m3sda (line 51) | def cnn_digit5_m3sda(**kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsdg.py
  class Convolution (line 10) | class Convolution(nn.Module):
    method __init__ (line 12) | def __init__(self, c_in, c_out):
    method forward (line 17) | def forward(self, x):
  class ConvNet (line 21) | class ConvNet(Backbone):
    method __init__ (line 23) | def __init__(self, c_hidden=64):
    method _check_input (line 32) | def _check_input(self, x):
    method forward (line 38) | def forward(self, x):
  function cnn_digitsdg (line 52) | def cnn_digitsdg(**kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsingle.py
  class CNN (line 14) | class CNN(Backbone):
    method __init__ (line 16) | def __init__(self):
    method _check_input (line 25) | def _check_input(self, x):
    method forward (line 31) | def forward(self, x):
  function cnn_digitsingle (line 53) | def cnn_digitsingle(**kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/model.py
  class MBConvBlock (line 14) | class MBConvBlock(nn.Module):
    method __init__ (line 26) | def __init__(self, block_args, global_params, image_size=None):
    method forward (line 98) | def forward(self, inputs, drop_connect_rate=None):
    method set_swish (line 137) | def set_swish(self, memory_efficient=True):
  class EfficientNet (line 142) | class EfficientNet(Backbone):
    method __init__ (line 155) | def __init__(self, blocks_args=None, global_params=None):
    method set_swish (line 240) | def set_swish(self, memory_efficient=True):
    method extract_features (line 246) | def extract_features(self, inputs):
    method forward (line 264) | def forward(self, inputs):
    method from_name (line 281) | def from_name(cls, model_name, override_params=None):
    method from_pretrained (line 289) | def from_pretrained(
    method get_image_size (line 302) | def get_image_size(cls, model_name):
    method _check_model_name_is_valid (line 308) | def _check_model_name_is_valid(cls, model_name):
    method _change_in_channels (line 316) | def _change_in_channels(model, in_channels):
  function build_efficientnet (line 327) | def build_efficientnet(name, pretrained):
  function efficientnet_b0 (line 335) | def efficientnet_b0(pretrained=True, **kwargs):
  function efficientnet_b1 (line 340) | def efficientnet_b1(pretrained=True, **kwargs):
  function efficientnet_b2 (line 345) | def efficientnet_b2(pretrained=True, **kwargs):
  function efficientnet_b3 (line 350) | def efficientnet_b3(pretrained=True, **kwargs):
  function efficientnet_b4 (line 355) | def efficientnet_b4(pretrained=True, **kwargs):
  function efficientnet_b5 (line 360) | def efficientnet_b5(pretrained=True, **kwargs):
  function efficientnet_b6 (line 365) | def efficientnet_b6(pretrained=True, **kwargs):
  function efficientnet_b7 (line 370) | def efficientnet_b7(pretrained=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/utils.py
  class SwishImplementation (line 56) | class SwishImplementation(torch.autograd.Function):
    method forward (line 59) | def forward(ctx, i):
    method backward (line 65) | def backward(ctx, grad_output):
  class MemoryEfficientSwish (line 71) | class MemoryEfficientSwish(nn.Module):
    method forward (line 73) | def forward(self, x):
  class Swish (line 77) | class Swish(nn.Module):
    method forward (line 79) | def forward(self, x):
  function round_filters (line 83) | def round_filters(filters, global_params):
  function round_repeats (line 98) | def round_repeats(repeats, global_params):
  function drop_connect (line 106) | def drop_connect(inputs, p, training):
  function get_same_padding_conv2d (line 121) | def get_same_padding_conv2d(image_size=None):
  function get_width_and_height_from_size (line 130) | def get_width_and_height_from_size(x):
  function calculate_output_image_size (line 140) | def calculate_output_image_size(input_image_size, stride):
  class Conv2dDynamicSamePadding (line 156) | class Conv2dDynamicSamePadding(nn.Conv2d):
    method __init__ (line 159) | def __init__(
    method forward (line 176) | def forward(self, x):
  class Conv2dStaticSamePadding (line 203) | class Conv2dStaticSamePadding(nn.Conv2d):
    method __init__ (line 206) | def __init__(
    method forward (line 238) | def forward(self, x):
  class Identity (line 252) | class Identity(nn.Module):
    method __init__ (line 254) | def __init__(self, ):
    method forward (line 257) | def forward(self, input):
  function efficientnet_params (line 266) | def efficientnet_params(model_name):
  class BlockDecoder (line 284) | class BlockDecoder(object):
    method _decode_block_string (line 288) | def _decode_block_string(block_string):
    method _encode_block_string (line 317) | def _encode_block_string(block):
    method decode (line 334) | def decode(string_list):
    method encode (line 348) | def encode(blocks_args):
  function efficientnet (line 361) | def efficientnet(
  function get_model_params (line 399) | def get_model_params(model_name, override_params):
  function load_pretrained_weights (line 461) | def load_pretrained_weights(model, model_name, load_fc=True, advprop=Fal...

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py
  function _make_divisible (line 13) | def _make_divisible(v, divisor, min_value=None):
  class ConvBNReLU (line 33) | class ConvBNReLU(nn.Sequential):
    method __init__ (line 35) | def __init__(
  class InvertedResidual (line 54) | class InvertedResidual(nn.Module):
    method __init__ (line 56) | def __init__(self, inp, oup, stride, expand_ratio):
    method forward (line 81) | def forward(self, x):
  class MobileNetV2 (line 88) | class MobileNetV2(Backbone):
    method __init__ (line 90) | def __init__(
    method _forward_impl (line 178) | def _forward_impl(self, x):
    method forward (line 185) | def forward(self, x):
  function init_pretrained_weights (line 189) | def init_pretrained_weights(model, model_url):
  function mobilenetv2 (line 213) | def mobilenetv2(pretrained=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py
  class PreActBlock (line 8) | class PreActBlock(nn.Module):
    method __init__ (line 11) | def __init__(self, in_planes, planes, stride=1):
    method forward (line 38) | def forward(self, x):
  class PreActBottleneck (line 47) | class PreActBottleneck(nn.Module):
    method __init__ (line 50) | def __init__(self, in_planes, planes, stride=1):
    method forward (line 79) | def forward(self, x):
  class PreActResNet (line 89) | class PreActResNet(Backbone):
    method __init__ (line 91) | def __init__(self, block, num_blocks):
    method _make_layer (line 105) | def _make_layer(self, block, planes, num_blocks, stride):
    method forward (line 113) | def forward(self, x):
  function preact_resnet18 (line 134) | def preact_resnet18(**kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py
  function conv3x3 (line 16) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 28) | class BasicBlock(nn.Module):
    method __init__ (line 31) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 41) | def forward(self, x):
  class Bottleneck (line 60) | class Bottleneck(nn.Module):
    method __init__ (line 63) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 84) | def forward(self, x):
  class ResNet (line 107) | class ResNet(Backbone):
    method __init__ (line 109) | def __init__(
    method _make_layer (line 147) | def _make_layer(self, block, planes, blocks, stride=1):
    method _init_params (line 169) | def _init_params(self):
    method featuremaps (line 188) | def featuremaps(self, x):
    method forward (line 204) | def forward(self, x):
  function init_pretrained_weights (line 210) | def init_pretrained_weights(model, model_url):
  function resnet18 (line 227) | def resnet18(pretrained=True, **kwargs):
  function resnet34 (line 237) | def resnet34(pretrained=True, **kwargs):
  function resnet50 (line 247) | def resnet50(pretrained=True, **kwargs):
  function resnet101 (line 257) | def resnet101(pretrained=True, **kwargs):
  function resnet152 (line 267) | def resnet152(pretrained=True, **kwargs):
  function resnet18_ms_l123 (line 282) | def resnet18_ms_l123(pretrained=True, **kwargs):
  function resnet18_ms_l12 (line 299) | def resnet18_ms_l12(pretrained=True, **kwargs):
  function resnet18_ms_l1 (line 316) | def resnet18_ms_l1(pretrained=True, **kwargs):
  function resnet50_ms_l123 (line 333) | def resnet50_ms_l123(pretrained=True, **kwargs):
  function resnet50_ms_l12 (line 350) | def resnet50_ms_l12(pretrained=True, **kwargs):
  function resnet50_ms_l1 (line 367) | def resnet50_ms_l1(pretrained=True, **kwargs):
  function resnet101_ms_l123 (line 384) | def resnet101_ms_l123(pretrained=True, **kwargs):
  function resnet101_ms_l12 (line 401) | def resnet101_ms_l12(pretrained=True, **kwargs):
  function resnet101_ms_l1 (line 418) | def resnet101_ms_l1(pretrained=True, **kwargs):
  function resnet18_efdmix_l123 (line 440) | def resnet18_efdmix_l123(pretrained=True, **kwargs):
  function resnet18_efdmix_l12 (line 457) | def resnet18_efdmix_l12(pretrained=True, **kwargs):
  function resnet18_efdmix_l1 (line 474) | def resnet18_efdmix_l1(pretrained=True, **kwargs):
  function resnet50_efdmix_l123 (line 491) | def resnet50_efdmix_l123(pretrained=True, **kwargs):
  function resnet50_efdmix_l12 (line 508) | def resnet50_efdmix_l12(pretrained=True, **kwargs):
  function resnet50_efdmix_l1 (line 525) | def resnet50_efdmix_l1(pretrained=True, **kwargs):
  function resnet101_efdmix_l123 (line 542) | def resnet101_efdmix_l123(pretrained=True, **kwargs):
  function resnet101_efdmix_l12 (line 559) | def resnet101_efdmix_l12(pretrained=True, **kwargs):
  function resnet101_efdmix_l1 (line 576) | def resnet101_efdmix_l1(pretrained=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py
  function channel_shuffle (line 21) | def channel_shuffle(x, groups):
  class InvertedResidual (line 36) | class InvertedResidual(nn.Module):
    method __init__ (line 38) | def __init__(self, inp, oup, stride):
    method depthwise_conv (line 98) | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
    method forward (line 103) | def forward(self, x):
  class ShuffleNetV2 (line 115) | class ShuffleNetV2(Backbone):
    method __init__ (line 117) | def __init__(self, stages_repeats, stages_out_channels, **kwargs):
    method featuremaps (line 162) | def featuremaps(self, x):
    method forward (line 171) | def forward(self, x):
  function init_pretrained_weights (line 177) | def init_pretrained_weights(model, model_url):
  function shufflenet_v2_x0_5 (line 201) | def shufflenet_v2_x0_5(pretrained=True, **kwargs):
  function shufflenet_v2_x1_0 (line 209) | def shufflenet_v2_x1_0(pretrained=True, **kwargs):
  function shufflenet_v2_x1_5 (line 217) | def shufflenet_v2_x1_5(pretrained=True, **kwargs):
  function shufflenet_v2_x2_0 (line 225) | def shufflenet_v2_x2_0(pretrained=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py
  class VGG (line 24) | class VGG(Backbone):
    method __init__ (line 26) | def __init__(self, features, init_weights=True):
    method forward (line 45) | def forward(self, x):
    method _initialize_weights (line 51) | def _initialize_weights(self):
  function make_layers (line 67) | def make_layers(cfg, batch_norm=False):
  function _vgg (line 133) | def _vgg(arch, cfg, batch_norm, pretrained):
  function vgg16 (line 146) | def vgg16(pretrained=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py
  class BasicBlock (line 12) | class BasicBlock(nn.Module):
    method __init__ (line 14) | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
    method forward (line 49) | def forward(self, x):
  class NetworkBlock (line 61) | class NetworkBlock(nn.Module):
    method __init__ (line 63) | def __init__(
    method _make_layer (line 71) | def _make_layer(
    method forward (line 86) | def forward(self, x):
  class WideResNet (line 90) | class WideResNet(Backbone):
    method __init__ (line 92) | def __init__(self, depth, widen_factor, dropRate=0.0):
    method forward (line 133) | def forward(self, x):
  function wide_resnet_28_2 (line 144) | def wide_resnet_28_2(**kwargs):
  function wide_resnet_16_4 (line 149) | def wide_resnet_16_4(**kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/build.py
  function build_head (line 6) | def build_head(name, verbose=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py
  class MLP (line 7) | class MLP(nn.Module):
    method __init__ (line 9) | def __init__(
    method forward (line 44) | def forward(self, x):
  function mlp (line 49) | def mlp(**kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/build.py
  function build_network (line 6) | def build_network(name, verbose=True, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py
  function init_network_weights (line 12) | def init_network_weights(model, init_type="normal", gain=0.02):
  function get_norm_layer (line 45) | def get_norm_layer(norm_type="instance"):
  class ResnetBlock (line 61) | class ResnetBlock(nn.Module):
    method __init__ (line 63) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
    method build_conv_block (line 69) | def build_conv_block(
    method forward (line 111) | def forward(self, x):
  class LocNet (line 115) | class LocNet(nn.Module):
    method __init__ (line 118) | def __init__(
    method forward (line 152) | def forward(self, x):
  class FCN (line 163) | class FCN(nn.Module):
    method __init__ (line 166) | def __init__(
    method init_loc_layer (line 236) | def init_loc_layer(self):
    method stn (line 244) | def stn(self, x):
    method forward (line 250) | def forward(self, x, lmda=1.0, return_p=False, return_stn_output=False):
  function fcn_3x32_gctx (line 283) | def fcn_3x32_gctx(**kwargs):
  function fcn_3x64_gctx (line 291) | def fcn_3x64_gctx(**kwargs):
  function fcn_3x32_gctx_stn (line 299) | def fcn_3x32_gctx_stn(image_size=32, **kwargs):
  function fcn_3x64_gctx_stn (line 316) | def fcn_3x64_gctx_stn(image_size=224, **kwargs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py
  function cross_entropy (line 5) | def cross_entropy(input, target, label_smooth=0, reduction="mean"):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py
  class _DSBN (line 4) | class _DSBN(nn.Module):
    method __init__ (line 13) | def __init__(self, num_features, n_domain, bn_type):
    method select_bn (line 28) | def select_bn(self, domain_idx=0):
    method forward (line 32) | def forward(self, x):
  class DSBN1d (line 36) | class DSBN1d(_DSBN):
    method __init__ (line 38) | def __init__(self, num_features, n_domain):
  class DSBN2d (line 42) | class DSBN2d(_DSBN):
    method __init__ (line 44) | def __init__(self, num_features, n_domain):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py
  function deactivate_efdmix (line 7) | def deactivate_efdmix(m):
  function activate_efdmix (line 12) | def activate_efdmix(m):
  function random_efdmix (line 17) | def random_efdmix(m):
  function crossdomain_efdmix (line 22) | def crossdomain_efdmix(m):
  function run_without_efdmix (line 28) | def run_without_efdmix(model):
  function run_with_efdmix (line 38) | def run_with_efdmix(model, mix=None):
  class EFDMix (line 53) | class EFDMix(nn.Module):
    method __init__ (line 60) | def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
    method __repr__ (line 76) | def __repr__(self):
    method set_activation_status (line 81) | def set_activation_status(self, status=True):
    method update_mix_method (line 84) | def update_mix_method(self, mix="random"):
    method forward (line 87) | def forward(self, x):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py
  function deactivate_mixstyle (line 7) | def deactivate_mixstyle(m):
  function activate_mixstyle (line 12) | def activate_mixstyle(m):
  function random_mixstyle (line 17) | def random_mixstyle(m):
  function crossdomain_mixstyle (line 22) | def crossdomain_mixstyle(m):
  function run_without_mixstyle (line 28) | def run_without_mixstyle(model):
  function run_with_mixstyle (line 38) | def run_with_mixstyle(model, mix=None):
  class MixStyle (line 53) | class MixStyle(nn.Module):
    method __init__ (line 60) | def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
    method __repr__ (line 76) | def __repr__(self):
    method set_activation_status (line 81) | def set_activation_status(self, status=True):
    method update_mix_method (line 84) | def update_mix_method(self, mix="random"):
    method forward (line 87) | def forward(self, x):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py
  function mixup (line 4) | def mixup(x1, x2, y1, y2, beta, preserve_order=False):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py
  class MaximumMeanDiscrepancy (line 6) | class MaximumMeanDiscrepancy(nn.Module):
    method __init__ (line 8) | def __init__(self, kernel_type="rbf", normalize=False):
    method forward (line 13) | def forward(self, x, y):
    method linear_mmd (line 28) | def linear_mmd(self, x, y):
    method poly_mmd (line 35) | def poly_mmd(self, x, y, alpha=1.0, c=2.0, d=2):
    method rbf_mmd (line 45) | def rbf_mmd(self, x, y):
    method rbf_kernel_mixture (line 60) | def rbf_kernel_mixture(exponent, sigmas=[1, 5, 10]):
    method remove_self_distance (line 68) | def remove_self_distance(distmat):
    method euclidean_squared_distance (line 76) | def euclidean_squared_distance(x, y):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py
  class OptimalTransport (line 6) | class OptimalTransport(nn.Module):
    method distance (line 9) | def distance(batch1, batch2, dist_metric="cosine"):
  class SinkhornDivergence (line 35) | class SinkhornDivergence(OptimalTransport):
    method __init__ (line 38) | def __init__(
    method forward (line 51) | def forward(self, x, y):
    method transport_cost (line 58) | def transport_cost(self, x, y, return_pi=False):
    method sinkhorn_iterate (line 69) | def sinkhorn_iterate(C, eps, max_iter, thre):
  class MinibatchEnergyDistance (line 103) | class MinibatchEnergyDistance(SinkhornDivergence):
    method __init__ (line 105) | def __init__(
    method forward (line 119) | def forward(self, x, y):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py
  class _ReverseGrad (line 5) | class _ReverseGrad(Function):
    method forward (line 8) | def forward(ctx, input, grad_scaling):
    method backward (line 13) | def backward(ctx, grad_output):
  class ReverseGrad (line 21) | class ReverseGrad(nn.Module):
    method forward (line 29) | def forward(self, x, grad_scaling=1.0):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py
  class Sequential2 (line 4) | class Sequential2(nn.Sequential):
    method forward (line 9) | def forward(self, *inputs):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py
  class _TransNorm (line 5) | class _TransNorm(nn.Module):
    method __init__ (line 19) | def __init__(
    method resnet_running_stats (line 36) | def resnet_running_stats(self):
    method reset_parameters (line 42) | def reset_parameters(self):
    method _check_input (line 46) | def _check_input(self, x):
    method _compute_alpha (line 49) | def _compute_alpha(self, mean_s, var_s, mean_t, var_t):
    method forward (line 57) | def forward(self, input):
  class TransNorm1d (line 121) | class TransNorm1d(_TransNorm):
    method _check_input (line 123) | def _check_input(self, x):
  class TransNorm2d (line 131) | class TransNorm2d(_TransNorm):
    method _check_input (line 133) | def _check_input(self, x):

FILE: Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py
  function sharpen_prob (line 5) | def sharpen_prob(p, temperature=2):
  function reverse_index (line 16) | def reverse_index(data, label):
  function shuffle_index (line 22) | def shuffle_index(data, label):
  function create_onehot (line 28) | def create_onehot(label, num_classes):
  function sigmoid_rampup (line 41) | def sigmoid_rampup(current, rampup_length):
  function linear_rampup (line 54) | def linear_rampup(current, rampup_length):
  function ema_model_update (line 66) | def ema_model_update(model, ema_model, alpha):

FILE: Dassl.ProGrad.pytorch/dassl/optim/lr_scheduler.py
  class _BaseWarmupScheduler (line 10) | class _BaseWarmupScheduler(_LRScheduler):
    method __init__ (line 12) | def __init__(
    method get_lr (line 24) | def get_lr(self):
    method step (line 27) | def step(self, epoch=None):
  class ConstantWarmupScheduler (line 35) | class ConstantWarmupScheduler(_BaseWarmupScheduler):
    method __init__ (line 37) | def __init__(
    method get_lr (line 51) | def get_lr(self):
  class LinearWarmupScheduler (line 57) | class LinearWarmupScheduler(_BaseWarmupScheduler):
    method __init__ (line 59) | def __init__(
    method get_lr (line 73) | def get_lr(self):
  function build_lr_scheduler (line 83) | def build_lr_scheduler(optimizer, optim_cfg):

FILE: Dassl.ProGrad.pytorch/dassl/optim/optimizer.py
  function build_optimizer (line 13) | def build_optimizer(model, optim_cfg):

FILE: Dassl.ProGrad.pytorch/dassl/optim/radam.py
  class RAdam (line 18) | class RAdam(Optimizer):
    method __init__ (line 20) | def __init__(
    method __setstate__ (line 47) | def __setstate__(self, state):
    method step (line 50) | def step(self, closure=None):
  class PlainRAdam (line 133) | class PlainRAdam(Optimizer):
    method __init__ (line 135) | def __init__(
    method __setstate__ (line 162) | def __setstate__(self, state):
    method step (line 165) | def step(self, closure=None):
  class AdamW (line 234) | class AdamW(Optimizer):
    method __init__ (line 236) | def __init__(
    method __setstate__ (line 267) | def __setstate__(self, state):
    method step (line 270) | def step(self, closure=None):

FILE: Dassl.ProGrad.pytorch/dassl/utils/logger.py
  class Logger (line 11) | class Logger:
    method __init__ (line 27) | def __init__(self, fpath=None):
    method __del__ (line 34) | def __del__(self):
    method __enter__ (line 37) | def __enter__(self):
    method __exit__ (line 40) | def __exit__(self, *args):
    method write (line 43) | def write(self, msg):
    method flush (line 48) | def flush(self):
    method close (line 54) | def close(self):
  function setup_logger (line 60) | def setup_logger(output=None):

FILE: Dassl.ProGrad.pytorch/dassl/utils/meters.py
  class AverageMeter (line 7) | class AverageMeter:
    method __init__ (line 17) | def __init__(self, ema=False):
    method reset (line 25) | def reset(self):
    method update (line 31) | def update(self, val, n=1):
  class MetricMeter (line 45) | class MetricMeter:
    method __init__ (line 58) | def __init__(self, delimiter="\t"):
    method update (line 62) | def update(self, input_dict):
    method __str__ (line 76) | def __str__(self):

FILE: Dassl.ProGrad.pytorch/dassl/utils/registry.py
  class Registry (line 7) | class Registry:
    method __init__ (line 32) | def __init__(self, name):
    method _do_register (line 36) | def _do_register(self, name, obj, force=False):
    method register (line 45) | def register(self, obj=None, force=False):
    method get (line 59) | def get(self, name):
    method registered_names (line 68) | def registered_names(self):

FILE: Dassl.ProGrad.pytorch/dassl/utils/tools.py
  function mkdir_if_missing (line 34) | def mkdir_if_missing(dirname):
  function check_isfile (line 44) | def check_isfile(fpath):
  function read_json (line 59) | def read_json(fpath):
  function write_json (line 66) | def write_json(obj, fpath):
  function set_random_seed (line 73) | def set_random_seed(seed):
  function download_url (line 80) | def download_url(url, dst):
  function read_image (line 111) | def read_image(path):
  function collect_env_info (line 134) | def collect_env_info():
  function listdir_nohidden (line 146) | def listdir_nohidden(path, sort=False):
  function get_most_similar_str_to_a_from_b (line 159) | def get_most_similar_str_to_a_from_b(a, b):
  function check_availability (line 176) | def check_availability(requested, available):
  function tolist_if_not (line 192) | def tolist_if_not(x):

FILE: Dassl.ProGrad.pytorch/dassl/utils/torchtools.py
  function save_checkpoint (line 27) | def save_checkpoint(
  function load_checkpoint (line 85) | def load_checkpoint(fpath):
  function resume_from_checkpoint (line 126) | def resume_from_checkpoint(fdir, model, optimizer=None, scheduler=None):
  function adjust_learning_rate (line 168) | def adjust_learning_rate(
  function set_bn_to_eval (line 194) | def set_bn_to_eval(m):
  function open_all_layers (line 203) | def open_all_layers(model):
  function open_specified_layers (line 214) | def open_specified_layers(model, open_layers):
  function count_num_param (line 254) | def count_num_param(model):
  function load_pretrained_weights (line 266) | def load_pretrained_weights(model, weight_path):
  function init_network_weights (line 323) | def init_network_weights(model, init_type="normal", gain=0.02):

FILE: Dassl.ProGrad.pytorch/datasets/da/cifar_stl.py
  function extract_and_save_image (line 47) | def extract_and_save_image(dataset, save_dir, discard, label2name):
  function download_and_prepare (line 70) | def download_and_prepare(name, root, discarded_label, label2name):

FILE: Dassl.ProGrad.pytorch/datasets/da/digit5.py
  function mkdir_if_missing (line 9) | def mkdir_if_missing(directory):
  function extract_and_save (line 14) | def extract_and_save(data, label, save_dir):
  function load_mnist (line 28) | def load_mnist(data_dir, raw_data_dir):
  function load_mnist_m (line 41) | def load_mnist_m(data_dir, raw_data_dir):
  function load_svhn (line 54) | def load_svhn(data_dir, raw_data_dir):
  function load_syn (line 66) | def load_syn(data_dir, raw_data_dir):
  function load_usps (line 79) | def load_usps(data_dir, raw_data_dir):
  function main (line 98) | def main(data_dir):

FILE: Dassl.ProGrad.pytorch/datasets/dg/cifar_c.py
  function extract_and_save (line 15) | def extract_and_save(images, labels, level, dst):
  function main (line 29) | def main(npy_folder):

FILE: Dassl.ProGrad.pytorch/datasets/ssl/cifar10_cifar100_svhn.py
  function extract_and_save_image (line 8) | def extract_and_save_image(dataset, save_dir):
  function download_and_prepare (line 24) | def download_and_prepare(name, root):

FILE: Dassl.ProGrad.pytorch/datasets/ssl/stl10.py
  function extract_and_save_image (line 8) | def extract_and_save_image(dataset, save_dir):
  function download_and_prepare (line 27) | def download_and_prepare(root):

FILE: Dassl.ProGrad.pytorch/setup.py
  function readme (line 6) | def readme():
  function find_version (line 12) | def find_version():
  function numpy_include (line 19) | def numpy_include():
  function get_requirements (line 27) | def get_requirements(filename='requirements.txt'):

FILE: Dassl.ProGrad.pytorch/tools/parse_test_res.py
  function compute_ci95 (line 60) | def compute_ci95(res):
  function parse_function (line 64) | def parse_function(*metrics, directory="", args=None, end_signal=None):
  function main (line 126) | def main(args, end_signal):

FILE: Dassl.ProGrad.pytorch/tools/replace_text.py
  function is_python_file (line 12) | def is_python_file(filename):
  function update_file (line 17) | def update_file(filename, text_to_search, replacement_text):
  function recursive_update (line 24) | def recursive_update(directory, text_to_search, replacement_text):
  function main (line 38) | def main():

FILE: Dassl.ProGrad.pytorch/tools/train.py
  function print_args (line 9) | def print_args(args, cfg):
  function reset_cfg (line 23) | def reset_cfg(cfg, args):
  function extend_cfg (line 55) | def extend_cfg(cfg):
  function setup_cfg (line 69) | def setup_cfg(args):
  function main (line 92) | def main(args):

FILE: ProGrad.public/clip/clip.py
  function _download (line 43) | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
  function _transform (line 86) | def _transform(n_px):
  function available_models (line 97) | def available_models() -> List[str]:
  function load (line 102) | def load(name: str,
  function tokenize (line 216) | def tokenize(texts: Union[str, List[str]],

FILE: ProGrad.public/clip/model.py
  class Bottleneck (line 10) | class Bottleneck(nn.Module):
    method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 44) | def forward(self, x: torch.Tensor):
  class AttentionPool2d (line 60) | class AttentionPool2d(nn.Module):
    method __init__ (line 61) | def __init__(self,
    method forward (line 75) | def forward(self, x):
  class ModifiedResNet (line 106) | class ModifiedResNet(nn.Module):
    method __init__ (line 113) | def __init__(self,
    method _make_layer (line 157) | def _make_layer(self, planes, blocks, stride=1):
    method forward (line 166) | def forward(self, x):
  class LayerNorm (line 185) | class LayerNorm(nn.LayerNorm):
    method forward (line 187) | def forward(self, x: torch.Tensor):
  class QuickGELU (line 193) | class QuickGELU(nn.Module):
    method forward (line 194) | def forward(self, x: torch.Tensor):
  class ResidualAttentionBlock (line 198) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 199) | def __init__(self,
    method attention (line 214) | def attention(self, x: torch.Tensor):
    method forward (line 221) | def forward(self, x: torch.Tensor):
  class Transformer (line 227) | class Transformer(nn.Module):
    method __init__ (line 228) | def __init__(self,
    method forward (line 241) | def forward(self, x: torch.Tensor):
  class VisionTransformer (line 245) | class VisionTransformer(nn.Module):
    method __init__ (line 246) | def __init__(self, input_resolution: int, patch_size: int, width: int,
    method forward (line 268) | def forward(self, x: torch.Tensor):
  class CLIP (line 293) | class CLIP(nn.Module):
    method __init__ (line 294) | def __init__(
    method initialize_parameters (line 345) | def initialize_parameters(self):
    method build_attention_mask (line 379) | def build_attention_mask(self):
    method dtype (line 388) | def dtype(self):
    method encode_image (line 391) | def encode_image(self, image):
    method encode_text (line 394) | def encode_text(self, text):
    method forward (line 411) | def forward(self, image, text):
  function convert_weights (line 430) | def convert_weights(model: nn.Module):
  function build_model (line 456) | def build_model(state_dict: dict):

FILE: ProGrad.public/clip/simple_tokenizer.py
  function default_bpe (line 11) | def default_bpe():
  function bytes_to_unicode (line 17) | def bytes_to_unicode():
  function get_pairs (line 43) | def get_pairs(word):
  function basic_clean (line 55) | def basic_clean(text):
  function whitespace_clean (line 61) | def whitespace_clean(text):
  class SimpleTokenizer (line 67) | class SimpleTokenizer(object):
    method __init__ (line 68) | def __init__(self, bpe_path: str = default_bpe()):
    method bpe (line 90) | def bpe(self, token):
    method encode (line 133) | def encode(self, text):
    method decode (line 143) | def decode(self, tokens):

FILE: ProGrad.public/datasets/caltech101.py
  class Caltech101 (line 20) | class Caltech101(DatasetBase):
    method __init__ (line 24) | def __init__(self, cfg):

FILE: ProGrad.public/datasets/dtd.py
  class DescribableTextures (line 12) | class DescribableTextures(DatasetBase):
    method __init__ (line 16) | def __init__(self, cfg):
    method read_and_split_data (line 66) | def read_and_split_data(image_dir,

FILE: ProGrad.public/datasets/eurosat.py
  class EuroSAT (line 25) | class EuroSAT(DatasetBase):
    method __init__ (line 29) | def __init__(self, cfg):
    method update_classname (line 79) | def update_classname(self, dataset_old):

FILE: ProGrad.public/datasets/fgvc_aircraft.py
  class FGVCAircraft (line 11) | class FGVCAircraft(DatasetBase):
    method __init__ (line 15) | def __init__(self, cfg):
    method read_data (line 65) | def read_data(self, cname2lab, split_file):

FILE: ProGrad.public/datasets/food101.py
  class Food101 (line 12) | class Food101(DatasetBase):
    method __init__ (line 16) | def __init__(self, cfg):

FILE: ProGrad.public/datasets/imagenet.py
  class ImageNet (line 12) | class ImageNet(DatasetBase):
    method __init__ (line 16) | def __init__(self, cfg):
    method read_classnames (line 70) | def read_classnames(text_file):
    method read_data (line 84) | def read_data(self, classnames, split_dir):

FILE: ProGrad.public/datasets/imagenet_a.py
  class ImageNetA (line 12) | class ImageNetA(DatasetBase):
    method __init__ (line 20) | def __init__(self, cfg):
    method read_data (line 32) | def read_data(self, classnames):

FILE: ProGrad.public/datasets/imagenet_r.py
  class ImageNetR (line 12) | class ImageNetR(DatasetBase):
    method __init__ (line 20) | def __init__(self, cfg):
    method read_data (line 32) | def read_data(self, classnames):

FILE: ProGrad.public/datasets/imagenet_sketch.py
  class ImageNetSketch (line 10) | class ImageNetSketch(DatasetBase):
    method __init__ (line 18) | def __init__(self, cfg):
    method read_data (line 30) | def read_data(self, classnames):

FILE: ProGrad.public/datasets/imagenetv2.py
  class ImageNetV2 (line 10) | class ImageNetV2(DatasetBase):
    method __init__ (line 18) | def __init__(self, cfg):
    method read_data (line 31) | def read_data(self, classnames):

FILE: ProGrad.public/datasets/oxford_flowers.py
  class OxfordFlowers (line 14) | class OxfordFlowers(DatasetBase):
    method __init__ (line 18) | def __init__(self, cfg):
    method read_data (line 70) | def read_data(self):

FILE: ProGrad.public/datasets/oxford_pets.py
  class OxfordPets (line 12) | class OxfordPets(DatasetBase):
    method __init__ (line 16) | def __init__(self, cfg):
    method read_data (line 66) | def read_data(self, split_file):
    method split_trainval (line 87) | def split_trainval(trainval, p_val=0.2):
    method save_split (line 110) | def save_split(train, val, test, filepath, path_prefix):
    method read_split (line 133) | def read_split(filepath, path_prefix):
    method subsample_classes (line 153) | def subsample_classes(*args, subsample="all"):

FILE: ProGrad.public/datasets/stanford_cars.py
  class StanfordCars (line 12) | class StanfordCars(DatasetBase):
    method __init__ (line 16) | def __init__(self, cfg):
    method read_data (line 72) | def read_data(self, image_dir, anno_file, meta_file):

FILE: ProGrad.public/datasets/sun397.py
  class SUN397 (line 11) | class SUN397(DatasetBase):
    method __init__ (line 15) | def __init__(self, cfg):
    method read_data (line 74) | def read_data(self, cname2lab, text_file):

FILE: ProGrad.public/datasets/ucf101.py
  class UCF101 (line 12) | class UCF101(DatasetBase):
    method __init__ (line 16) | def __init__(self, cfg):
    method read_data (line 78) | def read_data(self, cname2lab, text_file):

FILE: ProGrad.public/interpret_prompt.py
  function load_clip_to_cpu (line 10) | def load_clip_to_cpu(backbone_name="RN50"):

FILE: ProGrad.public/lpclip/feat_extractor.py
  function print_args (line 34) | def print_args(args, cfg):
  function reset_cfg (line 48) | def reset_cfg(cfg, args):
  function extend_cfg (line 65) | def extend_cfg(cfg):
  function setup_cfg (line 86) | def setup_cfg(args):
  function main (line 106) | def main(args):

FILE: ProGrad.public/lpclip/linear_probe.py
  function binary_search (line 84) | def binary_search(c_left, c_right, seed, step, test_acc_step_list):

FILE: ProGrad.public/lpclip/linear_probe_transfer.py
  function binary_search (line 103) | def binary_search(c_left, c_right, seed, step, test_acc_step_list):

FILE: ProGrad.public/parse_test_res.py
  function compute_ci95 (line 60) | def compute_ci95(res):
  function parse_function (line 64) | def parse_function(*metrics, directory="", args=None, end_signal=None):
  function main (line 126) | def main(args, end_signal):

FILE: ProGrad.public/train.py
  function print_args (line 34) | def print_args(args, cfg):
  function reset_cfg (line 48) | def reset_cfg(cfg, args):
  function extend_cfg (line 80) | def extend_cfg(cfg):
  function setup_cfg (line 118) | def setup_cfg(args):
  function main (line 141) | def main(args):

FILE: ProGrad.public/trainers/cocoop.py
  function load_clip_to_cpu (line 21) | def load_clip_to_cpu(cfg):
  class TextEncoder (line 39) | class TextEncoder(nn.Module):
    method __init__ (line 40) | def __init__(self, clip_model):
    method forward (line 48) | def forward(self, prompts, tokenized_prompts):
  class PromptLearner (line 63) | class PromptLearner(nn.Module):
    method __init__ (line 64) | def __init__(self, cfg, classnames, clip_model):
    method construct_prompts (line 135) | def construct_prompts(self, ctx, prefix, suffix, label=None):
    method forward (line 156) | def forward(self, im_features):
  class CustomCLIP (line 199) | class CustomCLIP(nn.Module):
    method __init__ (line 200) | def __init__(self, cfg, classnames, clip_model):
    method forward (line 209) | def forward(self, image, label=None):
  class CoCoOp (line 235) | class CoCoOp(TrainerX):
    method check_cfg (line 236) | def check_cfg(self, cfg):
    method build_model (line 239) | def build_model(self):
    method forward_backward (line 290) | def forward_backward(self, batch):
    method parse_batch_train (line 318) | def parse_batch_train(self, batch):
    method load_model (line 325) | def load_model(self, directory, epoch=None):

FILE: ProGrad.public/trainers/coop.py
  function load_clip_to_cpu (line 19) | def load_clip_to_cpu(cfg):
  class TextEncoder (line 37) | class TextEncoder(nn.Module):
    method __init__ (line 38) | def __init__(self, clip_model):
    method forward (line 46) | def forward(self, prompts, tokenized_prompts):
  class PromptLearner (line 61) | class PromptLearner(nn.Module):
    method __init__ (line 62) | def __init__(self, cfg, classnames, clip_model):
    method forward (line 138) | def forward(self):
  class CustomCLIP (line 227) | class CustomCLIP(nn.Module):
    method __init__ (line 228) | def __init__(self, cfg, classnames, clip_model):
    method forward (line 237) | def forward(self, image):
  class CoOp (line 256) | class CoOp(TrainerX):
    method check_cfg (line 262) | def check_cfg(self, cfg):
    method build_model (line 265) | def build_model(self):
    method forward_backward (line 306) | def forward_backward(self, batch):
    method parse_batch_train (line 333) | def parse_batch_train(self, batch):
    method load_model (line 340) | def load_model(self, directory, epoch=None):

FILE: ProGrad.public/trainers/prograd.py
  function load_clip_to_cpu (line 24) | def load_clip_to_cpu(cfg):
  class TextEncoder (line 42) | class TextEncoder(nn.Module):
    method __init__ (line 43) | def __init__(self, clip_model):
    method forward (line 51) | def forward(self, prompts, tokenized_prompts):
  class PromptLearner (line 66) | class PromptLearner(nn.Module):
    method __init__ (line 67) | def __init__(self, cfg, classnames, clip_model):
    method forward (line 134) | def forward(self):
  class CLIP (line 220) | class CLIP(nn.Module):
    method __init__ (line 221) | def __init__(self, cfg, classnames):
    method forward (line 240) | def forward(self, image):
  class CustomCLIP (line 252) | class CustomCLIP(nn.Module):
    method __init__ (line 253) | def __init__(self, cfg, classnames, clip_model):
    method forward (line 262) | def forward(self, image):
  class ProGradLoss (line 280) | class ProGradLoss(_Loss):
    method __init__ (line 281) | def __init__(self, T):
    method forward (line 285) | def forward(self, stu_logits, tea_logits, label):
  class ProGrad (line 297) | class ProGrad(TrainerX):
    method check_cfg (line 300) | def check_cfg(self, cfg):
    method build_model (line 303) | def build_model(self):
    method forward_backward (line 360) | def forward_backward(self, batch):
    method parse_batch_train (line 396) | def parse_batch_train(self, batch):
    method load_model (line 403) | def load_model(self, directory, epoch=None):

FILE: ProGrad.public/trainers/zsclip.py
  class ZeroshotCLIP (line 36) | class ZeroshotCLIP(TrainerX):
    method build_model (line 37) | def build_model(self):
    method model_inference (line 59) | def model_inference(self, image):
  class ZeroshotCLIP2 (line 69) | class ZeroshotCLIP2(ZeroshotCLIP):
    method build_model (line 75) | def build_model(self):
Condensed preview — 244 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (668K chars).
[
  {
    "path": "Dassl.ProGrad.pytorch/.flake8",
    "chars": 577,
    "preview": "[flake8]\nignore =\n    # At least two spaces before inline comment\n    E261,\n    # Line lengths are recommended to be no "
  },
  {
    "path": "Dassl.ProGrad.pytorch/.gitignore",
    "chars": 1888,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "Dassl.ProGrad.pytorch/.isort.cfg",
    "chars": 315,
    "preview": "[isort]\nline_length=79\nmulti_line_output=6\nlength_sort=true\nknown_standard_library=numpy,setuptools\nknown_myself=dassl\nk"
  },
  {
    "path": "Dassl.ProGrad.pytorch/.style.yapf",
    "chars": 222,
    "preview": "[style]\nBASED_ON_STYLE = pep8\nBLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true\nSPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN ="
  },
  {
    "path": "Dassl.ProGrad.pytorch/DATASETS.md",
    "chars": 7114,
    "preview": "# How to Install Datasets\n\n`$DATA` denotes the location where datasets are installed, e.g.\n\n```\n$DATA/\n|–– office31/\n|––"
  },
  {
    "path": "Dassl.ProGrad.pytorch/LICENSE",
    "chars": 1064,
    "preview": "MIT License\n\nCopyright (c) 2020 Kaiyang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof"
  },
  {
    "path": "Dassl.ProGrad.pytorch/README.md",
    "chars": 22588,
    "preview": "# Dassl\n\n## Introduction\n\nDassl is a [PyTorch](https://pytorch.org) toolbox initially developed for our project [Domain "
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/README.md",
    "chars": 331,
    "preview": "The `datasets/` folder contains dataset-specific config files which define the standard protocols (e.g., image size, dat"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/cifar_stl.yaml",
    "chars": 111,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:\n  NAME: \"CIFARSTL\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/digit5.yaml",
    "chars": 186,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n  TRANSFORMS: [\"normalize\"]\n\nDATASET:"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/domainnet.yaml",
    "chars": 162,
    "preview": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"DomainNet\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/mini_domainnet.yaml",
    "chars": 163,
    "preview": "INPUT:\n  SIZE: (96, 96)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"miniDomainNe"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/office31.yaml",
    "chars": 225,
    "preview": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"Office31\"\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/office_home.yaml",
    "chars": 56,
    "preview": "INPUT:\n  SIZE: (224, 224)\n\nDATASET:\n  NAME: \"OfficeHome\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/da/visda17.yaml",
    "chars": 185,
    "preview": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"center_crop\", \"normalize\"]\n\nDATASET:\n  NAME: \"VisDA17\"\n\nMODEL:\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/cifar100_c.yaml",
    "chars": 260,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/cifar10_c.yaml",
    "chars": 259,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/digit_single.yaml",
    "chars": 190,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/digits_dg.yaml",
    "chars": 184,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5, 0.5]\n\nDATASET:"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/office_home_dg.yaml",
    "chars": 185,
    "preview": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"OfficeHome"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/pacs.yaml",
    "chars": 177,
    "preview": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"PACS\"\n\nMOD"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/dg/vlcs.yaml",
    "chars": 177,
    "preview": "INPUT:\n  SIZE: (224, 224)\n  TRANSFORMS: [\"random_flip\", \"random_translation\", \"normalize\"]\n\nDATASET:\n  NAME: \"VLCS\"\n\nMOD"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar10.yaml",
    "chars": 255,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/cifar100.yaml",
    "chars": 275,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/stl10.yaml",
    "chars": 249,
    "preview": "INPUT:\n  SIZE: (96, 96)\n  TRANSFORMS: [\"random_flip\", \"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/datasets/ssl/svhn.yaml",
    "chars": 255,
    "preview": "INPUT:\n  SIZE: (32, 32)\n  TRANSFORMS: [\"random_crop\", \"normalize\"]\n  PIXEL_MEAN: [0.5, 0.5, 0.5]\n  PIXEL_STD: [0.5, 0.5,"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/dael/digit5.yaml",
    "chars": 318,
    "preview": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 256\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/dael/domainnet.yaml",
    "chars": 341,
    "preview": "DATALOADER:\n  NUM_WORKERS: 4\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TRAIN_U:\n    SAME_AS_X: "
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/dael/mini_domainnet.yaml",
    "chars": 344,
    "preview": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 192\n  TRAIN_U:\n    SAME_AS_X:"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/digit5.yaml",
    "chars": 247,
    "preview": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 256\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/domainnet.yaml",
    "chars": 245,
    "preview": "DATALOADER:\n  NUM_WORKERS: 4\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TRAIN_U:\n    SAME_AS_X: "
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/m3sda/mini_domainnet.yaml",
    "chars": 248,
    "preview": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 192\n  TRAIN_U:\n    SAME_AS_X:"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/digit5.yaml",
    "chars": 161,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 256\n  TEST:\n    BATCH_SIZE: 256\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05\n  STEPSIZE: [30]"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/mini_domainnet.yaml",
    "chars": 162,
    "preview": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATCH_SIZE: 128\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.00"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/office31.yaml",
    "chars": 135,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 32\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.002\n  STEPSIZE: [20]\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/da/source_only/visda17.yaml",
    "chars": 183,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 32\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.0001\n  STEPSIZE: [2]\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/dael/digits_dg.yaml",
    "chars": 242,
    "preview": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 120\n  TEST:\n    BATCH_SIZE: 100\n\nOPTIM:\n  NAME"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/dael/office_home_dg.yaml",
    "chars": 275,
    "preview": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TEST:\n    BATCH_SIZE: 100\n\nOPTIM:\n  NAME:"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/dael/pacs.yaml",
    "chars": 275,
    "preview": "DATALOADER:\n  TRAIN_X:\n    SAMPLER: \"RandomDomainSampler\"\n    BATCH_SIZE: 30\n  TEST:\n    BATCH_SIZE: 100\n\nOPTIM:\n  NAME:"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/digits_dg.yaml",
    "chars": 258,
    "preview": "INPUT:\n  PIXEL_MEAN: [0., 0., 0.]\n  PIXEL_STD: [1., 1., 1.]\n\nDATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATC"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/office_home_dg.yaml",
    "chars": 272,
    "preview": "INPUT:\n  PIXEL_MEAN: [0., 0., 0.]\n  PIXEL_STD: [1., 1., 1.]\n\nDATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 16\n  TEST:\n    BATCH"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/ddaig/pacs.yaml",
    "chars": 272,
    "preview": "INPUT:\n  PIXEL_MEAN: [0., 0., 0.]\n  PIXEL_STD: [1., 1., 1.]\n\nDATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 16\n  TEST:\n    BATCH"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/digits_dg.yaml",
    "chars": 178,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.05"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/mini_domainnet.yaml",
    "chars": 162,
    "preview": "DATALOADER:\n  NUM_WORKERS: 8\n  TRAIN_X:\n    BATCH_SIZE: 128\n  TEST:\n    BATCH_SIZE: 128\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.00"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/office_home_dg.yaml",
    "chars": 161,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.001"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/dg/vanilla/pacs.yaml",
    "chars": 161,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 64\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nOPTIM:\n  NAME: \"sgd\"\n  LR: 0.001"
  },
  {
    "path": "Dassl.ProGrad.pytorch/configs/trainers/ssl/fixmatch/cifar10.yaml",
    "chars": 373,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 64\n  TRAIN_U:\n    SAME_AS_X: False\n    BATCH_SIZE: 448\n  TEST:\n    BATCH_SIZE: 50"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/__init__.py",
    "chars": 443,
    "preview": "\"\"\"\nDassl\n------\nPyTorch toolbox for domain adaptation and semi-supervised learning.\n\nURL: https://github.com/KaiyangZho"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/config/__init__.py",
    "chars": 96,
    "preview": "from .defaults import _C as cfg_default\n\n\ndef get_cfg_default():\n    return cfg_default.clone()\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/config/defaults.py",
    "chars": 8752,
    "preview": "from yacs.config import CfgNode as CN\n\n###########################\n# Config definition\n###########################\n\n_C ="
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/__init__.py",
    "chars": 54,
    "preview": "from .data_manager import DataManager, DatasetWrapper\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/data_manager.py",
    "chars": 8040,
    "preview": "import torch\nimport torchvision.transforms as T\nfrom PIL import Image\nfrom torch.utils.data import Dataset as TorchDatas"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/__init__.py",
    "chars": 180,
    "preview": "from .build import DATASET_REGISTRY, build_dataset  # isort:skip\nfrom .base_dataset import Datum, DatasetBase  # isort:s"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/base_dataset.py",
    "chars": 6252,
    "preview": "import os\nimport random\nimport os.path as osp\nimport tarfile\nimport zipfile\nfrom collections import defaultdict\nimport g"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/build.py",
    "chars": 368,
    "preview": "from dassl.utils import Registry, check_availability\n\nDATASET_REGISTRY = Registry(\"DATASET\")\n\n\ndef build_dataset(cfg):\n "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/__init__.py",
    "chars": 229,
    "preview": "from .digit5 import Digit5\nfrom .visda17 import VisDA17\nfrom .cifarstl import CIFARSTL\nfrom .office31 import Office31\nfr"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/cifarstl.py",
    "chars": 2309,
    "preview": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_datase"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/digit5.py",
    "chars": 4109,
    "preview": "import random\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/domainnet.py",
    "chars": 2304,
    "preview": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REG"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/mini_domainnet.py",
    "chars": 1978,
    "preview": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REG"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/office31.py",
    "chars": 1952,
    "preview": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_datase"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/office_home.py",
    "chars": 2013,
    "preview": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_datase"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/da/visda17.py",
    "chars": 1765,
    "preview": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REG"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/__init__.py",
    "chars": 198,
    "preview": "from .pacs import PACS\nfrom .vlcs import VLCS\nfrom .cifar_c import CIFAR10C, CIFAR100C\nfrom .digits_dg import DigitsDG\nf"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/cifar_c.py",
    "chars": 3248,
    "preview": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_datase"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digit_single.py",
    "chars": 4315,
    "preview": "import os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_datase"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/digits_dg.py",
    "chars": 3220,
    "preview": "import glob\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ."
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/office_home_dg.py",
    "chars": 1564,
    "preview": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom .digits_dg import DigitsDG\nfrom ..base_dataset import D"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/pacs.py",
    "chars": 3174,
    "preview": "import os.path as osp\n\nfrom ..build import DATASET_REGISTRY\nfrom ..base_dataset import Datum, DatasetBase\n\n\n@DATASET_REG"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/dg/vlcs.py",
    "chars": 1895,
    "preview": "import glob\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY\nfrom ."
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/__init__.py",
    "chars": 85,
    "preview": "from .svhn import SVHN\nfrom .cifar import CIFAR10, CIFAR100\nfrom .stl10 import STL10\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/cifar.py",
    "chars": 3252,
    "preview": "import math\nimport random\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_R"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/stl10.py",
    "chars": 2712,
    "preview": "import numpy as np\nimport os.path as osp\n\nfrom dassl.utils import listdir_nohidden\n\nfrom ..build import DATASET_REGISTRY"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/datasets/ssl/svhn.py",
    "chars": 361,
    "preview": "from .cifar import CIFAR10\nfrom ..build import DATASET_REGISTRY\n\n\n@DATASET_REGISTRY.register()\nclass SVHN(CIFAR10):\n    "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/samplers.py",
    "chars": 6660,
    "preview": "import copy\nimport numpy as np\nimport random\nfrom collections import defaultdict\nfrom torch.utils.data.sampler import Sa"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/__init__.py",
    "chars": 40,
    "preview": "from .transforms import build_transform\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/autoaugment.py",
    "chars": 11840,
    "preview": "\"\"\"\nSource: https://github.com/DeepVoltaire/AutoAugment\n\"\"\"\nimport numpy as np\nimport random\nfrom PIL import Image, Imag"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/randaugment.py",
    "chars": 8715,
    "preview": "\"\"\"\nCredit to\n1) https://github.com/ildoonet/pytorch-randaugment\n2) https://github.com/kakaobrain/fast-autoaugment\n\"\"\"\ni"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/data/transforms/transforms.py",
    "chars": 10226,
    "preview": "import numpy as np\nimport random\nimport torch\nfrom PIL import Image\nfrom torchvision.transforms import (\n    Resize, Com"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/__init__.py",
    "chars": 215,
    "preview": "from .build import TRAINER_REGISTRY, build_trainer  # isort:skip\nfrom .trainer import TrainerX, TrainerXU, TrainerBase, "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/build.py",
    "chars": 368,
    "preview": "from dassl.utils import Registry, check_availability\n\nTRAINER_REGISTRY = Registry(\"TRAINER\")\n\n\ndef build_trainer(cfg):\n "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/__init__.py",
    "chars": 241,
    "preview": "from .mcd import MCD\nfrom .mme import MME\nfrom .adda import ADDA\nfrom .dael import DAEL\nfrom .dann import DANN\nfrom .ada"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/adabn.py",
    "chars": 986,
    "preview": "import torch\n\nfrom dassl.utils import check_isfile\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\n\n\n@TRAINER_REGIS"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/adda.py",
    "chars": 2847,
    "preview": "import copy\nimport torch\nimport torch.nn as nn\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dassl.u"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/dael.py",
    "chars": 7522,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom dassl.data import DataManager\nfrom dassl.optim import build_optimizer, build_lr"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/dann.py",
    "chars": 2661,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/m3sda.py",
    "chars": 6194,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/mcd.py",
    "chars": 3593,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/mme.py",
    "chars": 2693,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/self_ensembling.py",
    "chars": 2516,
    "preview": "import copy\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metric"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/da/source_only.py",
    "chars": 996,
    "preview": "from torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import com"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/__init__.py",
    "chars": 114,
    "preview": "from .ddaig import DDAIG\nfrom .daeldg import DAELDG\nfrom .vanilla import Vanilla\nfrom .crossgrad import CrossGrad\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/crossgrad.py",
    "chars": 2896,
    "preview": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dass"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/daeldg.py",
    "chars": 5550,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom dassl.data import DataManager\nfrom dassl.optim import build_optimizer, build_lr"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/ddaig.py",
    "chars": 3709,
    "preview": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.optim import build_optimizer, build_lr_scheduler\nfrom dass"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/dg/vanilla.py",
    "chars": 884,
    "preview": "from torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerX\nfrom dassl.metrics import comp"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/__init__.py",
    "chars": 165,
    "preview": "from .entmin import EntMin\nfrom .fixmatch import FixMatch\nfrom .mixmatch import MixMatch\nfrom .mean_teacher import MeanT"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/entmin.py",
    "chars": 1162,
    "preview": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metri"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/fixmatch.py",
    "chars": 4063,
    "preview": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.data import DataManager\nfrom dassl.engine import TRAINER_R"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/mean_teacher.py",
    "chars": 1757,
    "preview": "import copy\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metric"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/mixmatch.py",
    "chars": 3119,
    "preview": "import torch\nfrom torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.model"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/ssl/sup_baseline.py",
    "chars": 930,
    "preview": "from torch.nn import functional as F\n\nfrom dassl.engine import TRAINER_REGISTRY, TrainerXU\nfrom dassl.metrics import com"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/engine/trainer.py",
    "chars": 24201,
    "preview": "import json\nimport time\nimport numpy as np\nimport os.path as osp\nimport datetime\nfrom collections import OrderedDict\nimp"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/evaluation/__init__.py",
    "chars": 123,
    "preview": "from .build import build_evaluator, EVALUATOR_REGISTRY  # isort:skip\n\nfrom .evaluator import EvaluatorBase, Classificati"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/evaluation/build.py",
    "chars": 410,
    "preview": "from dassl.utils import Registry, check_availability\n\nEVALUATOR_REGISTRY = Registry(\"EVALUATOR\")\n\n\ndef build_evaluator(c"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/evaluation/evaluator.py",
    "chars": 3850,
    "preview": "import numpy as np\nimport os.path as osp\nfrom collections import OrderedDict, defaultdict\nimport torch\nfrom sklearn.metr"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/metrics/__init__.py",
    "chars": 138,
    "preview": "from .accuracy import compute_accuracy\nfrom .distance import (\n    cosine_distance, compute_distance_matrix, euclidean_s"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/metrics/accuracy.py",
    "chars": 972,
    "preview": "def compute_accuracy(output, target, topk=(1, )):\n    \"\"\"Computes the accuracy over the k top predictions for\n    the sp"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/metrics/distance.py",
    "chars": 2247,
    "preview": "\"\"\"\nSource: https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport torch\nfrom torch.nn import functional as F\n\n\ndef "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/__init__.py",
    "chars": 163,
    "preview": "from .head import HEAD_REGISTRY, build_head\nfrom .network import NETWORK_REGISTRY, build_network\nfrom .backbone import B"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/__init__.py",
    "chars": 1175,
    "preview": "from .build import build_backbone, BACKBONE_REGISTRY  # isort:skip\nfrom .backbone import Backbone  # isort:skip\n\nfrom .v"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/alexnet.py",
    "chars": 1908,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\n\nfrom .build import BACKBONE_REGISTRY\nfrom "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/backbone.py",
    "chars": 336,
    "preview": "import torch.nn as nn\n\n\nclass Backbone(nn.Module):\n\n    def __init__(self):\n        super().__init__()\n\n    def forward("
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/build.py",
    "chars": 358,
    "preview": "from dassl.utils import Registry, check_availability\n\nBACKBONE_REGISTRY = Registry(\"BACKBONE\")\n\n\ndef build_backbone(name"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digit5_m3sda.py",
    "chars": 1819,
    "preview": "\"\"\"\nReference\n\nhttps://github.com/VisionLearningGroup/VisionLearningGroup.github.io/tree/master/M3SDA\n\"\"\"\nimport torch.n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsdg.py",
    "chars": 1630,
    "preview": "import torch.nn as nn\nfrom torch.nn import functional as F\n\nfrom dassl.utils import init_network_weights\n\nfrom .build im"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/cnn_digitsingle.py",
    "chars": 1265,
    "preview": "\"\"\"\nThis model is built based on\nhttps://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py\n\"\"\"\nimport t"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/__init__.py",
    "chars": 370,
    "preview": "\"\"\"\nSource: https://github.com/lukemelas/EfficientNet-PyTorch.\n\"\"\"\n__version__ = \"0.6.4\"\nfrom .model import (\n    Effici"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/model.py",
    "chars": 13021,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom .utils import (\n    Swish, MemoryEfficientS"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/efficientnet/utils.py",
    "chars": 15997,
    "preview": "\"\"\"\nThis file contains helper functions for building the model and for loading model parameters.\nThese helper functions "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/mobilenetv2.py",
    "chars": 6728,
    "preview": "import torch.utils.model_zoo as model_zoo\nfrom torch import nn\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone impo"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/preact_resnet18.py",
    "chars": 4014,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbo"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/resnet.py",
    "chars": 14605,
    "preview": "import torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone imp"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/shufflenetv2.py",
    "chars": 7007,
    "preview": "\"\"\"\nCode source: https://github.com/pytorch/vision\n\"\"\"\nimport torch\nimport torch.utils.model_zoo as model_zoo\nfrom torch"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/vgg.py",
    "chars": 3934,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .build import BACKBONE_REGISTRY\nfrom .backbone import Backbone\n\ntry:\n    from t"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/backbone/wide_resnet.py",
    "chars": 4403,
    "preview": "\"\"\"\nModified from https://github.com/xternalz/WideResNet-pytorch\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn."
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/head/__init__.py",
    "chars": 81,
    "preview": "from .build import build_head, HEAD_REGISTRY  # isort:skip\n\nfrom .mlp import mlp\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/head/build.py",
    "chars": 326,
    "preview": "from dassl.utils import Registry, check_availability\n\nHEAD_REGISTRY = Registry(\"HEAD\")\n\n\ndef build_head(name, verbose=Tr"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/head/mlp.py",
    "chars": 1202,
    "preview": "import functools\nimport torch.nn as nn\n\nfrom .build import HEAD_REGISTRY\n\n\nclass MLP(nn.Module):\n\n    def __init__(\n    "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/network/__init__.py",
    "chars": 164,
    "preview": "from .build import build_network, NETWORK_REGISTRY  # isort:skip\n\nfrom .ddaig_fcn import (\n    fcn_3x32_gctx, fcn_3x64_g"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/network/build.py",
    "chars": 346,
    "preview": "from dassl.utils import Registry, check_availability\n\nNETWORK_REGISTRY = Registry(\"NETWORK\")\n\n\ndef build_network(name, v"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/network/ddaig_fcn.py",
    "chars": 9613,
    "preview": "\"\"\"\nCredit to: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix\n\"\"\"\nimport functools\nimport torch\nimport torch.nn"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/__init__.py",
    "chars": 647,
    "preview": "from .mmd import MaximumMeanDiscrepancy\nfrom .dsbn import DSBN1d, DSBN2d\nfrom .mixup import mixup\nfrom .efdmix import (\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/cross_entropy.py",
    "chars": 1033,
    "preview": "import torch\nfrom torch.nn import functional as F\n\n\ndef cross_entropy(input, target, label_smooth=0, reduction=\"mean\"):\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/dsbn.py",
    "chars": 1159,
    "preview": "import torch.nn as nn\n\n\nclass _DSBN(nn.Module):\n    \"\"\"Domain Specific Batch Normalization.\n\n    Args:\n        num_featu"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/efdmix.py",
    "chars": 3159,
    "preview": "import random\nfrom contextlib import contextmanager\nimport torch\nimport torch.nn as nn\n\n\ndef deactivate_efdmix(m):\n    i"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/mixstyle.py",
    "chars": 3170,
    "preview": "import random\nfrom contextlib import contextmanager\nimport torch\nimport torch.nn as nn\n\n\ndef deactivate_mixstyle(m):\n   "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/mixup.py",
    "chars": 774,
    "preview": "import torch\n\n\ndef mixup(x1, x2, y1, y2, beta, preserve_order=False):\n    \"\"\"Mixup.\n\n    Args:\n        x1 (torch.Tensor)"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/mmd.py",
    "chars": 3082,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\nclass MaximumMeanDiscrepancy(nn.Module):\n\n    "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/optimal_transport.py",
    "chars": 4663,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\nclass OptimalTransport(nn.Module):\n\n    @stati"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/reverse_grad.py",
    "chars": 845,
    "preview": "import torch.nn as nn\nfrom torch.autograd import Function\n\n\nclass _ReverseGrad(Function):\n\n    @staticmethod\n    def for"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/sequential2.py",
    "chars": 427,
    "preview": "import torch.nn as nn\n\n\nclass Sequential2(nn.Sequential):\n    \"\"\"An alternative sequential container to nn.Sequential,\n "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/transnorm.py",
    "chars": 4669,
    "preview": "import torch\nimport torch.nn as nn\n\n\nclass _TransNorm(nn.Module):\n    \"\"\"Transferable normalization.\n\n    Reference:\n   "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/modeling/ops/utils.py",
    "chars": 1977,
    "preview": "import numpy as np\nimport torch\n\n\ndef sharpen_prob(p, temperature=2):\n    \"\"\"Sharpening probability with a temperature.\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/__init__.py",
    "chars": 84,
    "preview": "from .optimizer import build_optimizer\nfrom .lr_scheduler import build_lr_scheduler\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/lr_scheduler.py",
    "chars": 4225,
    "preview": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport torch\nfrom torch.optim.lr_scheduler import "
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/optimizer.py",
    "chars": 3654,
    "preview": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport warnings\nimport torch\nimport torch.nn as nn"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/optim/radam.py",
    "chars": 11648,
    "preview": "\"\"\"\nImported from: https://github.com/LiyuanLucasLiu/RAdam\n\nhttps://arxiv.org/abs/1908.03265\n\n@article{liu2019radam,\n  t"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/__init__.py",
    "chars": 115,
    "preview": "from .tools import *\nfrom .logger import *\nfrom .meters import *\nfrom .registry import *\nfrom .torchtools import *\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/logger.py",
    "chars": 1703,
    "preview": "import os\nimport sys\nimport time\nimport os.path as osp\n\nfrom .tools import mkdir_if_missing\n\n__all__ = [\"Logger\", \"setup"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/meters.py",
    "chars": 2163,
    "preview": "from collections import defaultdict\nimport torch\n\n__all__ = [\"AverageMeter\", \"MetricMeter\"]\n\n\nclass AverageMeter:\n    \"\""
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/registry.py",
    "chars": 1709,
    "preview": "\"\"\"\nModified from https://github.com/facebookresearch/fvcore\n\"\"\"\n__all__ = [\"Registry\"]\n\n\nclass Registry:\n    \"\"\"A regis"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/tools.py",
    "chars": 4685,
    "preview": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport os\nimport sys\nimport json\nimport time\nimpor"
  },
  {
    "path": "Dassl.ProGrad.pytorch/dassl/utils/torchtools.py",
    "chars": 10547,
    "preview": "\"\"\"\nModified from https://github.com/KaiyangZhou/deep-person-reid\n\"\"\"\nimport pickle\nimport shutil\nimport os.path as osp\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/da/cifar_stl.py",
    "chars": 2383,
    "preview": "import sys\nimport pprint as pp\nimport os.path as osp\nfrom torchvision.datasets import STL10, CIFAR10\n\nfrom dassl.utils i"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/da/digit5.py",
    "chars": 3771,
    "preview": "import os\nimport numpy as np\nimport os.path as osp\nimport argparse\nfrom PIL import Image\nfrom scipy.io import loadmat\n\n\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/da/visda17.sh",
    "chars": 801,
    "preview": "# ------------------------------------------------------------------------\n# ROOT is the root directory where you put yo"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/dg/cifar_c.py",
    "chars": 2165,
    "preview": "\"\"\"\nThis script\n- creates a folder named \"cifar10_c\" under the same directory as 'CIFAR-10-C'\n- extracts images from .np"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/ssl/cifar10_cifar100_svhn.py",
    "chars": 1528,
    "preview": "import sys\nimport os.path as osp\nfrom torchvision.datasets import SVHN, CIFAR10, CIFAR100\n\nfrom dassl.utils import mkdir"
  },
  {
    "path": "Dassl.ProGrad.pytorch/datasets/ssl/stl10.py",
    "chars": 1171,
    "preview": "import sys\nimport os.path as osp\nfrom torchvision.datasets import STL10\n\nfrom dassl.utils import mkdir_if_missing\n\n\ndef "
  },
  {
    "path": "Dassl.ProGrad.pytorch/linter.sh",
    "chars": 150,
    "preview": "echo \"Running isort\"\nisort -y -sp .\necho \"Done\"\n\necho \"Running yapf\"\nyapf -i -r -vv -e build .\necho \"Done\"\n\necho \"Runnin"
  },
  {
    "path": "Dassl.ProGrad.pytorch/requirements.txt",
    "chars": 94,
    "preview": "flake8==3.7.9\nyapf==0.29.0\nisort==4.3.21\nyacs\ngdown\ntb-nightly\nfuture\nscipy\nscikit-learn\ntqdm\n"
  },
  {
    "path": "Dassl.ProGrad.pytorch/setup.py",
    "chars": 1237,
    "preview": "import numpy as np\nimport os.path as osp\nfrom setuptools import setup, find_packages\n\n\ndef readme():\n    with open('READ"
  },
  {
    "path": "Dassl.ProGrad.pytorch/tools/parse_test_res.py",
    "chars": 4662,
    "preview": "\"\"\"\nGoal\n---\n1. Read test results from log.txt files\n2. Compute mean and std across different folders (seeds)\n\nUsage\n---"
  },
  {
    "path": "Dassl.ProGrad.pytorch/tools/replace_text.py",
    "chars": 1957,
    "preview": "\"\"\"\nReplace text in python files.\n\"\"\"\nimport glob\nimport os.path as osp\nimport argparse\nimport fileinput\n\nEXTENSION = \"."
  },
  {
    "path": "Dassl.ProGrad.pytorch/tools/train.py",
    "chars": 4800,
    "preview": "import argparse\nimport torch\n\nfrom dassl.utils import setup_logger, set_random_seed, collect_env_info\nfrom dassl.config "
  },
  {
    "path": "ProGrad.public/.gitignore",
    "chars": 1826,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "ProGrad.public/DATASETS.md",
    "chars": 9126,
    "preview": "# How to install datasets\n\nWe suggest putting all datasets under the same folder (say `$DATA`) to ease management and fo"
  },
  {
    "path": "ProGrad.public/LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) 2021 Kaiyang Zhou\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "ProGrad.public/README.md",
    "chars": 7286,
    "preview": "# How to Run\n\n## GPU memory needed\n\nAll the experiments is able to run on a single graphic card. However, **if you want "
  },
  {
    "path": "ProGrad.public/clip/__init__.py",
    "chars": 20,
    "preview": "from .clip import *\n"
  },
  {
    "path": "ProGrad.public/clip/clip.py",
    "chars": 8854,
    "preview": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom typing import Union, List\n\nimport torch\nfrom PIL import Imag"
  },
  {
    "path": "ProGrad.public/clip/model.py",
    "chars": 19281,
    "preview": "from collections import OrderedDict\nfrom typing import Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.fun"
  },
  {
    "path": "ProGrad.public/clip/simple_tokenizer.py",
    "chars": 5021,
    "preview": "import gzip\nimport html\nimport os\nfrom functools import lru_cache\n\nimport ftfy\nimport regex as re\n\n\n@lru_cache()\ndef def"
  },
  {
    "path": "ProGrad.public/configs/datasets/caltech101.yaml",
    "chars": 30,
    "preview": "DATASET:\n  NAME: \"Caltech101\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/dtd.yaml",
    "chars": 39,
    "preview": "DATASET:\n  NAME: \"DescribableTextures\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/eurosat.yaml",
    "chars": 27,
    "preview": "DATASET:\n  NAME: \"EuroSAT\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/fgvc_aircraft.yaml",
    "chars": 32,
    "preview": "DATASET:\n  NAME: \"FGVCAircraft\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/food101.yaml",
    "chars": 27,
    "preview": "DATASET:\n  NAME: \"Food101\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet.yaml",
    "chars": 28,
    "preview": "DATASET:\n  NAME: \"ImageNet\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet_a.yaml",
    "chars": 29,
    "preview": "DATASET:\n  NAME: \"ImageNetA\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet_r.yaml",
    "chars": 29,
    "preview": "DATASET:\n  NAME: \"ImageNetR\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenet_sketch.yaml",
    "chars": 34,
    "preview": "DATASET:\n  NAME: \"ImageNetSketch\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/imagenetv2.yaml",
    "chars": 30,
    "preview": "DATASET:\n  NAME: \"ImageNetV2\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/oxford_flowers.yaml",
    "chars": 32,
    "preview": "DATASET:\n  NAME: \"OxfordFlowers\""
  },
  {
    "path": "ProGrad.public/configs/datasets/oxford_pets.yaml",
    "chars": 29,
    "preview": "DATASET:\n  NAME: \"OxfordPets\""
  },
  {
    "path": "ProGrad.public/configs/datasets/stanford_cars.yaml",
    "chars": 32,
    "preview": "DATASET:\n  NAME: \"StanfordCars\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/sun397.yaml",
    "chars": 26,
    "preview": "DATASET:\n  NAME: \"SUN397\"\n"
  },
  {
    "path": "ProGrad.public/configs/datasets/ucf101.yaml",
    "chars": 26,
    "preview": "DATASET:\n  NAME: \"UCF101\"\n"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/rn50_c4_ep10_batch1_ctxv1.yaml",
    "chars": 578,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/rn50_ep100_init.yaml",
    "chars": 580,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/rn50_ep50.yaml",
    "chars": 579,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml",
    "chars": 581,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml",
    "chars": 580,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml",
    "chars": 592,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml",
    "chars": 580,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 1\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50.yaml",
    "chars": 510,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTE"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50_ep100.yaml",
    "chars": 511,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTE"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50_ep50.yaml",
    "chars": 509,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 100\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTE"
  },
  {
    "path": "ProGrad.public/configs/trainers/CoOp/rn50_val.yaml",
    "chars": 342,
    "preview": "DATALOADER:\n  TRAIN_X:\n    BATCH_SIZE: 32\n  TEST:\n    BATCH_SIZE: 32\n  NUM_WORKERS: 8\n\nINPUT:\n  SIZE: (224, 224)\n  INTER"
  }
]

// ... and 44 more files (download for full content)

About this extraction

This page contains the full source code of the BeierZhu/Prompt-align GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 244 files (606.4 KB), approximately 166.1k tokens, and a symbol index with 911 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!