Full Code of microsoft/robustdg for AI

master 3eee1730ae9e cached
125 files
3.1 MB
830.6k tokens
337 symbols
1 requests
Download .txt
Showing preview only (3,322K chars total). Download the full file or copy to clipboard to get everything.
Repository: microsoft/robustdg
Branch: master
Commit: 3eee1730ae9e
Files: 125
Total size: 3.1 MB

Directory structure:
gitextract_exyigezx/

├── .github/
│   └── workflows/
│       └── python-package.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.rst
├── SECURITY.md
├── algorithms/
│   ├── __init__.py
│   ├── algo.py
│   ├── csd.py
│   ├── dann.py
│   ├── erm.py
│   ├── erm_match.py
│   ├── hybrid.py
│   ├── irm.py
│   ├── match_dg.py
│   └── mmd.py
├── azure_scripts/
│   ├── chest.yaml
│   ├── chest_ctr.yaml
│   ├── chest_ctr_spur.yaml
│   ├── chest_matchdg.yaml
│   ├── chest_matchdg_spur.yaml
│   ├── chest_spur.yaml
│   ├── fmnist.yaml
│   ├── irm_fashion.yaml
│   ├── irm_mnist.yaml
│   ├── mnist.yaml
│   ├── mnist_ctr.yaml
│   ├── mnist_ctr_spur.yaml
│   ├── mnist_spur.yaml
│   ├── pacs.yaml
│   ├── pacs_art_painting.yaml
│   ├── pacs_cartoon.yaml
│   ├── pacs_ctr.yaml
│   ├── pacs_erm.yaml
│   ├── pacs_hybrid.yaml
│   ├── pacs_matchdg.yaml
│   ├── pacs_perfect.yaml
│   ├── pacs_photo.yaml
│   ├── pacs_random.yaml
│   ├── pacs_sketch.yaml
│   └── setup_data_mnist.yaml
├── chestxray_download.txt
├── data/
│   ├── __init__.py
│   ├── adult_loader.py
│   ├── chestxray_loader.py
│   ├── chestxray_loader_aug.py
│   ├── chestxray_loader_match_eval.py
│   ├── data_gen_domainbed.py
│   ├── data_gen_mnist.py
│   ├── data_loader.py
│   ├── mnist_loader.py
│   ├── mnist_loader_match_eval.py
│   ├── mnist_loader_match_eval_spur.py
│   ├── mnist_loader_spur.py
│   ├── pacs_loader.py
│   ├── pacs_loader_aug.py
│   ├── pacs_loader_match_eval.py
│   ├── slab_loader.py
│   └── slab_loader_spur.py
├── data_gen_syn.py
├── docs/
│   ├── _config.yml
│   └── notebooks/
│       ├── ChestXRay_Translate.ipynb
│       ├── Preprocess.ipynb
│       ├── Spur_Rotated_MNIST.ipynb
│       ├── beta/
│       │   ├── HParam_Plots.ipynb
│       │   ├── adult_dataset.ipynb
│       │   └── mnist_results.ipynb
│       ├── helper_plots.ipynb
│       ├── privacy_plots.ipynb
│       ├── reproduce_results.ipynb
│       └── robustdg_getting_started.ipynb
├── evaluation/
│   ├── attribute_attack.py
│   ├── base_eval.py
│   ├── feat_eval.py
│   ├── logit_hist.py
│   ├── match_eval.py
│   ├── per_domain_acc.py
│   ├── privacy_attack.py
│   ├── privacy_entropy.py
│   ├── privacy_loss_attack.py
│   ├── slab_feat_eval.py
│   └── t_sne.py
├── misc_scripts/
│   ├── adult.txt
│   └── logit_plot_slab.py
├── models/
│   ├── alexnet.py
│   ├── densenet.py
│   ├── domain_bed_mnist.py
│   ├── fc.py
│   ├── lenet.py
│   ├── resnet.py
│   └── slab.py
├── reproduce_scripts/
│   ├── cxray_plot.py
│   ├── cxray_run.py
│   ├── mnist_mdg_ctr_run.py
│   ├── mnist_plot.py
│   ├── mnist_run.py
│   ├── pacs_run.py
│   ├── reproduce_rmnist_domainbed.py
│   ├── reproduce_rmnist_lenet.py
│   ├── reproduce_slab.py
│   ├── slab-plot.py
│   ├── slab-run.py
│   └── slab-tune.py
├── requirements.txt
├── requirements_new.txt
├── test.py
├── test_slab.py
├── train.py
└── utils/
    ├── __init__.py
    ├── attribute_attack.py
    ├── bnlearn_data.py
    ├── helper.py
    ├── match_function.py
    ├── privacy_attack.py
    ├── scripts/
    │   ├── __init__.py
    │   ├── data_utils.py
    │   ├── ensemble.py
    │   ├── gendata.py
    │   ├── gpu_utils.py
    │   ├── lms_utils.py
    │   ├── mnistcifar_utils.py
    │   ├── ptb_utils.py
    │   ├── synth_models.py
    │   └── utils.py
    └── slab_data.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/python-package.yml
================================================
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Python package

on:
  push:
    branches: [ master ]
  pull_request:
    branches: [ master ]

jobs:
  build:

    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: [3.5, 3.6, 3.7, 3.8]

    steps:
    - uses: actions/checkout@v2
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v2
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install flake8 pytest
        if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
    - name: Lint with flake8
      run: |
        # stop the build if there are Python syntax errors or undefined names
        flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
        # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
        flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
    - name: Test with pytest
      run: |
        pytest


================================================
FILE: .gitignore
================================================
#Data diri
data/datasets/

#python environment
amt_envs/
matchdg-env/

#Results dir
results/

#extra-files
#*.sh

#philly_tools
.ptignore
.amltignore
pt/
amlt/
#*.yaml
.ptconfig
.amltconfig

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Microsoft Open Source Code of Conduct

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).

Resources:

- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns


================================================
FILE: LICENSE
================================================
    MIT License

    Copyright (c) Microsoft Corporation.

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in all
    copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    SOFTWARE


================================================
FILE: README.rst
================================================
Toolkit for Building Robust ML models that generalize to unseen domains (RobustDG)
==================================================================================
`Divyat Mahajan <https://divyat09.github.io/>`_, 
`Shruti Tople <https://www.microsoft.com/en-us/research/people/shtople/>`_, 
`Amit Sharma <http://www.amitsharma.in>`_

`Privacy & Causal Learning (ICML 2020) <https://arxiv.org/abs/1909.12732>`_ | `MatchDG: Causal View of DG (ICML 2021) <http://proceedings.mlr.press/v139/mahajan21b.html>`_ | `Privacy & DG Connection paper <https://arxiv.org/abs/2110.03369>`_

For machine learning models to be reliable, they need to generalize to data
beyond the train distribution. In addition, ML models should be robust to
privacy attacks like membership inference and domain knowledge-based attacks like adversarial attacks.

To advance research in building robust and generalizable models, we are
releasing a toolkit for building and evaluating ML models, *RobustDG*. RobustDG contains implementations of domain
generalization algorithms and includes evaluation benchmarks based
on out-of-distribution accuracy and robustness to membership privacy attacks. We will be adding evaluation for adversarial attacks and more privacy attacks soon. 

It is easily extendable. Add your own DG algorithms and evaluate them on different benchmarks.


Installation
------------
To use the command-line interface of RobustDG, clone this repo and add the folder to your system's PATH (or alternatively, run the commands from the RobustDG root directory). 

**Load dataset**

Let's first load the rotatedMNIST dataset in a suitable format for the resnet18 architecture.

.. code:: shell

    python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000

**Train and evaluate ML model**

The following commands would train and evalute the MatchDG method on the Rotated MNIST dataset.

.. code:: shell


    python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --match_func_aug_case 1
    
    python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25
    
    python test.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25 --test_metric acc
    
    python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score    


Demo
----

A quick introduction on how to use our repository can be accessed here in the `Getting Started notebook <https://github.com/microsoft/robustdg/blob/master/docs/notebooks/robustdg_getting_started.ipynb>`_.

If you are interested in reproducing results from the MatchDG paper, check out the `Reproducing results notebook <https://github.com/microsoft/robustdg/blob/master/docs/notebooks/reproduce_results.ipynb>`_. 

Roadmap
-------

* Support for more domain generalization algorithms like CSD and IRM. If you are an author of a DG algorithm and would like to contribute, please raise a  pull request `here <https://github.com/microsoft/robustdg/pulls>`_ or get in touch.

* More evaluation metrics based on adversarial attacks, privacy attacks like model inversion. If you'd like to see an evaluation metric implemented, please raise an issue `here <https://github.com/microsoft/robustdg/issues>`_.

Contributing
--------------

This project welcomes contributions and suggestions.  Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the `Microsoft Open Source Code of Conduct <https://opensource.microsoft.com/codeofconduct/>`_.
For more information see the `Code of Conduct FAQ <https://opensource.microsoft.com/codeofconduct/faq/>`_ or
contact `opencode@microsoft.com <mailto:opencode@microsoft.com>`_ with any additional questions or comments.


================================================
FILE: SECURITY.md
================================================
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->

## Security

Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).

If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.

## Reporting Security Issues

**Please do not report security vulnerabilities through public GitHub issues.**

Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).

If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).

You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 

Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:

  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
  * Full paths of source file(s) related to the manifestation of the issue
  * The location of the affected source code (tag/branch/commit or direct URL)
  * Any special configuration required to reproduce the issue
  * Step-by-step instructions to reproduce the issue
  * Proof-of-concept or exploit code (if possible)
  * Impact of the issue, including how an attacker might exploit the issue

This information will help us triage your report more quickly.

If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.

## Preferred Languages

We prefer all communications to be in English.

## Policy

Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).

<!-- END MICROSOFT SECURITY.MD BLOCK -->

================================================
FILE: algorithms/__init__.py
================================================


================================================
FILE: algorithms/algo.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json
import os
from more_itertools import chunked

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils


from utils.match_function import get_matched_pairs

print('From Inside the Algo Class: ', sys.argv[0])


def get_noise_multiplier(
    target_epsilon: float,
    target_delta: float,
    sample_rate: float,
    epochs: int,
    alphas: [float],
    sigma_min: float = 0.01,
    sigma_max: float = 10.0,
) -> float:
    r"""
    Computes the noise level sigma to reach a total budget of (target_epsilon, target_delta)
    at the end of epochs, with a given sample_rate

    Args:
        target_epsilon: the privacy budget's epsilon
        target_delta: the privacy budget's delta
        sample_rate: the sampling rate (usually batch_size / n_data)
        epochs: the number of epochs to run
        alphas: the list of orders at which to compute RDP

    Returns:
        The noise level sigma to ensure privacy budget of (target_epsilon, target_delta)

    """
    
    from opacus import privacy_analysis
    
    eps = float("inf")
    while eps > target_epsilon:
        sigma_max = 2 * sigma_max
        rdp = privacy_analysis.compute_rdp(
            sample_rate, sigma_max, epochs / sample_rate, alphas
        )
        eps = privacy_analysis.get_privacy_spent(alphas, rdp, target_delta)[0]
        if sigma_max > 2000:
            raise ValueError("The privacy budget is too low.")

    while sigma_max - sigma_min > 0.01:
        sigma = (sigma_min + sigma_max) / 2
        rdp = privacy_analysis.compute_rdp(
            sample_rate, sigma, epochs / sample_rate, alphas
        )
        eps = privacy_analysis.get_privacy_spent(alphas, rdp, target_delta)[0]

        if eps < target_epsilon:
            sigma_max = sigma
        else:
            sigma_min = sigma

    return sigma


class BaseAlgo():
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, run, cuda):
        
        
        self.args= args
        self.train_dataset= train_dataset['data_loader']
        if args.method_name == 'matchdg_ctr':
            self.val_dataset= val_dataset
        else:
            self.val_dataset= val_dataset['data_loader']
        self.test_dataset= test_dataset['data_loader']
        
        self.train_domains= train_dataset['domain_list']
        self.total_domains= train_dataset['total_domains']
        self.domain_size= train_dataset['base_domain_size'] 
        self.training_list_size= train_dataset['domain_size_list']
        
        self.base_res_dir= base_res_dir
        self.run= run
        self.cuda= cuda
        
        self.post_string= str(self.args.penalty_ws) + '_' + str(self.args.penalty_diff_ctr) + '_' + str(self.args.match_case) + '_' + str(self.args.match_interrupt) + '_' + str(self.args.match_flag) + '_' + str(self.run) + '_' + self.args.pos_metric + '_' + self.args.model_name
        
        self.phi= self.get_model()
        self.opt= self.get_opt()
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=25)    
        
        self.final_acc=[]
        self.val_acc=[]
        self.train_acc=[]
        
        # Differentially Private Noise
        if self.args.dp_noise:
            self.privacy_engine= self.get_dp_noise()
    
    def get_model(self):
        
        if self.args.model_name == 'lenet':
            from models.lenet import LeNet5
            phi= LeNet5()

        if self.args.model_name == 'slab':
            from models.slab import SlabClf
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer=0
            else:
                fc_layer= self.args.fc_layer
            phi= SlabClf(self.args.slab_data_dim, self.args.out_classes, fc_layer)
            
        if self.args.model_name == 'fc':
            from models.fc import FC
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer=0
            else:
                fc_layer= self.args.fc_layer
            phi= FC(self.args.out_classes, fc_layer)
            
        if self.args.model_name == 'domain_bed_mnist':
            from models.domain_bed_mnist import DomainBed
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer=0
            else:
                fc_layer= self.args.fc_layer                        
            phi= DomainBed(self.args.img_c, fc_layer)
            
        if self.args.model_name == 'alexnet':
            from models.alexnet import alexnet
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer=0
            else:
                fc_layer= self.args.fc_layer            
            phi= alexnet(self.args.model_name, self.args.out_classes, fc_layer, 
                            self.args.img_c, self.args.pre_trained, self.args.os_env)
            
        if 'resnet' in self.args.model_name:
            from models.resnet import get_resnet
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer=0
            else:
                fc_layer= self.args.fc_layer
            phi= get_resnet(self.args.model_name, self.args.out_classes, fc_layer, 
                            self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env)
            
        if 'densenet' in self.args.model_name:
            from models.densenet import get_densenet
            if self.args.method_name in ['csd', 'matchdg_ctr']:
                fc_layer=0
            else:
                fc_layer= self.args.fc_layer
            phi= get_densenet(self.args.model_name, self.args.out_classes, fc_layer, 
                            self.args.img_c, self.args.pre_trained, self.args.os_env)
            
        print('Model Architecture: ', self.args.model_name)
        phi=phi.to(self.cuda)
        return phi
    
    def save_model(self):
        # Store the weights of the model
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model_' + self.post_string + '.pth')
        
        # Store the validation, test loss over the training epochs
        np.save( self.base_res_dir + '/Val_Acc_' + self.post_string + '.npy', np.array(self.val_acc) )
        np.save( self.base_res_dir + '/Test_Acc_' + self.post_string + '.npy', np.array(self.final_acc))
    
    def get_opt(self):
        if self.args.opt == 'sgd':
            opt= optim.SGD([
                         {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                ], lr= self.args.lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )        
        elif self.args.opt == 'adam':
            opt= optim.Adam([
                        {'params': filter(lambda p: p.requires_grad, self.phi.parameters())},
                ], lr= self.args.lr)
        
        return opt

    
    def get_match_function(self, inferred_match, phi):
        
        data_matched, domain_data, _= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, phi, self.args.match_case, self.args.perfect_match, inferred_match )
                
#         #Start initially with randomly defined batch; else find the local approximate batch
#         if epoch > 0:                    
#             inferred_match=1
#             if self.args.match_flag:
#                 data_matched, domain_data, _= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
#             else:
#                 temp_1, temp_2, _= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )                
#         else:
#             inferred_match=0
#             data_matched, domain_data, _= get_matched_pairs( self.args, self.cuda, self.train_dataset, self.domain_size, self.total_domains, self.training_list_size, self.phi, self.args.match_case, perfect_match, inferred_match )
        
        
        # Randomly Shuffle the list of matched data indices and divide as per batch sizes
        random.shuffle(data_matched)
        data_matched= list(chunked(data_matched, self.args.batch_size))
        
        return data_matched, domain_data

    def get_match_function_batch(self, batch_idx):
        curr_data_matched= self.data_matched[batch_idx]
        curr_batch_size= len(curr_data_matched)

        data_match_tensor=[]
        label_match_tensor=[]
        for idx in range(curr_batch_size):
            data_temp=[]
            label_temp= []
            for d_i in range(len(curr_data_matched[idx])):
                key= random.choice( curr_data_matched[idx][d_i] )
                data_temp.append(self.domain_data[d_i]['data'][key])
                label_temp.append(self.domain_data[d_i]['label'][key])
            
            data_match_tensor.append( torch.stack(data_temp) )
            label_match_tensor.append( torch.stack(label_temp) )                    

        data_match_tensor= torch.stack( data_match_tensor ) 
        label_match_tensor= torch.stack( label_match_tensor )
#         print('Shape: ', data_match_tensor.shape, label_match_tensor.shape)
        
        return data_match_tensor, label_match_tensor, curr_batch_size
    
    def get_test_accuracy(self, case):
        import opacus
        
        if self.args.dp_noise:
            opacus.autograd_grad_sample.disable_hooks()
            #self.privacy_engine.module.disable_hooks()
        
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset= self.val_dataset
        elif case == 'test':
            dataset= self.test_dataset

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                
                self.opt.zero_grad()
#                 print(x_e.shape)
#                 print(torch.cuda.memory_allocated())                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                #Forward Pass
                out= self.phi(x_e)                
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
                test_size+= y_e.shape[0]
                
                # To avoid CUDA memory issues
                if self.args.dp_noise:
                    self.opt.zero_grad()

        print(' Accuracy: ', case, 100*test_acc/test_size )         
                
        #self.privacy_engine.module.enable_hooks()
        opacus.autograd_grad_sample.enable_hooks()        
        return 100*test_acc/test_size
    
    def get_dp_noise(self):
        
        print('Privacy Engine')
        print('Total Domains: ', self.total_domains, ' Domain Size ', self.domain_size, ' Batch Size ', self.args.batch_size)
        
        from opacus.dp_model_inspector import DPModelInspector
        from opacus.utils import module_modification
        
        inspector = DPModelInspector()        
#         self.phi = module_modification.convert_batchnorm_modules(self.phi) 
        inspector.validate(self.phi)
        
        MAX_GRAD_NORM = 5.0
        DELTA = 1.0/(self.total_domains*self.domain_size)
        BATCH_SIZE = self.args.batch_size * self.total_domains
        SAMPLE_RATE = BATCH_SIZE /(self.total_domains*self.domain_size)
        DEFAULT_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))
                
        NOISE_MULTIPLIER = get_noise_multiplier(self.args.dp_epsilon, DELTA, SAMPLE_RATE, self.args.epochs, DEFAULT_ALPHAS)
        
        print("Target Epsilon: ", self.args.dp_epsilon)
        print(f"Using sigma={NOISE_MULTIPLIER} and C={MAX_GRAD_NORM}")
        
                
        from opacus import PrivacyEngine        
        privacy_engine = PrivacyEngine(
            self.phi,
            batch_size= BATCH_SIZE,
            sample_size= self.total_domains*self.domain_size,
            noise_multiplier=NOISE_MULTIPLIER,
            max_grad_norm=MAX_GRAD_NORM,
        )
        
        if self.args.dp_attach_opt:
            print('Standard DP Training with finite epsilon')
            privacy_engine.attach(self.opt)
        else:
            print('DP Training with infinite epsilon')
            
        return privacy_engine

================================================
FILE: algorithms/csd.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity


class CSD(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda)         
        
        # H_dim as per the feature layer dimension of ResNet-18
        H_dim= self.args.rep_dim
        self.K, m, self.num_classes = 1, H_dim, self.args.out_classes 
        num_domains = self.total_domains

        self.sms = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, m, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
        self.sm_biases = torch.nn.Parameter(torch.normal(0, 1e-1, size=[self.K+1, self.num_classes], dtype=torch.float, device='cuda:0'), requires_grad=True)
    
        self.embs = torch.nn.Parameter(torch.normal(mean=0., std=1e-1, size=[num_domains, self.K], dtype=torch.float, device='cuda:0'), requires_grad=True)
        self.cs_wt = torch.nn.Parameter(torch.normal(mean=.1, std=1e-4, size=[], dtype=torch.float, device='cuda:0'), requires_grad=True)

        self.opt= optim.SGD([
                         {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) },
                         {'params': self.sms },
                         {'params': self.sm_biases },
                         {'params': self.embs },
                         {'params': self.cs_wt }
                ], lr= self.args.lr, weight_decay= 5e-4, momentum= 0.9,  nesterov=True )          
        
        self.criterion = torch.nn.CrossEntropyLoss()
        
    def forward(self, x, y, di, eval_case=0):
        x = self.phi(x)        
        w_c, b_c = self.sms[0, :, :], self.sm_biases[0, :]
        logits_common = torch.matmul(x, w_c) + b_c
       
        if eval_case:
            return logits_common
 
        domains= di
        c_wts = torch.matmul(domains, self.embs)
    
        # B x K
        batch_size = x.shape[0]
        c_wts = torch.cat((torch.ones((batch_size, 1), dtype=torch.float).to(self.cuda)*self.cs_wt, c_wts), 1)
        c_wts = torch.tanh(c_wts).to(self.cuda)
        w_d, b_d = torch.einsum("bk,krl->brl", c_wts, self.sms), torch.einsum("bk,kl->bl", c_wts, self.sm_biases)
        logits_specialized = torch.einsum("brl,br->bl", w_d, x) + b_d

        specific_loss = self.criterion(logits_specialized, y)
        class_loss = self.criterion(logits_common, y)

        sms = self.sms
        diag_tensor = torch.stack([torch.eye(self.K+1).to(self.cuda) for _ in range(self.num_classes)], dim=0)
        cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.num_classes)], dim=0)
        orth_loss = torch.mean((1-diag_tensor)*(cps - diag_tensor)**2)

        loss = class_loss + specific_loss + orth_loss 
        return loss, class_loss, logits_common
    
    def epoch_callback(self, nepoch, final=False):
        if nepoch % 100 == 0:
            print (self.embs, torch.norm(self.sms[0]), torch.norm(self.sms[1]))
                          
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0
        for epoch in range(self.args.epochs):   
            
            penalty_erm=0
            penalty_csd=0
            train_acc= 0.0
            train_size=0
    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                
                #Forward Pass
                csd_loss, erm_loss, out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=0)
                loss_e+= csd_loss
                penalty_csd += float(loss_e)
                penalty_erm += float(erm_loss)

                #Backprorp
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del csd_loss
                del loss_e
                torch.cuda.empty_cache()
        
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]
                
   
            print('Train Loss Basic : ',  penalty_erm, penalty_csd - penalty_erm )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )

            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()                
                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])



    def get_test_accuracy(self, case):
        
        #Test Env Code
        test_acc= 0.0
        test_size=0
        if case == 'val':
            dataset= self.val_dataset
        elif case == 'test':
            dataset= self.test_dataset

        for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(dataset):
            with torch.no_grad():
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)

                #Forward Pass
                out= self.forward(x_e, y_e, d_e.to(self.cuda), eval_case=1)
                
                test_acc+= torch.sum( torch.argmax(out, dim=1) == y_e ).item()
                test_size+= y_e.shape[0]
                
        print(' Accuracy: ', case,  100*test_acc/test_size )         
        
        return 100*test_acc/test_size
    
    def save_model(self):
        # Store the weights of the model
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model_' + self.post_string + '.pth')
        # Store the parameters
        torch.save(self.sms, self.base_res_dir + '/Sms_' + self.post_string + ".pt")
        torch.save(self.sm_biases, self.base_res_dir + '/SmBiases_' + self.post_string + ".pt")

================================================
FILE: algorithms/dann.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json
import time

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils
import torch.autograd as autograd

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity

class DANN(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 
        
        self.conditional = bool(self.args.conditional)
        self.class_balance = False        
        
        self.featurizer = self.phi.feat_net
        self.classifier = self.phi.fc
        self.discriminator = self.phi.disc
        self.class_embeddings = self.phi.embedding
        
        self.grad_penalty= self.args.grad_penalty
        self.lambda_= self.args.penalty_ws
        self.d_steps_per_g_step= self.args.d_steps_per_g_step
        self.initial_lr= 0.01
        
        # Optimizers
        self.disc_opt = torch.optim.SGD(
            (list(self.discriminator.parameters()) + 
                list(self.class_embeddings.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)

        self.gen_opt = torch.optim.SGD(
            (list(self.featurizer.parameters()) + 
                list(self.classifier.parameters())),
            lr=self.initial_lr,
            weight_decay=5e-4)     
        
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0
        for epoch in range(self.args.epochs):   
                    
            penalty_erm=0
            penalty_dann=0
            train_acc= 0.0
            train_size=0
                    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).to(self.cuda)
        
                all_x = x_e
                all_y = y_e
                all_z = self.featurizer(all_x)
                if self.conditional:
                    disc_input = all_z + self.class_embeddings(all_y)
                else:
                    disc_input = all_z
                disc_out = self.discriminator(disc_input)
                disc_labels = d_e        
            
                if self.class_balance:
                    y_counts = F.one_hot(all_y).sum(dim=0)
                    weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float()
                    disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
                    disc_loss = (weights * disc_loss).sum()
                else:
                    disc_loss = F.cross_entropy(disc_out, disc_labels)

                disc_softmax = F.softmax(disc_out, dim=1)
                input_grad = autograd.grad(disc_softmax[:, disc_labels].sum(),
                    [disc_input], create_graph=True)[0]
                grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
                                
                #Disc Loss
                disc_loss += self.grad_penalty * grad_penalty

                #Gen Loss
                all_preds = self.classifier(all_z)
                classifier_loss = F.cross_entropy(all_preds, all_y)
                gen_loss = (classifier_loss +
                            (self.lambda_ * -disc_loss))

                penalty_erm += float(classifier_loss)
                penalty_dann += float(disc_loss)
                
                d_steps_per_g = self.d_steps_per_g_step
                if (epoch % (1+d_steps_per_g) < d_steps_per_g):
                    self.disc_opt.zero_grad()
                    disc_loss.backward()
                    self.disc_opt.step()
                else:
                    self.disc_opt.zero_grad()
                    self.gen_opt.zero_grad()
                    gen_loss.backward()
                    self.gen_opt.step()
                
                del classifier_loss
                del gen_loss 
                del disc_loss
                torch.cuda.empty_cache()
                
                #Forward Pass
                features = self.featurizer(x_e)
                out = self.classifier(features)                
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]                
                        
   
            print('Train Loss Basic : ',  penalty_erm, penalty_dann )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])

================================================
FILE: algorithms/erm.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity

class Erm(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 
              
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0        
        for epoch in range(self.args.epochs):   
            
            if epoch ==0 or (epoch % self.args.match_interrupt == 0 and self.args.match_flag):
                data_match_tensor, label_match_tensor= self.get_match_function(epoch)
            
            penalty_erm=0
            penalty_ws=0
            train_acc= 0.0
            train_size=0
    
            perm = torch.randperm(data_match_tensor.size(0))            
            data_match_tensor_split= torch.split(data_match_tensor[perm], self.args.batch_size, dim=0)
            label_match_tensor_split= torch.split(label_match_tensor[perm], self.args.batch_size, dim=0)
            print('Split Matched Data: ', len(data_match_tensor_split), data_match_tensor_split[0].shape, len(label_match_tensor_split))
    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).numpy()
                
                #Forward Pass
                out= self.phi(x_e)
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)

                #Backprorp
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del loss_e
                torch.cuda.empty_cache()
        
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]
                
   
            print('Train Loss Basic : ',  penalty_erm )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )

            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
            
        # Save the model's weights post training
        self.save_model()

================================================
FILE: algorithms/erm_match.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json
import time

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity

class ErmMatch(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 

                
    def train(self):
        
        self.max_epoch= -1
        self.max_val_acc= 0.0
        for epoch in range(self.args.epochs):   
            
            if epoch ==0:
                inferred_match= 0
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)                
            elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
                inferred_match= 1
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
        
            penalty_erm=0
            penalty_ws=0
            train_acc= 0.0
            train_size=0
            
            total_grad_norm= []
            
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
                
#                 self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).numpy()
                

                wasserstein_loss=torch.tensor(0.0).to(self.cuda)
                erm_loss= torch.tensor(0.0).to(self.cuda) 
                if epoch > self.args.penalty_s:
                    # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                    total_batch_size= len(self.data_matched)
                    if batch_idx >= total_batch_size:
                        break
                    
                    # Sample batch from matched data points
                    data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)
                        
                    data_match= data_match_tensor.to(self.cuda)
                    data_match= data_match.flatten(start_dim=0, end_dim=1)
                    feat_match= self.phi( data_match )
#                     print(feat_match.shape)
            
                    label_match= label_match_tensor.to(self.cuda)
                    label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )
                
                    erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda)
                    penalty_erm+= float(erm_loss)           
                    loss_e += erm_loss
                    
                    train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item()
                    train_size+= label_match.shape[0]
                        
                    # Creating tensor of shape ( domain size, total domains, feat size )
                    feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))                    
            
                    #Positive Match Loss
                    pos_match_counter=0
                    for d_i in range(feat_match.shape[1]):
        #                 if d_i != base_domain_idx:
        #                     continue
                        for d_j in range(feat_match.shape[1]):
                            if d_j > d_i:                        
                                if self.args.pos_metric == 'l2':
                                    wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) ) 
                                elif self.args.pos_metric == 'l1':
                                    wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) )        
                                elif self.args.pos_metric == 'cos':
                                    wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) )

                                pos_match_counter += feat_match.shape[0]

                    wasserstein_loss = wasserstein_loss / pos_match_counter
                    penalty_ws+= float(wasserstein_loss)                            
                
                    if epoch >= self.args.match_interrupt and self.args.match_flag==1:
                        loss_e += ( self.args.penalty_ws*( epoch - self.args.penalty_s - self.args.match_interrupt )/(self.args.epochs - self.args.penalty_s - self.args.match_interrupt) )*wasserstein_loss
                    else:
                        loss_e += ( self.args.penalty_ws*( epoch- self.args.penalty_s )/(self.args.epochs - self.args.penalty_s) )*wasserstein_loss

                        
                loss_e.backward(retain_graph=False)
                
                if self.args.dp_noise and self.args.dp_attach_opt:
                    if batch_idx % 10 == 9:
                        self.opt.step()
                        self.opt.zero_grad()
                    else:
                        self.opt.virtual_step()                        
                else:                    
                    self.opt.step()
                    self.opt.zero_grad()
                    
                #Gradient Norm Computation
#                 batch_grad_norm=0.0
#                 for p in self.phi.parameters():
#                     param_norm = p.grad.detach().data.norm(2)
#                     batch_grad_norm += param_norm.item() ** 2
#                 batch_grad_norm = batch_grad_norm ** 0.5
#                 total_grad_norm.append( batch_grad_norm )
    
#                 del out
                del erm_loss
                del wasserstein_loss 
                del loss_e
                torch.cuda.empty_cache()
                        

#             print('Gradient Norm: ', total_grad_norm)
                    
            print('Train Loss Basic : ',  penalty_erm, penalty_ws )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])
            
            if epoch > 0 and self.args.model_name in ['domain_bed_mnist', 'lenet']:
                if self.args.model_name == 'lenet':
                    lr_schedule_step= 25
                elif self.args.model_name == 'domain_bed_mnist':
                    lr_schedule_step= 10
            
                if epoch % lr_schedule_step==0 :
                    lr=self.args.lr/(2**(int(epoch/lr_schedule_step)))
                    print('Learning Rate Scheduling; New LR: ', lr)                
                    self.opt= optim.SGD([
                             {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                    ], lr= lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )     
                

================================================
FILE: algorithms/hybrid.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json
import os

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity, get_dataloader
from utils.match_function import get_matched_pairs

class Hybrid(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 
        
        self.ctr_save_post_string= str(self.args.match_case) + '_' + str(self.args.match_interrupt) + '_' + str(self.args.match_flag) + '_' + str(self.run) + '_' + self.args.model_name
        self.ctr_load_post_string= str(self.args.ctr_match_case) + '_' + str(self.args.ctr_match_interrupt) + '_' + str(self.args.ctr_match_flag) + '_' + str(self.run) + '_' + self.args.ctr_model_name
                    
    def save_model_erm_phase(self, run):
        
        if not os.path.exists(self.base_res_dir + '/' + self.ctr_load_post_string):
            os.makedirs(self.base_res_dir + '/' + self.ctr_load_post_string)         
                
        # Store the weights of the model
        torch.save(self.phi.state_dict(), self.base_res_dir + '/' + self.ctr_load_post_string + '/Model_' + self.post_string + '_' + str(run) + '.pth')
    
    def init_erm_phase(self):
            
            if self.args.ctr_model_name == 'lenet':
                from models.lenet import LeNet5
                ctr_phi= LeNet5().to(self.cuda)
                
            if self.args.model_name == 'slab':
                from models.slab import SlabClf
                fc_layer=0
                ctr_phi= SlabClf(self.args.slab_data_dim, self.args.out_classes, fc_layer).to(self.cuda)
                
            if self.args.ctr_model_name == 'alexnet':
                from models.alexnet import alexnet
                ctr_phi= alexnet(self.args.out_classes, self.args.pre_trained, 'matchdg_ctr').to(self.cuda)                
            if self.args.ctr_model_name == 'fc':
                from models.fc import FC
                fc_layer=0
                ctr_phi= FC(self.args.out_classes, fc_layer).to(self.cuda)              
            if 'resnet' in self.args.ctr_model_name:
                from models.resnet import get_resnet
                fc_layer=0                
                ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda)
            if 'densenet' in self.args.ctr_model_name:
                from models.densenet import get_densenet
                fc_layer=0
                ctr_phi= get_densenet(self.args.ctr_model_name, self.args.out_classes, fc_layer, 
                                self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda)

                
            # Load MatchDG CTR phase model from the saved weights
            if self.args.os_env:
                base_res_dir=os.getenv('PT_DATA_DIR') + '/' + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains)             
            else:
                base_res_dir="results/" + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains)             
                                
            #TODO: Handle slab noise case in helper functions
            if self.args.dataset_name == 'slab':
                base_res_dir= base_res_dir + '/slab_noise_'  + str(self.args.slab_noise)
                
            save_path= base_res_dir + '/Model_' + self.ctr_load_post_string + '.pth'
            ctr_phi.load_state_dict( torch.load(save_path) )
            ctr_phi.eval()

            #Inferred Match Case
            if self.args.match_case == -1:
                inferred_match=1
            # x% percentage match initial strategy 
            else:
                inferred_match=0                
                
            data_matched, domain_data= self.get_match_function(inferred_match, ctr_phi)

            return data_matched, domain_data
            
            
    def train(self):
        
        for run_erm in range(self.args.n_runs_matchdg_erm):  
            
            self.max_epoch=-1
            self.max_val_acc=0.0
            for epoch in range(self.args.epochs):    
                
                if epoch ==0:
                    self.data_matched, self.domain_data= self.init_erm_phase()
                elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
                    inferred_match= 1
                    self.data_match_tensor, self.label_match_tensor= self.get_match_function(inferred_match, self.phi)

                penalty_erm=0
                penalty_erm_extra=0
                penalty_ws=0
                penalty_aug=0
                train_acc= 0.0
                train_size=0

                #Batch iteration over single epoch
                for batch_idx, (x_e, x_org_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
            #         print('Batch Idx: ', batch_idx)

                    self.opt.zero_grad()
                    loss_e= torch.tensor(0.0).to(self.cuda)

                    x_e= x_e.to(self.cuda)
                    x_org_e= x_org_e.to(self.cuda)
                    y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                    d_e= torch.argmax(d_e, dim=1).numpy()

                    #Forward Pass
                    out= self.phi(x_e)
                    erm_loss_extra= F.cross_entropy(out, y_e.long()).to(self.cuda)
                    penalty_erm_extra += float(erm_loss_extra)
                    
                    #Perfect Match on Augmentations
                    out_org= self.phi(x_org_e)
#                     diff_indices= out != out_org
#                     out= out[diff_indices]
#                     out_org= out_org[diff_indices]
                    augmentation_loss=torch.tensor(0.0).to(self.cuda)
                    if self.args.pos_metric == 'l2':
                        augmentation_loss+= torch.sum( torch.sum( (out -out_org)**2, dim=1 ) ) 
                    elif self.args.pos_metric == 'l1':
                        augmentation_loss+= torch.sum( torch.sum( torch.abs(out -out_org), dim=1 ) )        
                    elif self.args.pos_metric == 'cos':
                        augmentation_loss+= torch.sum( cosine_similarity( out, out_org ) )

                    augmentation_loss = augmentation_loss / out.shape[0]
#                     print('Augmented Images Fraction: ', out.shape, self.args.batch_size, augmentation_loss)
                    penalty_aug+= float(augmentation_loss)                            

                    wasserstein_loss=torch.tensor(0.0).to(self.cuda)
                    erm_loss= torch.tensor(0.0).to(self.cuda) 
                    if epoch > self.args.penalty_s:                    
                        # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                        total_batch_size= len(self.data_matched)
                        if batch_idx >= total_batch_size:
                            break
                            
                        # Sample batch from matched data points
                        data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)                               
                        data_match= data_match_tensor.to(self.cuda)
                        data_match= data_match.flatten(start_dim=0, end_dim=1)
                        feat_match= self.phi( data_match )

                        label_match= label_match_tensor.to(self.cuda)
                        label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )

                        erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda)
                        penalty_erm+= float(erm_loss) 
                        
                        train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item()
                        train_size+= label_match.shape[0]                        

                        # Creating tensor of shape ( domain size, total domains, feat size )
                        feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))                    
                        label_match= torch.stack(torch.split(label_match, len(self.train_domains)))

                        #Positive Match Loss
                        pos_match_counter=0
                        for d_i in range(feat_match.shape[1]):
            #                 if d_i != base_domain_idx:
            #                     continue
                            for d_j in range(feat_match.shape[1]):
                                if d_j > d_i:                        
                                    if self.args.pos_metric == 'l2':
                                        wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) ) 
                                    elif self.args.pos_metric == 'l1':
                                        wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) )        
                                    elif self.args.pos_metric == 'cos':
                                        wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) )

                                    pos_match_counter += feat_match.shape[0]

                        wasserstein_loss = wasserstein_loss / pos_match_counter
                        penalty_ws+= float(wasserstein_loss)                            


                        loss_e += ( self.args.penalty_ws*( epoch- self.args.penalty_s )/(self.args.epochs - self.args.penalty_s) )*wasserstein_loss
                        loss_e += self.args.penalty_aug*augmentation_loss
                        loss_e += erm_loss
                        loss_e += erm_loss_extra
                        

                    loss_e.backward(retain_graph=False)
                    self.opt.step()

                    del erm_loss_extra
                    del erm_loss
                    del wasserstein_loss 
                    del loss_e
                    torch.cuda.empty_cache()

                print('Train Loss Basic : ', penalty_erm_extra, penalty_aug, penalty_erm, penalty_ws )
                print('Train Acc Env : ', 100*train_acc/train_size )
                print('Done Training for epoch: ', epoch)    
                
                #Val Dataset Accuracy
                self.val_acc.append( self.get_test_accuracy('val') )

                #Test Dataset Accuracy
                self.final_acc.append( self.get_test_accuracy('test') )                    
                
                                
                #Save the model if current best epoch as per validation loss
                if self.val_acc[-1] > self.max_val_acc:
                    self.max_val_acc= self.val_acc[-1]
                    self.max_epoch= epoch
                    self.save_model_erm_phase(run_erm)
                    
                        
#                     if epoch > 0:
#                         #GPU
#                         cuda= torch.device("cuda:" + str(self.args.cuda_device))
#                         if cuda:
#                             kwargs = {'num_workers': 1, 'pin_memory': False} 
#                         else:
#                             kwargs= {}
                        
#                         train_dataset_temp= get_dataloader( self.args, self.run, self.args.train_domains, 'train', 1, kwargs )
#                         val_dataset_temp= get_dataloader( self.args, self.run, self.args.train_domains, 'val', 1, kwargs )
#                         test_dataset_temp= get_dataloader( self.args, self.run, self.args.test_domains, 'test', 1, kwargs )

#                         from evaluation.match_eval import MatchEval
#                         test_method= MatchEval(
#                                            self.args, train_dataset_temp, val_dataset_temp,
#                                            test_dataset_temp, self.base_res_dir, 
#                                            self.run, self.cuda
#                                           )   
#                         #Compute test metrics: Mean Rank
#                         test_method.phi= self.phi
#                         test_method.get_metric_eval()
#                         print('Match Function: ', test_method.metric_score)


#                     from evaluation.privacy_attack import PrivacyAttack
#                     test_method= PrivacyAttack(
#                                        self.args, train_dataset_temp, val_dataset_temp,
#                                        test_dataset_temp, self.base_res_dir, 
#                                        self.run, self.cuda
#                                          )        
#                     #Compute test metrics: Mean Rank
#                     test_method.phi= self.phi
#                     test_method.get_metric_eval()
#                     print('MIA: ', test_method.metric_score)

#                     from evaluation.privacy_entropy import PrivacyEntropy
#                     test_method= PrivacyEntropy(
#                                        self.args, train_dataset_temp, val_dataset_temp,
#                                        test_dataset_temp, self.base_res_dir, 
#                                        self.run, self.cuda
#                                          )                        
#                     #Compute test metrics: Mean Rank
#                     test_method.phi= self.phi
#                     test_method.get_metric_eval()
#                     print('Entropy: ', test_method.metric_score)

                
                print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])

                
                

================================================
FILE: algorithms/irm.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity, compute_irm_penalty

class Irm(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 
              
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0
        for epoch in range(self.args.epochs):   
            
            if epoch ==0:
                inferred_match= 0                
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
            elif (epoch % self.args.match_interrupt == 0 and self.args.match_flag):
                inferred_match= 1
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
            
            penalty_erm=0
            penalty_irm=0
            train_acc= 0.0
            train_size=0
        
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).numpy()
                
                #Forward Pass
                out= self.phi(x_e)                
                
                irm_loss=torch.tensor(0.0).to(self.cuda)
                erm_loss= torch.tensor(0.0).to(self.cuda) 
                
                # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split         
                total_batch_size= len(self.data_matched)
                if batch_idx >= total_batch_size:
                    break                    
                
                # Sample batch from matched data points
                data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)
                    
                data_match= data_match_tensor.to(self.cuda)
                data_match= data_match.flatten(start_dim=0, end_dim=1)
                feat_match= self.phi( data_match )
            
                label_match= label_match_tensor.to(self.cuda)
                label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )
                
                erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda)
                penalty_erm+= float(erm_loss)                
                loss_e += erm_loss                
                
                train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item()
                train_size+= label_match.shape[0]                
                        
                # Creating tensor of shape ( domain size, total domains, feat size )
                feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))                    
                label_match= torch.stack(torch.split(label_match, len(self.train_domains)))

                #IRM Penalty
                domain_counter=0
                for d_i in range(feat_match.shape[1]):
                    irm_loss+= compute_irm_penalty( feat_match[:, d_i, :], label_match[:, d_i], self.cuda )
                    domain_counter+=1

                irm_loss = irm_loss/domain_counter
                penalty_irm+= float(irm_loss)                                            
                
                #IRM Penalty to be minimized only after threshold epoch
                if epoch > self.args.penalty_s:
                    loss_e += self.args.penalty_irm*irm_loss
                    if self.args.penalty_irm > 1.0:
                      # Rescale the entire loss to keep gradients in a reasonable range
                      loss_e /= self.args.penalty_irm                    

                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del irm_loss 
                del loss_e
                torch.cuda.empty_cache()
           
            print('Train Loss Basic : ',  penalty_erm, penalty_irm )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])



================================================
FILE: algorithms/match_dg.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json
import os

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity
from utils.match_function import get_matched_pairs

class MatchDG(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda, ctr_phase=1):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 
        
        self.ctr_phase= ctr_phase
        self.ctr_save_post_string= str(self.args.match_case) + '_' + str(self.args.match_interrupt) + '_' + str(self.args.match_flag) + '_' + str(self.run) + '_' + self.args.model_name
        self.ctr_load_post_string= str(self.args.ctr_match_case) + '_' + str(self.args.ctr_match_interrupt) + '_' + str(self.args.ctr_match_flag) + '_' + str(self.run) + '_' + self.args.ctr_model_name
        
    def train(self):
        # Initialise and call train functions depending on the method's phase
        if self.ctr_phase:
            self.train_ctr_phase()
        else:
            self.train_erm_phase()
            
    def save_model_ctr_phase(self, epoch):
        # Store the weights of the model
        torch.save(self.phi.state_dict(), self.base_res_dir + '/Model_' + self.ctr_save_post_string + '.pth')

    def save_model_erm_phase(self, run):
        
        if not os.path.exists(self.base_res_dir + '/' + self.ctr_load_post_string):
            os.makedirs(self.base_res_dir + '/' + self.ctr_load_post_string)         
                
        # Store the weights of the model
        torch.save(self.phi.state_dict(), self.base_res_dir + '/' + self.ctr_load_post_string + '/Model_' + self.post_string + '_' + str(run) + '.pth')
    
    def init_erm_phase(self):
            
            if self.args.ctr_model_name == 'lenet':
                from models.lenet import LeNet5
                ctr_phi= LeNet5().to(self.cuda)
                
            if self.args.model_name == 'slab':
                from models.slab import SlabClf
                fc_layer=0
                ctr_phi= SlabClf(self.args.slab_data_dim, self.args.out_classes, fc_layer).to(self.cuda)
                                
            if self.args.model_name == 'domain_bed_mnist':
                from models.domain_bed_mnist import DomainBed
                fc_layer=0
                ctr_phi= DomainBed(self.args.img_c, fc_layer).to(self.cuda)
                
            if self.args.ctr_model_name == 'alexnet':
                from models.alexnet import alexnet
                ctr_phi= alexnet(self.args.out_classes, self.args.pre_trained, 'matchdg_ctr').to(self.cuda)                
            if self.args.ctr_model_name == 'fc':
                from models.fc import FC
                fc_layer=0
                ctr_phi= FC(self.args.out_classes, fc_layer).to(self.cuda)              
                
            if 'resnet' in self.args.ctr_model_name:
                from models.resnet import get_resnet
                fc_layer=0                
                ctr_phi= get_resnet(self.args.ctr_model_name, self.args.out_classes, fc_layer, self.args.img_c, self.args.pre_trained, self.args.dp_noise, self.args.os_env).to(self.cuda)
                
            if 'densenet' in self.args.ctr_model_name:
                from models.densenet import get_densenet
                fc_layer=0
                ctr_phi= get_densenet(self.args.ctr_model_name, self.args.out_classes, fc_layer, 
                                self.args.img_c, self.args.pre_trained, self.args.os_env).to(self.cuda)

                
            # Load MatchDG CTR phase model from the saved weights
            if self.args.os_env:
                base_res_dir=os.getenv('PT_DATA_DIR') + '/' + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains)             
            else:
                base_res_dir="results/" + self.args.dataset_name + '/' + 'matchdg_ctr' + '/' + self.args.ctr_match_layer + '/' + 'train_' + str(self.args.train_domains)                

            #TODO: Handle slab noise case in helper functions
            if self.args.dataset_name == 'slab':
                base_res_dir= base_res_dir + '/slab_noise_'  + str(self.args.slab_noise)
                
            save_path= base_res_dir + '/Model_' + self.ctr_load_post_string + '.pth'
            ctr_phi.load_state_dict( torch.load(save_path) )
            ctr_phi.eval()

            #Inferred Match Case
            if self.args.match_case == -1:
                inferred_match=1
            # x% percentage match initial strategy 
            else:
                inferred_match=0                
                
            data_matched, domain_data= self.get_match_function(inferred_match, ctr_phi)

            return data_matched, domain_data
            
    def train_ctr_phase(self):
        
        self.max_epoch= -1
        self.max_val_score= 0.0
        for epoch in range(self.args.epochs):    
            
            if epoch ==0:
                inferred_match= 0                
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
            elif (epoch % self.args.match_interrupt == 0 and self.args.match_flag):
                inferred_match= 1
                self.data_matched, self.domain_data= self.get_match_function(inferred_match, self.phi)
            
            penalty_same_ctr=0
            penalty_diff_ctr=0
            penalty_same_hinge=0
            penalty_diff_hinge=0           
            train_acc= 0.0
            train_size=0
        
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)            

                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1).numpy()

                same_ctr_loss = torch.tensor(0.0).to(self.cuda)
                diff_ctr_loss = torch.tensor(0.0).to(self.cuda)
                same_hinge_loss = torch.tensor(0.0).to(self.cuda)
                diff_hinge_loss = torch.tensor(0.0).to(self.cuda)
                
                if epoch > self.args.penalty_s:
                    # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                    total_batch_size= len(self.data_matched)
                    if batch_idx >= total_batch_size:
                        break
                    
                    # Sample batch from matched data points
                    data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)
                        
                    data_match= data_match_tensor.to(self.cuda)
                    data_match= data_match.flatten(start_dim=0, end_dim=1)
                    feat_match= self.phi( data_match )
            
                    label_match= label_match_tensor.to(self.cuda)
                    label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )                    
                    
                    # Creating tensor of shape ( domain size, total domains, feat size )
                    feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))                    
                    label_match= torch.stack(torch.split(label_match, len(self.train_domains)))

                    # Contrastive Loss
                    same_neg_counter=1
                    diff_neg_counter=1
                    for y_c in range(self.args.out_classes):

                        pos_indices= label_match[:, 0] == y_c
                        neg_indices= label_match[:, 0] != y_c
                        pos_feat_match= feat_match[pos_indices]
                        neg_feat_match= feat_match[neg_indices]

#                         if pos_feat_match.shape[0] > neg_feat_match.shape[0]:
#                             print('Weird! Positive Matches are more than the negative matches?', pos_feat_match.shape[0], neg_feat_match.shape[0])

                        # If no instances of label y_c in the current batch then continue
                        
                        print(pos_feat_match.shape[0], neg_feat_match.shape[0], y_c)
        
                        if pos_feat_match.shape[0] ==0 or neg_feat_match.shape[0] == 0:
                            continue

                        # Iterating over anchors from different domains
                        for d_i in range(pos_feat_match.shape[1]):
                            if torch.sum( torch.isnan(neg_feat_match) ):
                                print('Non Reshaped X2 is Nan')
                                sys.exit()

                            diff_neg_feat_match= neg_feat_match.view(  neg_feat_match.shape[0]*neg_feat_match.shape[1], neg_feat_match.shape[2] )

                            if torch.sum( torch.isnan(diff_neg_feat_match) ):
                                print('Reshaped X2 is Nan')
                                sys.exit()

                            neg_dist= embedding_dist( pos_feat_match[:, d_i, :], diff_neg_feat_match[:, :], self.args.pos_metric, self.args.tau, xent=True)     
                            if torch.sum(torch.isnan(neg_dist)):
                                print('Neg Dist Nan')
                                sys.exit()

                            # Iterating pos dist for current anchor
                            for d_j in range(pos_feat_match.shape[1]):
                                if d_i != d_j:
                                    pos_dist= 1.0 - embedding_dist( pos_feat_match[:, d_i, :], pos_feat_match[:, d_j, :], self.args.pos_metric )
                                    pos_dist= pos_dist / self.args.tau
                                    if torch.sum(torch.isnan(neg_dist)):
                                        print('Pos Dist Nan')
                                        sys.exit()

                                    if torch.sum( torch.isnan( torch.log( torch.exp(pos_dist) + neg_dist ) ) ):
                                        print('Xent Nan')
                                        sys.exit()

    #                                 print( 'Pos Dist', pos_dist )
    #                                 print( 'Log Dist ', torch.log( torch.exp(pos_dist) + neg_dist ))
                                    diff_hinge_loss+= -1*torch.sum( pos_dist - torch.log( torch.exp(pos_dist) + neg_dist ) )                                 
                                    diff_ctr_loss+= torch.sum(neg_dist)
                                    diff_neg_counter+= pos_dist.shape[0]

                    same_ctr_loss = same_ctr_loss / same_neg_counter
                    diff_ctr_loss = diff_ctr_loss / diff_neg_counter
                    same_hinge_loss = same_hinge_loss / same_neg_counter
                    diff_hinge_loss = diff_hinge_loss / diff_neg_counter      

                    penalty_same_ctr+= float(same_ctr_loss)
                    penalty_diff_ctr+= float(diff_ctr_loss)
                    penalty_same_hinge+= float(same_hinge_loss)
                    penalty_diff_hinge+= float(diff_hinge_loss)
                
                    loss_e += ( ( epoch- self.args.penalty_s )/(self.args.epochs -self.args.penalty_s) )*diff_hinge_loss
                        
                if not loss_e.requires_grad:
                    continue
                    
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del same_ctr_loss
                del diff_ctr_loss
                del same_hinge_loss
                del diff_hinge_loss
                torch.cuda.empty_cache()
   
            print('Train Loss Ctr : ', penalty_same_ctr, penalty_diff_ctr, penalty_same_hinge, penalty_diff_hinge)
            print('Done Training for epoch: ', epoch)
                        
            if (epoch+1)%5 == 0:
                                
                from evaluation.match_eval import MatchEval
                test_method= MatchEval(
                                   self.args, self.train_dataset, self.val_dataset,
                                   self.test_dataset, self.base_res_dir, 
                                   self.run, self.cuda
                                  )   
                #Compute test metrics: Mean Rank
                test_method.phi= self.phi
                test_method.get_metric_eval()
                                
                # Save the model's weights post training
                if test_method.metric_score['TopK Perfect Match Score'] > self.max_val_score:
                    self.max_val_score= test_method.metric_score['TopK Perfect Match Score']
                    self.max_epoch= epoch
                    self.save_model_ctr_phase(epoch)

                print('Current Best Epoch: ', self.max_epoch, ' with TopK Overlap: ', self.max_val_score)                
            
            
    def train_erm_phase(self):
        
        for run_erm in range(self.args.n_runs_matchdg_erm):   
            
            self.max_epoch= -1
            self.max_val_acc= 0.0
            for epoch in range(self.args.epochs):    
                
                if epoch ==0:
                    self.data_matched, self.domain_data= self.init_erm_phase()
                elif epoch % self.args.match_interrupt == 0 and self.args.match_flag:
                    inferred_match= 1
                    self.data_match_tensor, self.label_match_tensor= self.get_match_function(inferred_match, self.phi)
                    
                penalty_erm=0
                penalty_erm_extra=0
                penalty_ws=0
                train_acc= 0.0
                train_size=0

                #Batch iteration over single epoch
                for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
            #         print('Batch Idx: ', batch_idx)

                    self.opt.zero_grad()
                    loss_e= torch.tensor(0.0).to(self.cuda)

                    x_e= x_e.to(self.cuda)
                    y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                    d_e= torch.argmax(d_e, dim=1).numpy()

                    #Forward Pass
                    out= self.phi(x_e)
                    erm_loss_extra= F.cross_entropy(out, y_e.long()).to(self.cuda)
                    penalty_erm_extra += float(erm_loss_extra)

                    wasserstein_loss=torch.tensor(0.0).to(self.cuda)
                    erm_loss= torch.tensor(0.0).to(self.cuda) 
                    if epoch > self.args.penalty_s:
                        # To cover the varying size of the last batch for data_match_tensor_split, label_match_tensor_split
                        total_batch_size= len(self.data_matched)
                        if batch_idx >= total_batch_size:
                            break
                            
                        # Sample batch from matched data points
                        data_match_tensor, label_match_tensor, curr_batch_size= self.get_match_function_batch(batch_idx)                        
                        data_match= data_match_tensor.to(self.cuda)
                        data_match= data_match.flatten(start_dim=0, end_dim=1)
                        feat_match= self.phi( data_match )

                        label_match= label_match_tensor.to(self.cuda)
                        label_match= torch.squeeze( label_match.flatten(start_dim=0, end_dim=1) )

                        erm_loss+= F.cross_entropy(feat_match, label_match.long()).to(self.cuda)
                        penalty_erm+= float(erm_loss) 
                        
                        train_acc+= torch.sum(torch.argmax(feat_match, dim=1) == label_match ).item()
                        train_size+= label_match.shape[0]                        

                        # Creating tensor of shape ( domain size, total domains, feat size )
                        feat_match= torch.stack(torch.split(feat_match, len(self.train_domains)))                    
                        label_match= torch.stack(torch.split(label_match, len(self.train_domains)))

                        #Positive Match Loss
                        pos_match_counter=0
                        for d_i in range(feat_match.shape[1]):
            #                 if d_i != base_domain_idx:
            #                     continue
                            for d_j in range(feat_match.shape[1]):
                                if d_j > d_i:                        
                                    if self.args.pos_metric == 'l2':
                                        wasserstein_loss+= torch.sum( torch.sum( (feat_match[:, d_i, :] - feat_match[:, d_j, :])**2, dim=1 ) ) 
                                    elif self.args.pos_metric == 'l1':
                                        wasserstein_loss+= torch.sum( torch.sum( torch.abs(feat_match[:, d_i, :] - feat_match[:, d_j, :]), dim=1 ) )        
                                    elif self.args.pos_metric == 'cos':
                                        wasserstein_loss+= torch.sum( cosine_similarity( feat_match[:, d_i, :], feat_match[:, d_j, :] ) )

                                    pos_match_counter += feat_match.shape[0]

                        wasserstein_loss = wasserstein_loss / pos_match_counter
                        penalty_ws+= float(wasserstein_loss)                            


                        loss_e += ( self.args.penalty_ws*( epoch- self.args.penalty_s )/(self.args.epochs - self.args.penalty_s) )*wasserstein_loss
                        loss_e += erm_loss
                        loss_e += erm_loss_extra

                    loss_e.backward(retain_graph=False)
                    self.opt.step()

                    del erm_loss_extra
                    del erm_loss
                    del wasserstein_loss 
                    del loss_e
                    torch.cuda.empty_cache()

                print('Train Loss Basic : ', penalty_erm_extra,  penalty_erm, penalty_ws )
                print('Train Acc Env : ', 100*train_acc/train_size )
                print('Done Training for epoch: ', epoch)    
                
                #Train Dataset Accuracy
                self.train_acc.append( 100*train_acc/train_size )
            
                #Val Dataset Accuracy
                self.val_acc.append( self.get_test_accuracy('val') )

                #Test Dataset Accuracy
                self.final_acc.append( self.get_test_accuracy('test') ) 
                
                #Save the model if current best epoch as per validation loss
                if self.val_acc[-1] > self.max_val_acc:
                    self.max_val_acc= self.val_acc[-1]
                    self.max_epoch= epoch
                    self.save_model_erm_phase(run_erm)
                
                print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])

                if epoch > 0 and self.args.model_name in ['domain_bed_mnist', 'lenet']:
                    if self.args.model_name == 'lenet':
                        lr_schedule_step= 25
                    elif self.args.model_name == 'domain_bed_mnist':
                        lr_schedule_step= 10

                    if epoch % lr_schedule_step==0 :
                        lr=self.args.lr/(2**(int(epoch/lr_schedule_step)))
                        print('Learning Rate Scheduling; New LR: ', lr)                
                        self.opt= optim.SGD([
                                 {'params': filter(lambda p: p.requires_grad, self.phi.parameters()) }, 
                        ], lr= lr, weight_decay= self.args.weight_decay, momentum= 0.9,  nesterov=True )  

================================================
FILE: algorithms/mmd.py
================================================
import sys
import numpy as np
import argparse
import copy
import random
import json
import time

import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

from .algo import BaseAlgo
from utils.helper import l1_dist, l2_dist, embedding_dist, cosine_similarity

class MMD(BaseAlgo):
    def __init__(self, args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda):
        
        super().__init__(args, train_dataset, val_dataset, test_dataset, base_res_dir, post_string, cuda) 

        self.mmd_gamma= self.args.penalty_ws
        self.gaussian= bool(self.args.gaussian)
        self.conditional= bool(self.args.conditional)
        if self.gaussian: 
            self.kernel_type = "gaussian"
        else:
            self.kernel_type = "mean_cov"
        
        self.featurizer = self.phi.feat_net
        self.classifier = self.phi.fc
        
        print('Initial Params: ', )
        
    def my_cdist(self, x1, x2):
        x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
        x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
        res = torch.addmm(x2_norm.transpose(-2, -1),
                          x1,
                          x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
        return res.clamp_min_(1e-30)
    
    def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
                                           1000]):
        D = self.my_cdist(x, y)
        K = torch.zeros_like(D)

        for g in gamma:
            K.add_(torch.exp(D.mul(-g)))

        return K

    def mmd(self, x, y):
        if self.kernel_type == "gaussian":
            Kxx = self.gaussian_kernel(x, x).mean()
            Kyy = self.gaussian_kernel(y, y).mean()
            Kxy = self.gaussian_kernel(x, y).mean()
            return Kxx + Kyy - 2 * Kxy
        else:
            mean_x = x.mean(0, keepdim=True)
            mean_y = y.mean(0, keepdim=True)
            cent_x = x - mean_x
            cent_y = y - mean_y
            cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
            cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)

            mean_diff = (mean_x - mean_y).pow(2).mean()
            cova_diff = (cova_x - cova_y).pow(2).mean()

            return mean_diff + cova_diff
    
    def mmd_regularization(self, features, d, nmb):
        penalty= torch.tensor(0.0).to(self.cuda)
        for d_i in range(nmb):
            for d_j in range(d_i + 1, nmb):
                f_i= features[ d == d_i ]
                f_j= features[ d == d_j ]
                penalty += self.mmd(f_i, f_j)
        return penalty
        
    def train(self):
        
        self.max_epoch=-1
        self.max_val_acc=0.0
        for epoch in range(self.args.epochs):   
                    
            penalty_erm=0
            penalty_mmd=0
            train_acc= 0.0
            train_size=0
                    
            #Batch iteration over single epoch
            for batch_idx, (x_e, y_e ,d_e, idx_e, obj_e) in enumerate(self.train_dataset):
        #         print('Batch Idx: ', batch_idx)

                self.opt.zero_grad()
                loss_e= torch.tensor(0.0).to(self.cuda)
                
                x_e= x_e.to(self.cuda)
                y_e= torch.argmax(y_e, dim=1).to(self.cuda)
                d_e= torch.argmax(d_e, dim=1)
                
                #Forward Pass
                features = self.featurizer(x_e)
                out = self.classifier(features)                
                
                #ERM
                erm_loss= F.cross_entropy(out, y_e.long()).to(self.cuda)
                loss_e+= erm_loss
                penalty_erm += float(loss_e)
                
                #MMD
                mmd_loss=torch.tensor(0.0).to(self.cuda)
                match_domains= torch.unique(d_e)
                class_labels= torch.unique(y_e)
                nmb = len(match_domains)

                if self.conditional:
                    for y_c in range(len(class_labels)):                    
                        features_c= features[ y_e == y_c ]
                        d_c= d_e[ y_e == y_c ]
                        if len(torch.unique(d_c)) != nmb:
                            print('*********************************')
                            print('Error: Some classes not distributed across all the domains; issues for class conditional methods')
                            continue
                        mmd_loss+= self.mmd_regularization(features_c, d_c, nmb)
                else:
                    mmd_loss+= self.mmd_regularization(features, d_e, nmb)            

                if nmb > 1:
                    mmd_loss /= (nmb * (nmb - 1) / 2)
                                
                penalty_mmd+= float(mmd_loss)
                
                #Backward Pass
                loss_e+= self.mmd_gamma*mmd_loss                
                loss_e.backward(retain_graph=False)
                self.opt.step()
                
                del erm_loss
                del mmd_loss 
                del loss_e
                torch.cuda.empty_cache()
                
                train_acc+= torch.sum(torch.argmax(out, dim=1) == y_e ).item()
                train_size+= y_e.shape[0]                
                        
   
            print('Train Loss Basic : ',  penalty_erm, penalty_mmd )
            print('Train Acc Env : ', 100*train_acc/train_size )
            print('Done Training for epoch: ', epoch)
            
            #Train Dataset Accuracy
            self.train_acc.append( 100*train_acc/train_size )
            
            #Val Dataset Accuracy
            self.val_acc.append( self.get_test_accuracy('val') )
            
            #Test Dataset Accuracy
            self.final_acc.append( self.get_test_accuracy('test') )
            
            #Save the model if current best epoch as per validation loss
            if self.val_acc[-1] > self.max_val_acc:
                self.max_val_acc=self.val_acc[-1]
                self.max_epoch= epoch
                self.save_model()
                                
            print('Current Best Epoch: ', self.max_epoch, ' with Test Accuracy: ', self.final_acc[self.max_epoch])

================================================
FILE: azure_scripts/chest.yaml
================================================
description: ChestXray Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
 

# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
# - name: erm_oracle
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains kaggle_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1


- name: erm
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1
  
- name: rand_match
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --model_name densenet121 --n_runs 3 --os_env 1
  
- name: csd
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray --method_name csd --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --rep_dim 1024 --os_env 1
   
- name: irm
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1

- name: irm-50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 50.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1

- name: irm-100
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 100.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1

# - name: perf_match
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name hybrid --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --penalty_aug 10.0 --os_env 1


# - name: erm_oracle_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1

# - name: erm_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: rand_match_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: csd_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name csd --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --rep_dim 1024 --os_env 1
   
# - name: irm_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1
  
 
# - name: erm_oracle_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains chex_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1

# - name: erm_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: rand_match_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: csd_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name csd --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --rep_dim 1024 --os_env 1
   
# - name: irm_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1




================================================
FILE: azure_scripts/chest_ctr.yaml
================================================
description: ChestXray Dataset Constrastive Learning

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
# - name: kaggle_test
#   # one gpu
#   sku: G1
#   command:
#   - python train.py --dataset chestxray --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains nih_trans chex_trans --test_domains kaggle --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --os_env 1

- name: chex_test
  # one gpu
  sku: G1
  command:
  - python train.py --dataset chestxray --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains nih_trans kaggle_trans --test_domains chex_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --os_env 1

- name: nih_test
  # one gpu
  sku: G1
  command:
  - python train.py --dataset chestxray --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains chex_trans kaggle_trans --test_domains nih_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --os_env 1

================================================
FILE: azure_scripts/chest_ctr_spur.yaml
================================================
description: ChestXray Dataset Constrastive Learning

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
- name: kaggle_test
  # one gpu
  sku: G1
  command:
  - python train.py --dataset chestxray_spur --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 100 --batch_size 64 --pos_metric cos --train_domains nih_trans chex_trans --test_domains kaggle_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --os_env 1

# - name: chex_test
#   # one gpu
#   sku: G1
#   command:
#   - python train.py --dataset chestxray --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains nih_trans kaggle_trans --test_domains chex_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --os_env 1

# - name: nih_test
#   # one gpu
#   sku: G1
#   command:
#   - python train.py --dataset chestxray --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains chex_trans kaggle_trans --test_domains nih_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name densenet121 --os_env 1

================================================
FILE: azure_scripts/chest_matchdg.yaml
================================================
description: Hyperparam sweep on ChestXray Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:

    - echo "--debug" && python train.py --dataset chestxray --method_name {method}  --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --model_name densenet121 --n_runs 2 --train_domains nih_trans chex_trans --test_domains kaggle_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name densenet121
    
  type: grid
  max_trials: 100
  params:
    - name: penalty
      spec: discrete
#       values: [10.0, 50.0]
      values: [1.0]
    - name: penalty_aug
      spec: discrete
#       values: [10.0]
      values: [50.0]
    - name: lr
      spec: discrete
      values: [0.001]
#       values: [0.001, 0.0005]
    - name: method
      spec: discrete
      values: [hybrid]
#       values: [matchdg_erm, hybrid]
    - name: batch_size
      spec: discrete
#       values: [16, 32, 64]
      values: [16]
    - name: weight_decay
      spec: discrete
      values: [0.0005]
#       values: [0.0005, 0.001]

================================================
FILE: azure_scripts/chest_matchdg_spur.yaml
================================================
description: Hyperparam sweep on ChestXray Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:

    - echo "--debug" && python train.py --dataset chestxray_spur --method_name {method}  --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name densenet121 --model_name densenet121 --n_runs 2 --train_domains nih_trans chex_trans --test_domains kaggle_trans --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name densenet121
    
  type: grid
  max_trials: 100
  params:
    - name: penalty
      spec: discrete
#       values: [10.0, 50.0]
#       values: [1.0]
      values: [50.0]
    - name: penalty_aug
      spec: discrete
      values: [1.0]
#       values: [100.0]
    - name: lr
      spec: discrete
      values: [0.001]
#       values: [0.001, 0.0005]
    - name: method
      spec: discrete
      values: [matchdg_erm]
#       values: [matchdg_erm, hybrid]
    - name: batch_size
      spec: discrete
#       values: [16, 32, 64]
      values: [16]
    - name: weight_decay
      spec: discrete
      values: [0.0005]
#       values: [0.0005, 0.001]

================================================
FILE: azure_scripts/chest_spur.yaml
================================================
description: ChestXray Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
 

# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
# - name: erm_oracle
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains kaggle_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1


- name: erm
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray_spur --method_name erm_match --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1
  
- name: rand_match
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray_spur --method_name erm_match --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 50.0 --model_name densenet121 --n_runs 3 --os_env 1
  
- name: csd
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray_spur --method_name csd --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --rep_dim 1024 --os_env 1
   
- name: irm
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset chestxray_spur --method_name irm --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 50.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1

# - name: irm-50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 50.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1

# - name: irm-100
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 100.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1

# - name: perf_match
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name hybrid --match_case 0.01 --train_domains nih_trans chex_trans  --test_domains kaggle_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --penalty_aug 10.0 --os_env 1


# - name: erm_oracle_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1

# - name: erm_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: rand_match_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: csd_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name csd --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --rep_dim 1024 --os_env 1
   
# - name: irm_nih
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains chex_trans kaggle_trans  --test_domains nih_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1
  
 
# - name: erm_oracle_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains chex_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1

# - name: erm_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: rand_match_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name erm_match --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --model_name densenet121 --n_runs 3 --os_env 1
  
# - name: csd_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name csd --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 0.0 --model_name densenet121 --n_runs 3 --rep_dim 1024 --os_env 1
   
# - name: irm_chex
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset chestxray --method_name irm --match_case 0.01 --train_domains nih_trans kaggle_trans  --test_domains chex_trans  --out_classes 2 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 40  --lr 0.001 --batch_size 16 --penalty_ws 10.0 --penalty_s 5 --model_name densenet121 --n_runs 3 --os_env 1




================================================
FILE: azure_scripts/fmnist.yaml
================================================
description: Fashion MNIST Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
 

# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs

- name: erm
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --epochs 60  --os_env 1


- name: random_match
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --epochs 60  --os_env 1


# - name: approx_25
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset fashion_mnist --method_name erm_match --match_case 0.25 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: approx_50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset fashion_mnist --method_name erm_match --match_case 0.50 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: approx_75
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset fashion_mnist --method_name erm_match --match_case 0.75 --penalty_ws 0.1 --epochs 25  --os_env 1


- name: perfect
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name erm_match --match_case 1.0 --penalty_ws 0.1 --epochs 60  --os_env 1


- name: csd
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name csd --match_case 0.01 --penalty_ws 0.0 --rep_dim 512 --epochs 60  --os_env 1


- name: irm
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name irm --match_case 0.01 --penalty_irm 0.05 --penalty_s -1   --epochs 60  --os_env 1

  
- name: matchdg
  # one gpu
  sku: G1
  command: 
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name matchdg_erm --match_case -1 --penalty_ws 0.1 --epochs 60 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3  --os_env 1

================================================
FILE: azure_scripts/irm_fashion.yaml
================================================
description: Hyperparam sweep on IRM MNIST

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:
    - python train.py --dataset fashion_mnist --method_name irm --match_case 0.01 --lr 0.01 --penalty_irm {penalty} --penalty_s {threshold} --epochs 60 --os_env 1
  type: grid
  max_trials: 60
  params:
    - name: penalty
      spec: discrete
      values: [0.05, 0.1, 0.5, 1.0, 5.0]
    - name: threshold
      spec: discrete
      values: [-1, 5, 15, 30, 45 ]


================================================
FILE: azure_scripts/irm_mnist.yaml
================================================
description: Hyperparam sweep on IRM MNIST

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:
    - python train.py --dataset rot_mnist --method_name irm --match_case 0.01 --lr 0.01 --penalty_irm {penalty} --penalty_s {threshold} --epochs 60  --os_env 1
  type: grid
  max_trials: 60
  params:
    - name: penalty
      spec: discrete
      values: [0.05, 0.1, 0.5, 1.0, 5.0]
    - name: threshold
      spec: discrete
      values: [-1, 5, 15, 30, 45 ]


================================================
FILE: azure_scripts/mnist.yaml
================================================
description: Fashion MNIST Dataset

target:
  service: amlk8s
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  name: itpeusp100cl

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements_new.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
 

# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs

# - name: erm
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --epochs 25  --os_env 1


# - name: random_match
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: approx_25
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.25 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: approx_50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.50 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: approx_75
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.75 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: perfect
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: csd
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name csd --match_case 0.01 --penalty_ws 0.0 --rep_dim 512 --epochs 25  --os_env 1


- name: irm
  # one gpu
  sku: G1
  command:
  - echo "--debug" &&  python data/data_gen.py rot_mnist lenet
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name irm --match_case 0.01 --penalty_irm 1.0 --penalty_s 5   --epochs 25  --os_env 1

  
# - name: matchdg
#   # one gpu
#   sku: G1
#   command: 
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name matchdg_erm --match_case -1 --penalty_ws 0.1 --epochs 25 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3  --os_env 1

================================================
FILE: azure_scripts/mnist_ctr.yaml
================================================
description: MNIST Dataset Constrastive Learning

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
- name: fmnist_ctr_standard
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 60 --batch_size 64 --pos_metric cos --os_env 1
  
- name: fmnist_ctr_perfect
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset fashion_mnist --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --epochs 60 --batch_size 64 --pos_metric cos --os_env 1
  
- name: fmnist_ctr_non_iterative
  # one gpu
  sku: G1
  command:
  - python train.py --dataset fashion_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 0 --epochs 60 --batch_size 256 --pos_metric cos --os_env 1
  
- name: rmnist_ctr_standard
  # one gpu
  sku: G1
  command:
  - python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 100 --batch_size 256 --pos_metric cos --train_domains 30 45 --os_env 1
  
- name: rmnist_ctr_perfect
  # one gpu
  sku: G1
  command:
  - python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --epochs 100 --batch_size 256 --pos_metric cos --os_env 1
  
- name: rmnist_ctr_non_iterative
  # one gpu
  sku: G1
  command:
  - python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 0 --epochs 100 --batch_size 256 --pos_metric cos --os_env 1


================================================
FILE: azure_scripts/mnist_ctr_spur.yaml
================================================
description: MNIST Dataset Constrastive Learning

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
# - name: fmnist_ctr_standard
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset fashion_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 60 --batch_size 64 --pos_metric cos --os_env 1

  
- name: rmnist_ctr_standard
  # one gpu
  sku: G1
  command:
  - python train.py --dataset rot_mnist_spur --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 100 --batch_size 256 --pos_metric cos --img_c 3 --os_env 1



================================================
FILE: azure_scripts/mnist_spur.yaml
================================================
description: Fashion MNIST Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
 

# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs

# - name: erm
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist_spur --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --epochs 25 --img_c 3  --os_env 1


# - name: random_match
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist_spur --method_name erm_match --match_case 0.01 --penalty_ws 10.0 --epochs 25 --img_c 3  --os_env 1


# - name: approx_25
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.25 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: approx_50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.50 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: approx_75
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist --method_name erm_match --match_case 0.75 --penalty_ws 0.1 --epochs 25  --os_env 1


# - name: perfect
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist_spur --method_name erm_match --match_case 1.0 --penalty_ws 10.0 --epochs 25  --img_c 3 --os_env 1


# - name: csd
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist_spur --method_name csd --match_case 0.01 --penalty_ws 0.0 --rep_dim 512 --epochs 25 --img_c 3  --os_env 1


# - name: irm
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset rot_mnist_spur --method_name irm --match_case 0.01 --penalty_irm 50.0 --penalty_s 5   --epochs 25 --img_c 3  --os_env 1

  
- name: matchdg
  # one gpu
  sku: G1
  command: 
  - echo "--debug" && python train.py --dataset rot_mnist_spur --method_name matchdg_erm --match_case -1 --penalty_ws 0.1 --epochs 25 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3 --img_c 3 --os_env 1

================================================
FILE: azure_scripts/pacs.yaml
================================================
description: PACS Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
- name: photo
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1
  
- name: art_painting
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1
  
- name: cartoon
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1

- name: sketch
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1

- name: photo_random_match
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.5 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1
  
- name: art_painting_random_match
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.5 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1
  
- name: cartoon_random_match
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.5 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1

- name: sketch_random_match
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.5 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.005 --epochs 50 --os_env 1


================================================
FILE: azure_scripts/pacs_art_painting.yaml
================================================
description: Hyperparam sweep on PACS

target:
  service: amlk8s
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  name: itpeusp100cl
#   name:  itpeusp40cl

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements_new.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:
    
    # ERM, RandMatch
#     - echo "--debug" && python train.py --dataset pacs --method_name {method} --match_case 0.0 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model} 
    
    
    # MDG, Hybrid
    - echo "--debug" && python train.py --dataset pacs --method_name {method}  --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0  --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model} 

    
  type: grid
  max_trials: 100
  params:
    - name: penalty
      spec: discrete
#       values: [0.1, 0.5, 1.0, 5.0]
      values: [0.01, 1.0]
    - name: penalty_aug
      spec: discrete
#       values: [0.1, 1.0, 5.0, 10.0]
      values: [0.1, 1.0]
    - name: lr
      spec: discrete
      values: [0.001]
#       values: [0.01, 0.001, 0.0005]
    - name: model
      spec: discrete
      values: [alexnet]
#       values: [alexnet, resnet18, resnet50]
    - name: method
      spec: discrete
#       values: [erm_match, matchdg_erm, hybrid]
      values: [hybrid]
    - name: batch_size
      spec: discrete
      values: [16]
    - name: weight_decay
      spec: discrete
      values: [0.0005]
      

================================================
FILE: azure_scripts/pacs_cartoon.yaml
================================================
description: Hyperparam sweep on PACS

target:
  service: amlk8s
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  name: itpeusp100cl
#   name:  itpeusp40cl

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements_new.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:
    
    # ERM, RandMatch
#     - echo "--debug" && python train.py --dataset pacs --method_name {method} --match_case 0.0 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model} 
    
    
    # MDG, Hybrid
    - echo "--debug" && python train.py --dataset pacs --method_name {method}  --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0  --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model} 

    
  type: grid
  max_trials: 100
  params:
    - name: penalty
      spec: discrete
#       values: [0.1, 0.5, 1.0, 5.0]
      values: [0.01, 1.0]
    - name: penalty_aug
      spec: discrete
#       values: [0.1, 1.0, 5.0, 10.0]
      values: [0.1, 1.0]
    - name: lr
      spec: discrete
      values: [0.001]
#       values: [0.01, 0.001, 0.0005]
    - name: model
      spec: discrete
      values: [alexnet]
#       values: [alexnet, resnet18, resnet50]
    - name: method
      spec: discrete
#       values: [erm_match, matchdg_erm, hybrid]
      values: [hybrid]
    - name: batch_size
      spec: discrete
      values: [16]
    - name: weight_decay
      spec: discrete
      values: [0.0005]
      

================================================
FILE: azure_scripts/pacs_ctr.yaml
================================================
description: PACS Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
- name: photo_ctr_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 256 --pos_metric cos --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --os_env 1
  
# - name: art_painting_ctr_r18
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 256 --pos_metric cos --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --os_env 1
  
- name: cartoon_ctr_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 256 --pos_metric cos --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --os_env 1

- name: sketch_ctr_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 256 --pos_metric cos --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet18 --os_env 1
  
  
- name: photo_ctr_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --os_env 1
  
- name: art_painting_ctr_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --os_env 1
  
- name: cartoon_ctr_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --os_env 1

- name: sketch_ctr_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --model_name resnet50 --os_env 1

================================================
FILE: azure_scripts/pacs_erm.yaml
================================================
description: PACS ERM Dataset

target:
  service: amlk8s
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  name: itpeusp100cl

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements_new.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
# - name: photo_r18
#   # one gpu
#   sku: G1
#   command:
#   - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.0005 --epochs 50 --model_name resnet18 --weight_decay 0.001 --os_env 1
  
# - name: art_painting_r18
#   # one gpu
#   sku: G1
#   command:
#   - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.01 --epochs 50 --model_name resnet18 --weight_decay 0.001 --os_env 1
  
# - name: cartoon_r18
#   # one gpu
#   sku: G1
#   command:
#   - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.001 --epochs 50 --model_name resnet18 --weight_decay 0.001 --os_env 1

- name: sketch_r18
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.01 --epochs 50 --model_name resnet18 --weight_decay 0.001 --os_env 1


================================================
FILE: azure_scripts/pacs_hybrid.yaml
================================================
description: PACS MatchDG Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
# - name: photo_r18
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005  --batch_size 16 --weight_decay 0.001 --penalty_ws 0.1 --penalty_aug 0.1 --model_name resnet18
  
# - name: art_painting_r18
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.01 --penalty_aug 0.1 --model_name resnet18
  
# - name: cartoon_r18
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.01 --penalty_aug 0.1 --model_name resnet18

- name: sketch_r18_0.01
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.01 --penalty_aug 0.1 --model_name resnet18


- name: sketch_r18_0.1
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.1 --penalty_aug 0.1 --model_name resnet18


- name: sketch_r18_0.5
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.5 --penalty_aug 0.1 --model_name resnet18

# - name: photo_r50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005  --batch_size 16 --weight_decay 0.001 --penalty_ws 0.1 --penalty_aug 0.1 --model_name resnet50
  
# - name: art_painting_r50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.01 --penalty_aug 0.1 --model_name resnet50
  
# - name: cartoon_r50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.01 --penalty_aug 0.1 --model_name resnet50

# - name: sketch_r50
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.01 --penalty_aug 0.1 --model_name resnet50


================================================
FILE: azure_scripts/pacs_matchdg.yaml
================================================
description: PACS MatchDG Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
# - name: photo_r18
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005  --batch_size 16 --weight_decay 0.001 --penalty_ws 0.1 --penalty_aug 0.1 --model_name resnet18
  
# - name: art_painting_r18
#   # one gpu
#   sku: G1
#   command:
#   - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.5 --penalty_aug 0.1 --model_name resnet18
  
- name: cartoon_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 1.0 --penalty_aug 0.1 --model_name resnet18

- name: sketch_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.01 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.5 --penalty_aug 0.1 --model_name resnet18

- name: photo_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005  --batch_size 16 --weight_decay 0.001 --penalty_ws 0.1 --penalty_aug 0.1 --model_name resnet50
  
- name: art_painting_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.1 --penalty_aug 0.1 --model_name resnet50
  
- name: cartoon_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005 --batch_size 16 --weight_decay 0.001 --penalty_ws 1.0 --penalty_aug 0.1 --model_name resnet50

- name: sketch_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name matchdg_erm --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.5 --penalty_aug 0.1 --model_name resnet50


================================================
FILE: azure_scripts/pacs_perfect.yaml
================================================
description: PACS MatchDG Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
- name: photo_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005  --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet18
  
- name: art_painting_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet18
  
- name: cartoon_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet18

- name: sketch_r18
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.01 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet18

- name: photo_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005  --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet50
  
- name: art_painting_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet50
  
- name: cartoon_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.0005 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet50

- name: sketch_r50
  # one gpu
  sku: G1
  command:
  - echo "--debug" && python train.py --dataset pacs --method_name hybrid --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr 0.001 --batch_size 16 --weight_decay 0.001 --penalty_ws 0.0 --penalty_aug 0.1 --model_name resnet50


================================================
FILE: azure_scripts/pacs_photo.yaml
================================================
description: Hyperparam sweep on PACS

target:
  service: amlk8s
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  name: itpeusp100cl
#   name: itplabrl1cl1
#   name:  itpeusp40cl
  
environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements_new.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:
    
    # ERM, RandMatch
#     - echo "--debug" && python train.py --dataset pacs --method_name {method} --match_case 0.0 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model} 
    
    
    # MDG, Hybrid
    - echo "--debug" && python train.py --dataset pacs --method_name {method}  --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0  --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model} 

    
  type: grid
  max_trials: 100
  params:
    - name: penalty
      spec: discrete
#       values: [0.1, 0.5, 1.0, 5.0]
      values: [0.1]
    - name: penalty_aug
      spec: discrete
#       values: [0.1, 1.0, 5.0, 10.0]
      values: [0.1, 1.0]
    - name: lr
      spec: discrete
      values: [0.0005]
#       values: [0.01, 0.001, 0.0005]
    - name: model
      spec: discrete
      values: [alexnet]
#       values: [alexnet, resnet18, resnet50]
    - name: method
      spec: discrete
#       values: [erm_match, matchdg_erm, hybrid]
      values: [hybrid]
    - name: batch_size
      spec: discrete
      values: [16]
    - name: weight_decay
      spec: discrete
      values: [0.0005]
      

================================================
FILE: azure_scripts/pacs_random.yaml
================================================
description: PACS Random Match Dataset

target:
  service: philly
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  cluster: rr1

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
    
# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs
- name: photo_r18
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.001 --epochs 50 --model_name resnet18 --weight_decay 0.001 --os_env 1
  
- name: art_painting_r18
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.5 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.01 --epochs 50 --model_name resnet18 --os_env 1
  
- name: cartoon_r18
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.01 --epochs 50 --model_name resnet18 --os_env 1

- name: sketch_r18
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.01 --epochs 50 --model_name resnet18 --os_env 1

- name: photo_r50
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --train_domains art_painting cartoon sketch --test_domains photo --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.0005 --epochs 50 --model_name resnet50 --weight_decay 0.001 --os_env 1
  
- name: art_painting_r50
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.5 --train_domains photo cartoon sketch --test_domains art_painting --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.001 --epochs 50 --model_name resnet50 --os_env 1
  
- name: cartoon_r50
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.5 --train_domains photo art_painting sketch --test_domains cartoon --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.0005 --epochs 50 --model_name resnet50 --os_env 1

- name: sketch_r50
  # one gpu
  sku: G1
  command:
  - python train.py --dataset pacs --method_name erm_match --match_case 0.01 --penalty_ws 0.1 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --lr 0.01 --epochs 50 --model_name resnet50 --os_env 1


================================================
FILE: azure_scripts/pacs_sketch.yaml
================================================
description: Hyperparam sweep on PACS

target:
  service: amlk8s
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  name: itpeusp100cl
#   name:  itpeusp40cl

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements_new.txt

code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/

search:
  job_template:
    # you may use {random_string:s} to avoid job name collisions
    # {auto:3s} generates lr_0.00000_mom_0.5, .. etc
    # {auto:2s} generates lr_0.00000_mo_0.5, .. etc
    name: search_{experiment_name:s}_{auto:5s}
    sku: G1
    command:
    
    # ERM, RandMatch
#     - echo "--debug" && python train.py --dataset pacs --method_name {method} --match_case 0.0 --n_runs 3 --train_domains photo art_painting cartoon --test_domains sketch --out_classes 7 --perfect_match 0 --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model}     
    
    # MDG, Hybrid
    - echo "--debug" && python train.py --dataset pacs --method_name {method}  --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet50 --n_runs 3 --train_domains photo art_painting cartoon  --test_domains sketch --out_classes 7 --perfect_match 0  --img_c 3 --pre_trained 1 --epochs 50 --os_env 1  --lr {lr} --batch_size {batch_size} --weight_decay {weight_decay} --penalty_ws {penalty} --penalty_aug {penalty_aug} --model_name {model} 

    
  type: grid
  max_trials: 100
  params:
    - name: penalty
      spec: discrete
#       values: [0.1, 0.5, 1.0, 5.0]
      values: [0.01, 0.1]
    - name: penalty_aug
      spec: discrete
#       values: [0.1, 1.0, 5.0, 10.0]
      values: [0.1, 1.0]
    - name: lr
      spec: discrete
      values: [0.001]
#       values: [0.01, 0.001, 0.0005]
    - name: model
      spec: discrete
      values: [alexnet]
#       values: [alexnet, resnet18, resnet50]
    - name: method
      spec: discrete
#       values: [erm_match, matchdg_erm, hybrid]
      values: [hybrid]
    - name: batch_size
      spec: discrete
      values: [16]
    - name: weight_decay
      spec: discrete
      values: [0.0005]
      

================================================
FILE: azure_scripts/setup_data_mnist.yaml
================================================
description: MNIST Dat Setup

target:
  service: amlk8s
  # which virtual cluster you belong to (msrlabs, etc.). Everyone has access to "msrlabs".
  vc: resrchvc 
  # physical cluster to use (cam, gcr, rr1, rr2) or Azure clusters (eu1, eu2, etc.)
  name: itpeusp100cl

environment:
  image: pytorch/pytorch:1.5-cuda10.1-cudnn7-devel
  setup:
    - pip install --user -r requirements_new.txt
    
code:
  # local directory of the code. this will be uploaded to the server.
  # $CONFIG_DIR is expanded to the directory of this config file
  local_dir: $CONFIG_DIR

data:
    local_dir: $CONFIG_DIR/data/datasets/
    remote_dir: data/datasets/
 

# list of jobs to run, we run 2 jobs in this example
jobs:
  # name must be unique across the jobs

- name: rot_mnist_resnet18
  # one gpu
  sku: G1
  command:
  - echo "--debug" &&  python data/data_gen.py rot_mnist resnet18

#- name: fashion_mnist_resnet18
#  # one gpu
#  sku: G1
#  command:
#  - echo "--debug" &&  python data/data_gen.py fashion_mnist resnet18

#- name: rot_mnist_lenet
#  # one gpu
#  sku: G1
#  command:
#  - echo "--debug" &&  python data/data_gen.py rot_mnist lenet



================================================
FILE: chestxray_download.txt
================================================
NIH Dataset:

curl -o nih.zip "https://storage.googleapis.com/kaggle-data-sets/5839%2F18613%2Fbundle%2Farchive.zip?GoogleAccessId=gcp-kaggle-com@kaggle-161607.iam.gserviceaccount.com&Expires=1600359450&Signature=MPds%2FPBnAPNGFXy1cnmRVhHaHsTRggstPA44ZCE0onI35vc4UMwdPSyQS%2Fypf5B%2FhmOsf6%2B6oxy0%2BKL8HCBh8BtFrwMyfY7dVczTmPkBEGPALf7roGbuWFB6oUVrXAVHFpJwCKEwMCSrxkpFIccLxXII%2B84aG4xrqwzu1LQq%2BRyE3W7Rg22ib1tiyX%2FsZjGk8%2BmHqlA7gg2Y9pr4s7xZgTpnpUv0NPiVjLcsWHgWznx2fuWZm8Ox%2Faj6CzZa6dbpYg%2FNWIHpCJ%2BzfPRCZQuGVaoSfKjoPZK9ei3W1FrZ2MDHBzPREQh1OCngT3v2%2Fn%2BXFj5tQ5b4OfvL1YB%2FMou5Cg%3D%3D"

CheXpert:

curl -o chexpert.zip http://download.cs.stanford.edu/deep/CheXpert-v1.0-small.zip

OpenI:

curl -o openi.tgz https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz

Kaggle:

curl -o kaggle.zip "https://storage.googleapis.com/kaggle-competitions-data/kaggle-v2/10338/862042/bundle/archive.zip?GoogleAccessId=web-data@kaggle161607.iam.gserviceaccount.com&Expires=1600347314&Signature=hilOASDiejHlo7KgvJR%2FqzaPg3eKcnBKauYVS%2FM6CIoVUl6mjgDdiDFwXJYOmeuK%2F1WfLO32JEjsc8XB6h7SQWhsMJ6Xs%2F1P7oMKNURjcYkZ2OQYXSV5gFDWVqZ%2Bna4t4B2y%2Bz6Gp9GpGt5HEjc4leOGlMizwLQEhQmlZWSpBqFzgTjLF9eVbNc2ekln5SCsLFWLz0YGFeAgkulq5qgh2Rfu%2BD5QafmPgTc3iMMJf%2BQcVJ0dgqHjcROmANWTnvdWcMjweZMBwXOgYHOomCHHRAgXnWvaXC5AxZsKXmmsbWe%2BsuCDJ4bIwAzm%2BC27XJwnIaeaOudn6BL%2FuLtf1lvv7A%3D%3D&response-content-disposition=attachment%3B+filename%3Drsna-pneumonia-detection-challenge.zip"



================================================
FILE: data/__init__.py
================================================


================================================
FILE: data/adult_loader.py
================================================
#Common imports
import os
import random
import copy
import numpy as np
import h5py
from PIL import Image

#Pytorch
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from torchvision import datasets, transforms

#Base Class
from .data_loader import BaseDataLoader

class Adult(BaseDataLoader):
    def __init__(self, args, list_train_domains, root, transform=None, data_case='train', match_func=False):
        
        super().__init__(args, list_train_domains, root, transform, data_case, match_func) 
        self.train_data, self.train_labels, self.train_domain, self.train_indices, self.train_spur = self._get_data()

    def _get_data(self):
        
        data_dir= self.root
        to_pil = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()
        
        # Choose subsets that should be included into the training
        training_list_img = []
        training_list_labels = []
        training_list_idx= []
        training_list_size= []
        training_list_spur= []
        training_out_classes=[]
       
        for domain in self.list_train_domains:
            
            domain_imgs = torch.load(data_dir + domain + '_' + self.data_case + '_data.pt').float()
            domain_labels = torch.load(data_dir + domain + '_' + self.data_case + '_label.pt').long()
            domain_spur= torch.load(data_dir + domain + '_' + self.data_case + '_spur.pt').long()
            domain_idx= list(range(len(domain_imgs)))
            print('Image: ', domain_imgs.shape, ' Labels: ', domain_labels.shape)
            print('Source Domain ', domain)
            training_list_img.append(domain_imgs)
            training_list_labels.append(domain_labels)
            training_list_idx.append( domain_idx )
            training_list_spur.append( domain_spur )
            training_list_size.append(len(domain_imgs))            
            training_out_classes.append( len(torch.unique(domain_labels)) )
        
        if self.match_func:
            print('Match Function Updates')
            num_classes= 2
            for y_c in range(num_classes):
                base_class_size=0
                base_class_idx=-1
                for d_idx, domain in enumerate( self.list_train_domains ):
                    class_idx= training_list_labels[d_idx] == y_c
                    curr_class_size= training_list_labels[d_idx][class_idx].shape[0]
                    if base_class_size < curr_class_size:
                        base_class_size= curr_class_size
                        base_class_idx= d_idx
                self.base_domain_size += base_class_size
                print('Max Class Size: ', base_class_size, base_class_idx, y_c )
        
                
        # Stack
        train_imgs = torch.cat(training_list_img)
        train_labels = torch.cat(training_list_labels)
        train_spur = torch.cat(training_list_spur)
        train_indices = np.array(training_list_idx)
        train_indices= np.hstack(train_indices)
        self.training_list_size = training_list_size
                
        print(train_imgs.shape, train_labels.shape, train_indices.shape, train_spur.shape)
        print(self.training_list_size)
        
        # Create domain labels
        train_domains = torch.zeros(train_labels.size())
        domain_start=0
        for idx in range(len(self.list_train_domains)):
            curr_domain_size= self.training_list_size[idx]
            train_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size
           
        # Shuffle everything one more time
        inds = np.arange(train_labels.size()[0])
        np.random.shuffle(inds)
        train_imgs = train_imgs[inds]
        train_labels = train_labels[inds]
        train_spur= train_spur[inds]
        train_domains = train_domains[inds].long()
        train_indices = train_indices[inds]

        # Convert to onehot
        out_classes= training_out_classes[0]
        y = torch.eye(out_classes)
        train_labels = y[train_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_train_domains))
        train_domains = d[train_domains]
        
        print(train_imgs.shape, train_labels.shape, train_domains.shape, train_indices.shape)
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(train_imgs.shape)==3:
            train_imgs= train_imgs.unsqueeze(1)
        return train_imgs, train_labels, train_domains, train_indices, train_spur

================================================
FILE: data/chestxray_loader.py
================================================
#Common imports
import os
import random
import copy
import numpy as np
import h5py
from PIL import Image

#Pytorch
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from torchvision import datasets, transforms

#Base Class
from .data_loader import BaseDataLoader

class ChestXRay(BaseDataLoader):
    def __init__(self, args, list_train_domains, root, transform=None, data_case='train', match_func=False):
        
        super().__init__(args, list_train_domains, root, transform, data_case, match_func) 
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()

    def _get_data(self):
        
        data_dir= self.root
        to_pil = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()
        
        # Choose subsets that should be included into the training
        list_img = []
        list_labels = []
        list_idx= []
        list_size= []
        list_classes=[]
       
        for domain in self.list_domains:
            
            domain_imgs = torch.load(data_dir + domain + '_' + self.data_case + '_image.pt')
            domain_imgs_org = torch.load(data_dir + domain + '_' + self.data_case + '_image_org.pt')
            domain_labels = torch.load(data_dir + domain + '_' + self.data_case + '_label.pt')
            domain_idx= list(range(len(domain_imgs)))
            print('Image: ', domain_imgs.shape, ' Labels: ', domain_labels.shape)
            print('Source Domain ', domain)
            
            #Apply augmentation to only training dataset
            if self.data_case == 'train':
                list_img.append(domain_imgs)
            else:
                list_img.append(domain_imgs_org)                
            list_labels.append(domain_labels)
            list_idx.append( domain_idx )
            list_size.append(len(domain_imgs))            
            list_classes.append( len(torch.unique(domain_labels)) )
        
        if self.match_func:
            print('Match Function Updates')
            num_classes= 2
            for y_c in range(num_classes):
                base_class_size=0
                base_class_idx=-1
                for d_idx, domain in enumerate( self.list_domains ):
                    class_idx= list_labels[d_idx] == y_c
                    curr_class_size= list_labels[d_idx][class_idx].shape[0]
                    if base_class_size < curr_class_size:
                        base_class_size= curr_class_size
                        base_class_idx= d_idx
                self.base_domain_size += base_class_size
                print('Max Class Size: ', base_class_size, ' Base Domain Idx: ', base_class_idx, ' Class Label: ', y_c )        
                
        # Stack
        data_imgs = torch.cat(list_img)
        data_labels = torch.cat(list_labels)
        data_indices = np.array(list_idx)
        data_indices= np.hstack(data_indices)
        self.training_list_size = list_size                    
        
        #No ground truth objects in ChestXRay, for reference we set them same as data indices
        data_objects= copy.deepcopy(data_indices)
    
        # Create domain labels
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size
           
        # Shuffle everything one more time
        inds = np.arange(data_labels.size()[0])
        np.random.shuffle(inds)
        data_imgs = data_imgs[inds]
        data_labels = data_labels[inds]
        data_domains = data_domains[inds].long()
        data_indices = data_indices[inds]
        data_objects = data_objects[inds]

        # Convert to onehot
        out_classes= list_classes[0]
        y = torch.eye(out_classes)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains]
        
        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)
            
        print('Shape: Data ', data_imgs.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_labels, data_domains, data_indices, data_objects


================================================
FILE: data/chestxray_loader_aug.py
================================================
#Common imports
import os
import random
import copy
import numpy as np
import h5py
from PIL import Image

#Pytorch
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from torchvision import datasets, transforms

#Base Class
from .data_loader import BaseDataLoader

class ChestXRayAug(BaseDataLoader):
    def __init__(self, args, list_domains, root, transform=None, data_case='train', match_func=False):
        
        super().__init__(args, list_domains, root, transform, data_case, match_func) 
        self.data, self.data_org, self.labels, self.domains, self.indices, self.objects = self._get_data()
        
    def __getitem__(self, index):
        x = self.data[index]
        x_org = self.data_org[index]
        y = self.labels[index]
        d = self.domains[index]
        idx = self.indices[index]
        obj =  self.objects[index]
            
        if self.transform is not None:
            x = self.transform(x)
        return x, x_org, y, d, idx, obj        

    def _get_data(self):
        
        data_dir= self.root
        to_pil = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()
        
        # Choose subsets that should be included into the training
        list_img = []
        list_img_org= []
        list_labels = []
        list_idx= []
        list_size= []
        list_classes=[]
       
        for domain in self.list_domains:
            
            domain_imgs = torch.load(data_dir + domain + '_' + self.data_case + '_image.pt')
            domain_imgs_org = torch.load(data_dir + domain + '_' + self.data_case + '_image_org.pt')
            domain_labels = torch.load(data_dir + domain + '_' + self.data_case + '_label.pt')
            domain_idx= list(range(len(domain_imgs)))
            print('Image: ', domain_imgs.shape, ' Labels: ', domain_labels.shape)
            print('Source Domain ', domain)
            list_img.append(domain_imgs)
            list_img_org.append(domain_imgs_org)
            list_labels.append(domain_labels)
            list_idx.append( domain_idx )
            list_size.append(len(domain_imgs))            
            list_classes.append( len(torch.unique(domain_labels)) )
        
        if self.match_func:
            print('Match Function Updates')
            num_classes= 2
            for y_c in range(num_classes):
                base_class_size=0
                base_class_idx=-1
                for d_idx, domain in enumerate( self.list_domains ):
                    class_idx= list_labels[d_idx] == y_c
                    curr_class_size= list_labels[d_idx][class_idx].shape[0]
                    if base_class_size < curr_class_size:
                        base_class_size= curr_class_size
                        base_class_idx= d_idx
                self.base_domain_size += base_class_size
                print('Max Class Size: ', base_class_size, ' Base Domain Idx: ', base_class_idx, ' Class Label: ', y_c )        
                
        # Stack
        data_imgs = torch.cat(list_img)
        data_imgs_org = torch.cat(list_img_org)
        data_labels = torch.cat(list_labels)
        data_indices = np.array(list_idx)
        data_indices= np.hstack(data_indices)
        self.training_list_size = list_size            

        #No ground truth objects in PACS, for reference we set them same as data indices
        data_objects= copy.deepcopy(data_indices)
        
        # Create domain labels
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.list_domains)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size
           
        # Shuffle everything one more time
        inds = np.arange(data_labels.size()[0])
        np.random.shuffle(inds)
        data_imgs = data_imgs[inds]
        data_imgs_org = data_imgs_org[inds]
        data_labels = data_labels[inds]
        data_domains = data_domains[inds].long()
        data_indices = data_indices[inds]
        data_objects = data_objects[inds]

        # Convert to onehot
        out_classes= list_classes[0]
        y = torch.eye(out_classes)
        data_labels = y[data_labels]

        # Convert to onehot
        d = torch.eye(len(self.list_domains))
        data_domains = d[data_domains]        

        # If shape (B,H,W) change it to (B,C,H,W) with C=1
        if len(data_imgs.shape)==3:
            data_imgs= data_imgs.unsqueeze(1)            
            
        print('Shape: Data ', data_imgs.shape, ' Data w/o augmentation ', data_imgs_org.shape, ' Labels ', data_labels.shape, ' Domains ', data_domains.shape, ' Indices ', data_indices.shape, ' Objects ', data_objects.shape)
        return data_imgs, data_imgs_org, data_labels, data_domains, data_indices, data_objects            


================================================
FILE: data/chestxray_loader_match_eval.py
================================================
#Common imports
import os
import random
import copy
import numpy as np
import h5py
from PIL import Image

#Pytorch
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from torchvision import datasets, transforms

#Base Class
from .data_loader import BaseDataLoader

class ChestXRayAugEval(BaseDataLoader):
    def __init__(self, args, list_domains, root, transform=None, data_case='train', match_func=False):
        
        super().__init__(args, list_domains, root, transform, data_case, match_func) 
        self.data, self.labels, self.domains, self.indices, self.objects = self._get_data()

    def _get_data(self):
        
        data_dir= self.root
        to_pil = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()
        
        # Choose subsets that should be included into the training
        list_img = {'aug':[], 'org':[] }
        list_labels = {'aug':[], 'org':[] }
        list_idx= {'aug':[], 'org':[] }
        list_size= {'aug':0, 'org':0 }
        list_classes={'aug':[], 'org':[] }
       
        image_counter=0
        for domain in self.list_domains:
            
            domain_imgs = torch.load(data_dir + domain + '_' + self.data_case + '_image.pt')
            domain_imgs_org = torch.load(data_dir + domain + '_' + self.data_case + '_image_org.pt')
            domain_labels = torch.load(data_dir + domain + '_' + self.data_case + '_label.pt')
            
            domain_idx= image_counter + np.array(list(range(len(domain_imgs))))
            domain_idx= domain_idx.tolist()
            image_counter+= len(domain_imgs)
            
            print('Image: ', domain_imgs.shape, ' Labels: ', domain_labels.shape)
            print('Source Domain ', domain)
            
            list_img['aug'].append(domain_imgs)
            list_img['org'].append(domain_imgs_org)
            
            list_labels['aug'].append(domain_labels)
            list_labels['org'].append(domain_labels)
            
            list_idx['aug'].append( domain_idx )            
            list_idx['org'].append( domain_idx )            
            
            list_size['aug']+= len(domain_imgs)
            list_size['org']+= len(domain_imgs)            
            
            list_classes['aug'].append( len(torch.unique(domain_labels))  )
            list_classes['org'].append( len(torch.unique(domain_labels))  )

                
        if self.match_func:
            print('Match Function Updates')
            num_classes= 2
            for y_c in range(num_classes):
                for key in ['aug', 'org']:
                    base_class_size=0
                    base_class_idx=-1                    
                    curr_class_size=0
                    for d_idx, domain in enumerate( self.list_domains ):
                        class_idx= list_labels[key][d_idx] == y_c
                        curr_class_size+= list_labels[key][d_idx][class_idx].shape[0]
                        
                    if base_class_size < curr_class_size:
                        base_class_size= curr_class_size
                        if key == 'aug':
                            base_class_idx= 0
                        else:
                            base_class_idx= 1                            
                    
                self.base_domain_size += base_class_size
                print('Max Class Size: ', base_class_size, ' Base Domain Idx: ', base_class_idx, ' Class Label: ', y_c )                
                
        # Stack
        data_imgs = torch.cat(list_img['aug'] + list_img['org'] )
        data_labels = torch.cat(list_labels['aug'] + list_labels['org'] )
        data_indices = np.array(list_idx['aug']+list_idx['org']) 
        data_indices= np.hstack(data_indices)
        list_classes= list_classes['aug'] + list_classes['org']
        self.training_list_size = [ list_size['aug'], list_size['org'] ]           

        #No ground truth objects in PACS, for reference we set them same as data indices
        data_objects= copy.deepcopy(data_indices)                        
        
        # Create domain labels
        data_domains = torch.zeros(data_labels.size())
        domain_start=0
        for idx in range(len(self.training_list_size)):
            curr_domain_size= self.training_list_size[idx]
            data_domains[ domain_start: domain_start+ curr_domain_size ] += idx
            domain_start+= curr_domain_size
            
        # Shuffle everything one more time
        inds = np.arange(data_labels.size()[0])
        np.random.shuffle(inds)
        data_imgs = data_imgs[inds]
        data_labels = data_labels[inds]
        data_domains = data_domains[inds].long()
        data_indices = data_indices[inds]
        data_objects = data_objects[inds]

        # Convert to onehot
        out_classes= list
Download .txt
gitextract_exyigezx/

├── .github/
│   └── workflows/
│       └── python-package.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.rst
├── SECURITY.md
├── algorithms/
│   ├── __init__.py
│   ├── algo.py
│   ├── csd.py
│   ├── dann.py
│   ├── erm.py
│   ├── erm_match.py
│   ├── hybrid.py
│   ├── irm.py
│   ├── match_dg.py
│   └── mmd.py
├── azure_scripts/
│   ├── chest.yaml
│   ├── chest_ctr.yaml
│   ├── chest_ctr_spur.yaml
│   ├── chest_matchdg.yaml
│   ├── chest_matchdg_spur.yaml
│   ├── chest_spur.yaml
│   ├── fmnist.yaml
│   ├── irm_fashion.yaml
│   ├── irm_mnist.yaml
│   ├── mnist.yaml
│   ├── mnist_ctr.yaml
│   ├── mnist_ctr_spur.yaml
│   ├── mnist_spur.yaml
│   ├── pacs.yaml
│   ├── pacs_art_painting.yaml
│   ├── pacs_cartoon.yaml
│   ├── pacs_ctr.yaml
│   ├── pacs_erm.yaml
│   ├── pacs_hybrid.yaml
│   ├── pacs_matchdg.yaml
│   ├── pacs_perfect.yaml
│   ├── pacs_photo.yaml
│   ├── pacs_random.yaml
│   ├── pacs_sketch.yaml
│   └── setup_data_mnist.yaml
├── chestxray_download.txt
├── data/
│   ├── __init__.py
│   ├── adult_loader.py
│   ├── chestxray_loader.py
│   ├── chestxray_loader_aug.py
│   ├── chestxray_loader_match_eval.py
│   ├── data_gen_domainbed.py
│   ├── data_gen_mnist.py
│   ├── data_loader.py
│   ├── mnist_loader.py
│   ├── mnist_loader_match_eval.py
│   ├── mnist_loader_match_eval_spur.py
│   ├── mnist_loader_spur.py
│   ├── pacs_loader.py
│   ├── pacs_loader_aug.py
│   ├── pacs_loader_match_eval.py
│   ├── slab_loader.py
│   └── slab_loader_spur.py
├── data_gen_syn.py
├── docs/
│   ├── _config.yml
│   └── notebooks/
│       ├── ChestXRay_Translate.ipynb
│       ├── Preprocess.ipynb
│       ├── Spur_Rotated_MNIST.ipynb
│       ├── beta/
│       │   ├── HParam_Plots.ipynb
│       │   ├── adult_dataset.ipynb
│       │   └── mnist_results.ipynb
│       ├── helper_plots.ipynb
│       ├── privacy_plots.ipynb
│       ├── reproduce_results.ipynb
│       └── robustdg_getting_started.ipynb
├── evaluation/
│   ├── attribute_attack.py
│   ├── base_eval.py
│   ├── feat_eval.py
│   ├── logit_hist.py
│   ├── match_eval.py
│   ├── per_domain_acc.py
│   ├── privacy_attack.py
│   ├── privacy_entropy.py
│   ├── privacy_loss_attack.py
│   ├── slab_feat_eval.py
│   └── t_sne.py
├── misc_scripts/
│   ├── adult.txt
│   └── logit_plot_slab.py
├── models/
│   ├── alexnet.py
│   ├── densenet.py
│   ├── domain_bed_mnist.py
│   ├── fc.py
│   ├── lenet.py
│   ├── resnet.py
│   └── slab.py
├── reproduce_scripts/
│   ├── cxray_plot.py
│   ├── cxray_run.py
│   ├── mnist_mdg_ctr_run.py
│   ├── mnist_plot.py
│   ├── mnist_run.py
│   ├── pacs_run.py
│   ├── reproduce_rmnist_domainbed.py
│   ├── reproduce_rmnist_lenet.py
│   ├── reproduce_slab.py
│   ├── slab-plot.py
│   ├── slab-run.py
│   └── slab-tune.py
├── requirements.txt
├── requirements_new.txt
├── test.py
├── test_slab.py
├── train.py
└── utils/
    ├── __init__.py
    ├── attribute_attack.py
    ├── bnlearn_data.py
    ├── helper.py
    ├── match_function.py
    ├── privacy_attack.py
    ├── scripts/
    │   ├── __init__.py
    │   ├── data_utils.py
    │   ├── ensemble.py
    │   ├── gendata.py
    │   ├── gpu_utils.py
    │   ├── lms_utils.py
    │   ├── mnistcifar_utils.py
    │   ├── ptb_utils.py
    │   ├── synth_models.py
    │   └── utils.py
    └── slab_data.py
Download .txt
SYMBOL INDEX (337 symbols across 62 files)

FILE: algorithms/algo.py
  function get_noise_multiplier (line 25) | def get_noise_multiplier(
  class BaseAlgo (line 77) | class BaseAlgo():
    method __init__ (line 78) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_model (line 112) | def get_model(self):
    method save_model (line 173) | def save_model(self):
    method get_opt (line 181) | def get_opt(self):
    method get_match_function (line 194) | def get_match_function(self, inferred_match, phi):
    method get_match_function_batch (line 216) | def get_match_function_batch(self, batch_idx):
    method get_test_accuracy (line 239) | def get_test_accuracy(self, case):
    method get_dp_noise (line 279) | def get_dp_noise(self):

FILE: algorithms/csd.py
  class CSD (line 21) | class CSD(BaseAlgo):
    method __init__ (line 22) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method forward (line 47) | def forward(self, x, y, di, eval_case=0):
    method epoch_callback (line 76) | def epoch_callback(self, nepoch, final=False):
    method train (line 80) | def train(self):
    method get_test_accuracy (line 142) | def get_test_accuracy(self, case):
    method save_model (line 167) | def save_model(self):

FILE: algorithms/dann.py
  class DANN (line 22) | class DANN(BaseAlgo):
    method __init__ (line 23) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method train (line 53) | def train(self):

FILE: algorithms/erm.py
  class Erm (line 20) | class Erm(BaseAlgo):
    method __init__ (line 21) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method train (line 25) | def train(self):

FILE: algorithms/erm_match.py
  class ErmMatch (line 21) | class ErmMatch(BaseAlgo):
    method __init__ (line 22) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method train (line 27) | def train(self):

FILE: algorithms/hybrid.py
  class Hybrid (line 22) | class Hybrid(BaseAlgo):
    method __init__ (line 23) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method save_model_erm_phase (line 30) | def save_model_erm_phase(self, run):
    method init_erm_phase (line 38) | def init_erm_phase(self):
    method train (line 93) | def train(self):

FILE: algorithms/irm.py
  class Irm (line 20) | class Irm(BaseAlgo):
    method __init__ (line 21) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method train (line 25) | def train(self):

FILE: algorithms/match_dg.py
  class MatchDG (line 22) | class MatchDG(BaseAlgo):
    method __init__ (line 23) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method train (line 31) | def train(self):
    method save_model_ctr_phase (line 38) | def save_model_ctr_phase(self, epoch):
    method save_model_erm_phase (line 42) | def save_model_erm_phase(self, run):
    method init_erm_phase (line 50) | def init_erm_phase(self):
    method train_ctr_phase (line 111) | def train_ctr_phase(self):
    method train_erm_phase (line 271) | def train_erm_phase(self):

FILE: algorithms/mmd.py
  class MMD (line 21) | class MMD(BaseAlgo):
    method __init__ (line 22) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method my_cdist (line 39) | def my_cdist(self, x1, x2):
    method gaussian_kernel (line 47) | def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100,
    method mmd (line 57) | def mmd(self, x, y):
    method mmd_regularization (line 76) | def mmd_regularization(self, features, d, nmb):
    method train (line 85) | def train(self):

FILE: data/adult_loader.py
  class Adult (line 18) | class Adult(BaseDataLoader):
    method __init__ (line 19) | def __init__(self, args, list_train_domains, root, transform=None, dat...
    method _get_data (line 24) | def _get_data(self):

FILE: data/chestxray_loader.py
  class ChestXRay (line 18) | class ChestXRay(BaseDataLoader):
    method __init__ (line 19) | def __init__(self, args, list_train_domains, root, transform=None, dat...
    method _get_data (line 24) | def _get_data(self):

FILE: data/chestxray_loader_aug.py
  class ChestXRayAug (line 18) | class ChestXRayAug(BaseDataLoader):
    method __init__ (line 19) | def __init__(self, args, list_domains, root, transform=None, data_case...
    method __getitem__ (line 24) | def __getitem__(self, index):
    method _get_data (line 36) | def _get_data(self):

FILE: data/chestxray_loader_match_eval.py
  class ChestXRayAugEval (line 18) | class ChestXRayAugEval(BaseDataLoader):
    method __init__ (line 19) | def __init__(self, args, list_domains, root, transform=None, data_case...
    method _get_data (line 24) | def _get_data(self):

FILE: data/data_gen_domainbed.py
  function generate_rotated_domain_data (line 20) | def generate_rotated_domain_data(imgs, labels, data_case, dataset, indic...

FILE: data/data_gen_mnist.py
  function generate_rotated_domain_data (line 21) | def generate_rotated_domain_data(imgs, labels, data_case, dataset, indic...

FILE: data/data_loader.py
  class BaseDataLoader (line 8) | class BaseDataLoader(data_utils.Dataset):
    method __init__ (line 9) | def __init__(self, args, list_domains, root, transform=None, data_case...
    method __len__ (line 28) | def __len__(self):
    method __getitem__ (line 31) | def __getitem__(self, index):
    method get_size (line 42) | def get_size(self):
    method get_item_spur (line 45) | def get_item_spur(self, index):

FILE: data/mnist_loader.py
  class MnistRotated (line 15) | class MnistRotated(BaseDataLoader):
    method __init__ (line 16) | def __init__(self, args, list_domains, mnist_subset, root, transform=N...
    method _get_data (line 24) | def _get_data(self):

FILE: data/mnist_loader_match_eval.py
  class MnistRotatedAugEval (line 15) | class MnistRotatedAugEval(BaseDataLoader):
    method __init__ (line 16) | def __init__(self, args, list_domains, mnist_subset, root, transform=N...
    method _get_data (line 24) | def _get_data(self):

FILE: data/mnist_loader_match_eval_spur.py
  class MnistRotatedAugEval (line 15) | class MnistRotatedAugEval(BaseDataLoader):
    method __init__ (line 16) | def __init__(self, args, list_domains, mnist_subset, root, transform=N...
    method _get_data (line 24) | def _get_data(self):

FILE: data/mnist_loader_spur.py
  class MnistRotated (line 15) | class MnistRotated(BaseDataLoader):
    method __init__ (line 16) | def __init__(self, args, list_domains, mnist_subset, root, transform=N...
    method _get_data (line 24) | def _get_data(self):

FILE: data/pacs_loader.py
  class PACS (line 18) | class PACS(BaseDataLoader):
    method __init__ (line 19) | def __init__(self, args, list_domains, root, transform=None, data_case...
    method _get_data (line 24) | def _get_data(self):

FILE: data/pacs_loader_aug.py
  class PACSAug (line 18) | class PACSAug(BaseDataLoader):
    method __init__ (line 19) | def __init__(self, args, list_domains, root, transform=None, data_case...
    method __getitem__ (line 24) | def __getitem__(self, index):
    method _get_data (line 36) | def _get_data(self):

FILE: data/pacs_loader_match_eval.py
  class PACSAugEval (line 18) | class PACSAugEval(BaseDataLoader):
    method __init__ (line 19) | def __init__(self, args, list_domains, root, transform=None, data_case...
    method _get_data (line 24) | def _get_data(self):

FILE: data/slab_loader.py
  class SlabData (line 22) | class SlabData(BaseDataLoader):
    method __init__ (line 23) | def __init__(self, args, list_domains, root, transform=None, data_case...
    method _get_data (line 58) | def _get_data(self, domain_size, data_dim, total_slabs, spur_probs, sl...

FILE: data/slab_loader_spur.py
  class SlabData (line 25) | class SlabData(BaseDataLoader):
    method __init__ (line 26) | def __init__(self, args, list_train_domains, root, transform=None, dat...
    method _get_data (line 58) | def _get_data(self, domain_size, data_dim, total_slabs, spur_probs, sl...

FILE: evaluation/attribute_attack.py
  class SpurCorrDataLoader (line 35) | class SpurCorrDataLoader(data_utils.Dataset):
    method __init__ (line 36) | def __init__(self, dataloader):
    method __len__ (line 47) | def __len__(self):
    method __getitem__ (line 50) | def __getitem__(self, index):
  class AttributeAttack (line 60) | class AttributeAttack(BaseEval):
    method __init__ (line 62) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_spur_logits (line 67) | def get_spur_logits(self):
    method get_logits (line 132) | def get_logits(self):
    method get_metric_eval (line 189) | def get_metric_eval(self):

FILE: evaluation/base_eval.py
  class BaseEval (line 19) | class BaseEval():
    method __init__ (line 20) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_model (line 60) | def get_model(self, run_matchdg_erm=0):
    method load_model (line 124) | def load_model(self, run_matchdg_erm):
    method forward (line 152) | def forward(self, x_e):
    method get_logits (line 163) | def get_logits(self):
    method get_metric_eval (line 203) | def get_metric_eval(self):

FILE: evaluation/feat_eval.py
  class FeatEval (line 22) | class FeatEval(BaseEval):
    method __init__ (line 24) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_match_function_batch (line 27) | def get_match_function_batch(self, batch_idx):
    method get_metric_eval (line 49) | def get_metric_eval(self):

FILE: evaluation/logit_hist.py
  class LogitHist (line 39) | class LogitHist(BaseEval):
    method __init__ (line 41) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_loss (line 45) | def get_loss(self):
    method get_metric_eval (line 83) | def get_metric_eval(self):

FILE: evaluation/match_eval.py
  class MatchEval (line 20) | class MatchEval(BaseEval):
    method __init__ (line 22) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_metric_eval (line 25) | def get_metric_eval(self):

FILE: evaluation/per_domain_acc.py
  class PerDomainAcc (line 20) | class PerDomainAcc(BaseEval):
    method __init__ (line 22) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_metric_eval (line 25) | def get_metric_eval(self):

FILE: evaluation/privacy_attack.py
  class PrivacyAttack (line 37) | class PrivacyAttack(BaseEval):
    method __init__ (line 39) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_metric_eval (line 44) | def get_metric_eval(self):

FILE: evaluation/privacy_entropy.py
  class PrivacyEntropy (line 37) | class PrivacyEntropy(BaseEval):
    method __init__ (line 39) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_label_logits (line 44) | def get_label_logits(self):
    method create_attack_data (line 93) | def create_attack_data(self, train_data, test_data, sample_size, case=...
    method eval_entropy_attack (line 117) | def eval_entropy_attack(self, data, threshold_data, scale=1.0, case='t...
    method get_metric_eval (line 182) | def get_metric_eval(self):

FILE: evaluation/privacy_loss_attack.py
  class PrivacyLossAttack (line 37) | class PrivacyLossAttack(BaseEval):
    method __init__ (line 39) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_ce_loss (line 44) | def get_ce_loss(self):
    method create_attack_data (line 110) | def create_attack_data(self, train_data, test_data, sample_size, case=...
    method create_attack_data_true_obj (line 135) | def create_attack_data_true_obj(self, train_data, test_data, sample_si...
    method eval_entropy_attack (line 195) | def eval_entropy_attack(self, data, threshold_data, scale=1.0, case='t...
    method get_metric_eval (line 247) | def get_metric_eval(self):

FILE: evaluation/slab_feat_eval.py
  function sim_matrix (line 22) | def sim_matrix(a, b, eps=1e-8):
  class SlabFeatEval (line 32) | class SlabFeatEval(BaseEval):
    method __init__ (line 34) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_metric_eval (line 37) | def get_metric_eval(self):

FILE: evaluation/t_sne.py
  class TSNE (line 20) | class TSNE(BaseEval):
    method __init__ (line 22) | def __init__(self, args, train_dataset, val_dataset, test_dataset, bas...
    method get_metric_eval (line 25) | def get_metric_eval(self):

FILE: misc_scripts/logit_plot_slab.py
  function get_logits (line 33) | def get_logits(model, loader, device, label=1):

FILE: models/alexnet.py
  class Id (line 26) | class Id(nn.Module):
    method __init__ (line 27) | def __init__(self):
    method forward (line 30) | def forward(self, x):
  class AlexNet (line 34) | class AlexNet(nn.Module):
    method __init__ (line 35) | def __init__(self, num_classes=1000, dropout=True):
    method initialize_params (line 69) | def initialize_params(self):
    method forward (line 76) | def forward(self, x):
  function alexnet (line 83) | def alexnet(model_name, classes, fc_layer, num_ch, pre_trained, os_env):

FILE: models/densenet.py
  class Identity (line 9) | class Identity(nn.Module):
    method __init__ (line 10) | def __init__(self,n_inputs):
    method forward (line 14) | def forward(self, x):
  function get_densenet (line 18) | def get_densenet(model_name, classes, fc_layer, num_ch, pre_trained, os_...

FILE: models/domain_bed_mnist.py
  class DomainBed (line 12) | class DomainBed(torch.nn.Module):
    method __init__ (line 14) | def __init__(self, num_ch, fc_layer):
    method forward (line 34) | def forward(self, x):

FILE: models/fc.py
  class FC (line 11) | class FC(torch.nn.Module):
    method __init__ (line 13) | def __init__(self, classes, fc_layer):
    method forward (line 31) | def forward(self, x):

FILE: models/lenet.py
  class LeNet5 (line 11) | class LeNet5(torch.nn.Module):
    method __init__ (line 13) | def __init__(self):
    method forward (line 41) | def forward(self, x):

FILE: models/resnet.py
  class GroupNorm (line 8) | class GroupNorm(torch.nn.GroupNorm):
    method __init__ (line 9) | def __init__(self, num_channels, num_groups=32, **kwargs):
  class Identity (line 13) | class Identity(nn.Module):
    method __init__ (line 14) | def __init__(self,n_inputs):
    method forward (line 18) | def forward(self, x):
  function get_resnet (line 22) | def get_resnet(model_name, classes, fc_layer, num_ch, pre_trained, dp_no...

FILE: models/slab.py
  class SlabClf (line 10) | class SlabClf(nn.Module):
    method __init__ (line 11) | def __init__(self, inp_shape, out_shape, fc_layer):
    method forward (line 35) | def forward(self, x):

FILE: reproduce_scripts/cxray_plot.py
  function get_base_dir (line 7) | def get_base_dir(test_domain, dataset, metric):

FILE: reproduce_scripts/mnist_plot.py
  function get_base_dir (line 7) | def get_base_dir(train_case, test_case, dataset, metric):

FILE: test_slab.py
  function get_logits (line 34) | def get_logits(model, loader, device, label=1):

FILE: utils/attribute_attack.py
  function to_onehot (line 24) | def to_onehot(inp):
  function my_attack_model (line 30) | def my_attack_model(features, labels, mode, params):
  function mia (line 85) | def mia(X_att_train, y_att_train, X_att_test, y_att_test, my_feature_col...

FILE: utils/bnlearn_data.py
  function to_onehot (line 10) | def to_onehot(inp):
  function split_list (line 15) | def split_list(a_list):
  function load_dnn_prob (line 20) | def load_dnn_prob(train_prob, test_prob, features):
  function load_bnet_prob (line 104) | def load_bnet_prob(dataset_name, num_examples, output_name, dist):
  function load_data (line 151) | def load_data(dataset_name, num_examples, output_name, noise):
  function train_input_fn (line 180) | def train_input_fn(features, labels, batch_size):
  function eval_input_fn (line 194) | def eval_input_fn(features, labels, batch_size):

FILE: utils/helper.py
  function slab_batch_process (line 19) | def slab_batch_process(x, y, d, o):
  function t_sne_plot (line 34) | def t_sne_plot(X):
  function classifier (line 39) | def classifier(x_e, phi, w):
  function erm_loss (line 42) | def erm_loss(temp_logits, target_label):
  function compute_irm_penalty (line 46) | def compute_irm_penalty( logits, target_label, cuda):
  function cosine_similarity (line 55) | def cosine_similarity( x1, x2 ):
  function l1_dist (line 59) | def l1_dist(x1, x2):
  function l2_dist (line 78) | def l2_dist(x1, x2):
  function embedding_dist (line 97) | def embedding_dist(x1, x2, pos_metric, tau=0.05, xent=False):
  function get_dataloader (line 167) | def get_dataloader(args, run, domains, data_case, eval_case, kwargs):

FILE: utils/match_function.py
  function init_data_match_dict (line 22) | def init_data_match_dict(args, keys, vals, variation):
  function get_matched_pairs (line 43) | def get_matched_pairs(args, cuda, train_dataset, domain_size, total_doma...

FILE: utils/privacy_attack.py
  function to_onehot (line 24) | def to_onehot(inp):
  function my_attack_model (line 30) | def my_attack_model(features, labels, mode, params):
  function mia (line 85) | def mia(X_att_train, y_att_train, X_att_test, y_att_test, my_feature_col...

FILE: utils/scripts/data_utils.py
  function msd (line 33) | def msd(x, r=3):
  function _get_dataloaders (line 36) | def _get_dataloaders(trd, ted,  bs, pm=True, shuffle=True):
  function get_cifar10_models (line 41) | def get_cifar10_models(device=None, pretrained=True):
  function plot_decision_boundary (line 61) | def plot_decision_boundary(dl, model, c1, c2, ax=None, print_info=True):
  function get_binary_datasets (line 84) | def get_binary_datasets(X, Y, y1, y2, image_width=28, use_cnn=False):
  function get_binary_loader (line 95) | def get_binary_loader(dl, y1, y2):
  function get_mnist (line 100) | def get_mnist(fpath=DOWNLOAD_DIR, flatten=False, binarize=False, normali...
  function get_mnist_dl (line 125) | def get_mnist_dl(fpath=DOWNLOAD_DIR, to_np=False, bs=128, pm=False, shuf...
  function get_cifar (line 132) | def get_cifar(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=False,...
  function get_cifar_dl (line 180) | def get_cifar_dl(fpath=DOWNLOAD_DIR, use_cifar10=False, bs=128, shuffle=...
  function get_cifar_np (line 188) | def get_cifar_np(fpath=DOWNLOAD_DIR, use_cifar10=False, flatten_data=Fal...

FILE: utils/scripts/ensemble.py
  class Ensemble (line 17) | class Ensemble(nn.Module):
    method _get_dummy_classifier (line 19) | def _get_dummy_classifier(self):
    method __init__ (line 24) | def __init__(self, models, num_classes, use_softmax=False):
    method _forward (line 37) | def _forward(self, x):
    method forward (line 40) | def forward(self, x):
    method get_output_loader (line 44) | def get_output_loader(self, dl, device=gu.get_device(None), bs=None):
    method fit_classifier (line 56) | def fit_classifier(self, tr_dl, te_dl, lr=0.05, adam=False, wd=5e-5, d...
  class EnsembleLinear (line 74) | class EnsembleLinear(Ensemble):
    method _get_classifier (line 76) | def _get_classifier(self):
    method __init__ (line 85) | def __init__(self, models, num_classes=2, use_softmax=False, use_bias=...
    method _forward (line 90) | def _forward(self, x):
  class EnsembleMLP (line 97) | class EnsembleMLP(Ensemble):
    method _get_classifier (line 99) | def _get_classifier(self):
    method __init__ (line 104) | def __init__(self, models, num_classes=2, use_softmax=False, hdim=None...
    method _forward (line 110) | def _forward(self, x):
  class EnsembleAverage (line 117) | class EnsembleAverage(Ensemble):
    method __init__ (line 119) | def __init__(self, models, num_classes=2, use_softmax=False):
    method _forward (line 123) | def _forward(self, x):
    method fit_classifier (line 129) | def fit_classifier(self, *args, **kwargs):

FILE: utils/scripts/gendata.py
  function _prep_data (line 10) | def _prep_data(X, Y, N_tr, N_te, bs, nw, pm, w, orth_matrix=None):
  function _get_random_data (line 29) | def _get_random_data(N, dim, scale):
  function generate_linsep_data_v2 (line 35) | def generate_linsep_data_v2(N_tr, dim, eff_margin, width=10., bs=256, sc...
  function sample_from_unif_union_of_unifs (line 54) | def sample_from_unif_union_of_unifs(unifs, size):
  function generate_ub_linslab_data_diffmargin_v2 (line 63) | def generate_ub_linslab_data_diffmargin_v2(N_tr, dim, eff_lin_margins, e...
  function generate_ub_linslab_data_v2 (line 270) | def generate_ub_linslab_data_v2(N_tr, dim, eff_lin_margin, eff_slab_marg...
  function get_lms_data (line 288) | def get_lms_data(**kw):

FILE: utils/scripts/gpu_utils.py
  function get_gpu_info (line 7) | def get_gpu_info(print_info=True, get_specs=False):
  function get_device (line 27) | def get_device(device_id=None): # None -> cpu
  function get_gpu_name (line 32) | def get_gpu_name():
  function get_cuda_version (line 41) | def get_cuda_version():
  function get_cudnn_version (line 59) | def get_cudnn_version():

FILE: utils/scripts/lms_utils.py
  function parse_data (line 19) | def parse_data(exps=None, root='/', **funcs_kw):
  function parse_exp_stats (line 58) | def parse_exp_stats(data):
  function parse_exp_model (line 83) | def parse_exp_model(data):
  function parse_exp_depth1_model (line 92) | def parse_exp_depth1_model(data):
  function parse_exp_linear_model (line 107) | def parse_exp_linear_model(data):
  function parse_exp_data (line 116) | def parse_exp_data(data, load_X=False):
  function get_yhat (line 134) | def get_yhat(model, data):
  function get_acc (line 138) | def get_acc(y,yhat):
  function parse_and_get_df (line 142) | def parse_and_get_df(root, prefix, files=None, device_id=None, only_load...
  function viz (line 172) | def viz(d, c1, c2, k=80_000, info=True, plot_dm=True, plot_data=True, us...
  function visualize_boundary (line 209) | def visualize_boundary(model, data, c1, c2, dim, ax=None, is_binary=Fals...
  function get_randomized_loader (line 213) | def get_randomized_loader(dl, W, coordinates):
  function get_feature_deps (line 238) | def get_feature_deps(dl, model, W=None, dep_type='random', only_linear=F...
  function get_subset_feature_deps (line 305) | def get_subset_feature_deps(dl, model, coords_set, comb_size, W=None, de...

FILE: utils/scripts/mnistcifar_utils.py
  function get_binary_mnist (line 12) | def get_binary_mnist(y1=0, y2=1, apply_padding=True, repeat_channels=True):
  function get_binary_cifar (line 27) | def get_binary_cifar(y1=3, y2=5, c={0,1,2,3,4}, use_cifar10=True):
  function combine_datasets (line 37) | def combine_datasets(Xm, Ym, Xc, Yc, randomize_order=False, randomize_fi...
  function get_mnist_cifar (line 79) | def get_mnist_cifar(mnist_classes=(0,1), cifar_classes=None, c={0,1,2,3,4},
  function get_mnist_cifar_dl (line 92) | def get_mnist_cifar_dl(mnist_classes=(0,1), cifar_classes=None, c={0,1,2...

FILE: utils/scripts/ptb_utils.py
  function get_yhat (line 23) | def get_yhat(model, data): return torch.argmax(model(data), 1)
  function get_acc (line 24) | def get_acc(y,yhat): return (y==yhat).sum().item()/float(len(y))
  class PGD_Attack (line 26) | class PGD_Attack(object):
    method __init__ (line 28) | def __init__(self, eps, lr, num_iter, loss_type, rand_eps=1e-3,
    method evaluate_attack (line 45) | def evaluate_attack(self, dl, model):
    method perturb (line 79) | def perturb(self, xb, yb, model, cpu=False):
    method _perturb_once (line 105) | def _perturb_once(self, xb, yb, model, track_scores=False, stop_const=...
    method _init_delta (line 143) | def _init_delta(self, xb, yb):
    method _clamp_input (line 150) | def _clamp_input(self, xb, yb):
    method _get_loss (line 157) | def _get_loss(self, xb, yb, model, get_scores=False):
  class L2_PGD_Attack (line 189) | class L2_PGD_Attack(PGD_Attack):
    method get_norms (line 193) | def get_norms(self, X):
    method _update_delta (line 197) | def _update_delta(self, xb, yb, update_mask=None):
    method _init_delta (line 218) | def _init_delta(self, xb, yb):
  class Linf_PGD_Attack (line 228) | class Linf_PGD_Attack(PGD_Attack):
    method _update_delta (line 230) | def _update_delta(self, xb, yb, **kw):

FILE: utils/scripts/synth_models.py
  function kaiming_init (line 12) | def kaiming_init(m):
  class SequenceClassifier (line 17) | class SequenceClassifier(nn.Module):
    method __init__ (line 19) | def __init__(self, seq_model, idim, hdim, hl, input_size, num_classes=...
    method forward (line 35) | def forward(self, x):
  class GRUClassifier (line 46) | class GRUClassifier(SequenceClassifier):
    method __init__ (line 48) | def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_...
  class LSTMClassifier (line 51) | class LSTMClassifier(SequenceClassifier):
    method __init__ (line 53) | def __init__(self, idim, hdim, hl, input_size, num_classes=2, many_to_...
  class CNNClassifier (line 56) | class CNNClassifier(nn.Module):
    method __init__ (line 58) | def __init__(self, out_channels, hl, kernel_size, idim, num_classes=2,...
    method forward (line 92) | def forward(self, x):
  class CNN2DClassifier (line 104) | class CNN2DClassifier(nn.Module):
    method __init__ (line 106) | def __init__(self, num_filters, filter_size, num_layers, input_shape, ...
    method forward (line 134) | def forward(self, x):
  function get_linear (line 139) | def get_linear(input_dim, num_classes):
  function get_fcn (line 142) | def get_fcn(idim, hdim, odim, hl=1, init=False, activation=nn.ReLU, use_...

FILE: utils/scripts/utils.py
  function get_orthonormal_matrix (line 24) | def get_orthonormal_matrix(n):
  function get_dataloader (line 32) | def get_dataloader(X, Y, bs, **kw):
  function split_dataloader (line 35) | def split_dataloader(dl, frac=0.5):
  function _to_dl (line 47) | def _to_dl(X, Y, bs, shuffle=True):
  function extract_tensors_from_loader (line 50) | def extract_tensors_from_loader(dl, repeat=1, transform_fn=None):
  function extract_numpy_from_loader (line 62) | def extract_numpy_from_loader(dl, repeat=1, transform_fn=None):
  function _to_tensor_dl (line 66) | def _to_tensor_dl(dl, repeat=1, bs=None):
  function flatten_loader (line 71) | def flatten_loader(dl, bs=None):
  function merge_loaders (line 76) | def merge_loaders(dla, dlb):
  function transform_loader (line 82) | def transform_loader(dl, func, shuffle=True):
  function visualize_tensors (line 87) | def visualize_tensors(P, size=8, normalize=True, scale_each=False, permu...
  function visualize_loader (line 101) | def visualize_loader(dl, ax=None, size=8, normalize=True, scale_each=Fal...
  function visualize_loader_by_class (line 107) | def visualize_loader_by_class(dl, ax=None, size=8, normalize=True, scale...
  function visualize_perturbations (line 121) | def visualize_perturbations(P, transform_fn=None):
  function get_logits_given_tensor (line 132) | def get_logits_given_tensor(X, model, device=None, bs=250, softmax=False):
  function get_predictions_given_tensor (line 150) | def get_predictions_given_tensor(X, model, device=None, bs=250):
  function get_accuracy_given_tensor (line 154) | def get_accuracy_given_tensor(X, Y, model, device=None, bs=250):
  function compute_accuracy (line 160) | def compute_accuracy(X, Y, model):
  function compute_loss_and_accuracy_from_dl (line 167) | def compute_loss_and_accuracy_from_dl(dl, model, loss_fn, sample_pct=1.0...
  function count_parameters (line 214) | def count_parameters(model):
  function get_logits (line 217) | def get_logits(model, loader, device):
  function get_scores (line 229) | def get_scores(model, loader, device):
  function get_multiclass_logit_score (line 235) | def get_multiclass_logit_score(L, Y):
  function get_binary_auc (line 249) | def get_binary_auc(model, loader, device):
  function get_multiclass_auc (line 253) | def get_multiclass_auc(model, loader, device, one_vs_rest=True):
  function clip_gradient (line 260) | def clip_gradient(model, clip_value):
  function print_model_gradients (line 264) | def print_model_gradients(model, print_bias=True):
  function hinge_loss (line 271) | def hinge_loss(out, y):
  function pgd_adv_fit_model (line 275) | def pgd_adv_fit_model(model, opt, tr_dl, te_dl, attack, eval_attack=None...
  function fit_model (line 421) | def fit_model(model, loss, opt, train_dl, valid_dl, sch=None, epsilon=1e...
  function save_pickle (line 696) | def save_pickle(fname, d, mode='w'):
  function load_pickle (line 700) | def load_pickle(fname, mode='r'):
  function update_ax (line 704) | def update_ax(ax, title=None, xlabel=None, ylabel=None, legend_loc='best...

FILE: utils/slab_data.py
  function get_data (line 28) | def get_data(num_samples, spur_corr, slab_noise, total_slabs, data_case,...
Condensed preview — 125 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,346K chars).
[
  {
    "path": ".github/workflows/python-package.yml",
    "chars": 1276,
    "preview": "# This workflow will install Python dependencies, run tests and lint with a variety of Python versions\n# For more inform"
  },
  {
    "path": ".gitignore",
    "chars": 1990,
    "preview": "#Data diri\ndata/datasets/\n\n#python environment\namt_envs/\nmatchdg-env/\n\n#Results dir\nresults/\n\n#extra-files\n#*.sh\n\n#phill"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 444,
    "preview": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://op"
  },
  {
    "path": "LICENSE",
    "chars": 1141,
    "preview": "    MIT License\n\n    Copyright (c) Microsoft Corporation.\n\n    Permission is hereby granted, free of charge, to any pers"
  },
  {
    "path": "README.rst",
    "chars": 4554,
    "preview": "Toolkit for Building Robust ML models that generalize to unseen domains (RobustDG)\n====================================="
  },
  {
    "path": "SECURITY.md",
    "chars": 2780,
    "preview": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products an"
  },
  {
    "path": "algorithms/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "algorithms/algo.py",
    "chars": 12780,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport os\nfrom more_itertools import"
  },
  {
    "path": "algorithms/csd.py",
    "chars": 6967,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\n\nimport torch\nfrom torch.autograd im"
  },
  {
    "path": "algorithms/dann.py",
    "chars": 5759,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport time\n\nimport torch\nfrom torch"
  },
  {
    "path": "algorithms/erm.py",
    "chars": 3503,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\n\nimport torch\nfrom torch.autograd im"
  },
  {
    "path": "algorithms/erm_match.py",
    "chars": 8169,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport time\n\nimport torch\nfrom torch"
  },
  {
    "path": "algorithms/hybrid.py",
    "chars": 14539,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport os\n\nimport torch\nfrom torch.a"
  },
  {
    "path": "algorithms/irm.py",
    "chars": 5631,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\n\nimport torch\nfrom torch.autograd im"
  },
  {
    "path": "algorithms/match_dg.py",
    "chars": 20419,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport os\n\nimport torch\nfrom torch.a"
  },
  {
    "path": "algorithms/mmd.py",
    "chars": 6425,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport time\n\nimport torch\nfrom torch"
  },
  {
    "path": "azure_scripts/chest.yaml",
    "chars": 7829,
    "preview": "description: ChestXray Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). Every"
  },
  {
    "path": "azure_scripts/chest_ctr.yaml",
    "chars": 1834,
    "preview": "description: ChestXray Dataset Constrastive Learning\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to "
  },
  {
    "path": "azure_scripts/chest_ctr_spur.yaml",
    "chars": 1856,
    "preview": "description: ChestXray Dataset Constrastive Learning\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to "
  },
  {
    "path": "azure_scripts/chest_matchdg.yaml",
    "chars": 2047,
    "preview": "description: Hyperparam sweep on ChestXray Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (m"
  },
  {
    "path": "azure_scripts/chest_matchdg_spur.yaml",
    "chars": 2080,
    "preview": "description: Hyperparam sweep on ChestXray Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (m"
  },
  {
    "path": "azure_scripts/chest_spur.yaml",
    "chars": 7869,
    "preview": "description: ChestXray Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). Every"
  },
  {
    "path": "azure_scripts/fmnist.yaml",
    "chars": 2656,
    "preview": "description: Fashion MNIST Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). E"
  },
  {
    "path": "azure_scripts/irm_fashion.yaml",
    "chars": 1272,
    "preview": "description: Hyperparam sweep on IRM MNIST\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, "
  },
  {
    "path": "azure_scripts/irm_mnist.yaml",
    "chars": 1269,
    "preview": "description: Hyperparam sweep on IRM MNIST\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, "
  },
  {
    "path": "azure_scripts/mnist.yaml",
    "chars": 2743,
    "preview": "description: Fashion MNIST Dataset\n\ntarget:\n  service: amlk8s\n  # which virtual cluster you belong to (msrlabs, etc.). E"
  },
  {
    "path": "azure_scripts/mnist_ctr.yaml",
    "chars": 2150,
    "preview": "description: MNIST Dataset Constrastive Learning\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msr"
  },
  {
    "path": "azure_scripts/mnist_ctr_spur.yaml",
    "chars": 1242,
    "preview": "description: MNIST Dataset Constrastive Learning\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msr"
  },
  {
    "path": "azure_scripts/mnist_spur.yaml",
    "chars": 2760,
    "preview": "description: Fashion MNIST Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). E"
  },
  {
    "path": "azure_scripts/pacs.yaml",
    "chars": 3220,
    "preview": "description: PACS Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). Everyone h"
  },
  {
    "path": "azure_scripts/pacs_art_painting.yaml",
    "chars": 2586,
    "preview": "description: Hyperparam sweep on PACS\n\ntarget:\n  service: amlk8s\n  # which virtual cluster you belong to (msrlabs, etc.)"
  },
  {
    "path": "azure_scripts/pacs_cartoon.yaml",
    "chars": 2586,
    "preview": "description: Hyperparam sweep on PACS\n\ntarget:\n  service: amlk8s\n  # which virtual cluster you belong to (msrlabs, etc.)"
  },
  {
    "path": "azure_scripts/pacs_ctr.yaml",
    "chars": 3746,
    "preview": "description: PACS Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). Everyone h"
  },
  {
    "path": "azure_scripts/pacs_erm.yaml",
    "chars": 2181,
    "preview": "description: PACS ERM Dataset\n\ntarget:\n  service: amlk8s\n  # which virtual cluster you belong to (msrlabs, etc.). Everyo"
  },
  {
    "path": "azure_scripts/pacs_hybrid.yaml",
    "chars": 5846,
    "preview": "description: PACS MatchDG Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). Ev"
  },
  {
    "path": "azure_scripts/pacs_matchdg.yaml",
    "chars": 4816,
    "preview": "description: PACS MatchDG Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). Ev"
  },
  {
    "path": "azure_scripts/pacs_perfect.yaml",
    "chars": 4756,
    "preview": "description: PACS MatchDG Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc.). Ev"
  },
  {
    "path": "azure_scripts/pacs_photo.yaml",
    "chars": 2606,
    "preview": "description: Hyperparam sweep on PACS\n\ntarget:\n  service: amlk8s\n  # which virtual cluster you belong to (msrlabs, etc.)"
  },
  {
    "path": "azure_scripts/pacs_random.yaml",
    "chars": 3429,
    "preview": "description: PACS Random Match Dataset\n\ntarget:\n  service: philly\n  # which virtual cluster you belong to (msrlabs, etc."
  },
  {
    "path": "azure_scripts/pacs_sketch.yaml",
    "chars": 2586,
    "preview": "description: Hyperparam sweep on PACS\n\ntarget:\n  service: amlk8s\n  # which virtual cluster you belong to (msrlabs, etc.)"
  },
  {
    "path": "azure_scripts/setup_data_mnist.yaml",
    "chars": 1138,
    "preview": "description: MNIST Dat Setup\n\ntarget:\n  service: amlk8s\n  # which virtual cluster you belong to (msrlabs, etc.). Everyon"
  },
  {
    "path": "chestxray_download.txt",
    "chars": 1459,
    "preview": "NIH Dataset:\n\ncurl -o nih.zip \"https://storage.googleapis.com/kaggle-data-sets/5839%2F18613%2Fbundle%2Farchive.zip?Googl"
  },
  {
    "path": "data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "data/adult_loader.py",
    "chars": 4527,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\nimpor"
  },
  {
    "path": "data/chestxray_loader.py",
    "chars": 4521,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\nimpor"
  },
  {
    "path": "data/chestxray_loader_aug.py",
    "chars": 4939,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\nimpor"
  },
  {
    "path": "data/chestxray_loader_match_eval.py",
    "chars": 5503,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\nimpor"
  },
  {
    "path": "data/data_gen_domainbed.py",
    "chars": 7297,
    "preview": "#Common imports\nimport numpy as np\nimport sys\nimport os\nimport random\nimport copy\nimport os \n\n#Sklearn\nfrom scipy.stats "
  },
  {
    "path": "data/data_gen_mnist.py",
    "chars": 11306,
    "preview": "#Common imports\nimport numpy as np\nimport sys\nimport os\nimport argparse\nimport random\nimport copy\nimport os \n\n#Sklearn\nf"
  },
  {
    "path": "data/data_loader.py",
    "chars": 1596,
    "preview": "import os\nimport copy\nimport numpy as np\nimport torch\nimport torch.utils.data as data_utils\nfrom torchvision import data"
  },
  {
    "path": "data/mnist_loader.py",
    "chars": 4210,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\n\n#Pytorch\nimport torch\nimport torch.utils.data as"
  },
  {
    "path": "data/mnist_loader_match_eval.py",
    "chars": 5224,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\n\n#Pytorch\nimport torch\nimport torch.utils.data as"
  },
  {
    "path": "data/mnist_loader_match_eval_spur.py",
    "chars": 5570,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\n\n#Pytorch\nimport torch\nimport torch.utils.data as"
  },
  {
    "path": "data/mnist_loader_spur.py",
    "chars": 4538,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\n\n#Pytorch\nimport torch\nimport torch.utils.data as"
  },
  {
    "path": "data/pacs_loader.py",
    "chars": 5804,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\nimpor"
  },
  {
    "path": "data/pacs_loader_aug.py",
    "chars": 6931,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\nimpor"
  },
  {
    "path": "data/pacs_loader_match_eval.py",
    "chars": 7062,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\nimpor"
  },
  {
    "path": "data/slab_loader.py",
    "chars": 5361,
    "preview": "    #Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Pytorch\ni"
  },
  {
    "path": "data/slab_loader_spur.py",
    "chars": 5893,
    "preview": "#Common imports\nimport os\nimport random\nimport copy\nimport numpy as np\nimport h5py\nfrom PIL import Image\n\n#Sklearn\nfrom "
  },
  {
    "path": "data_gen_syn.py",
    "chars": 4142,
    "preview": "import sys\nimport random\nimport os, copy, pickle, time\nimport argparse\nimport itertools\nfrom collections import defaultd"
  },
  {
    "path": "docs/_config.yml",
    "chars": 27,
    "preview": "theme: jekyll-theme-minimal"
  },
  {
    "path": "docs/notebooks/ChestXRay_Translate.ipynb",
    "chars": 21884,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "docs/notebooks/Preprocess.ipynb",
    "chars": 4207,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Download Data\"\n   ]\n  },\n  {\n   "
  },
  {
    "path": "docs/notebooks/Spur_Rotated_MNIST.ipynb",
    "chars": 6294,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "docs/notebooks/beta/HParam_Plots.ipynb",
    "chars": 31185,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "docs/notebooks/beta/adult_dataset.ipynb",
    "chars": 39300,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "docs/notebooks/beta/mnist_results.ipynb",
    "chars": 2334585,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "docs/notebooks/helper_plots.ipynb",
    "chars": 3073,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Visualizing Rotated MINST sample"
  },
  {
    "path": "docs/notebooks/privacy_plots.ipynb",
    "chars": 210615,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "docs/notebooks/reproduce_results.ipynb",
    "chars": 22266,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Reproducing results\\n\",\n    \"\\n\","
  },
  {
    "path": "docs/notebooks/robustdg_getting_started.ipynb",
    "chars": 9150,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Getting started with RobustDG: Ge"
  },
  {
    "path": "evaluation/attribute_attack.py",
    "chars": 10525,
    "preview": "#General Imports\nimport sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json"
  },
  {
    "path": "evaluation/base_eval.py",
    "chars": 9974,
    "preview": "import sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json\nimport pickle\n\ni"
  },
  {
    "path": "evaluation/feat_eval.py",
    "chars": 6212,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nfrom more_itertools import chunked\n\n"
  },
  {
    "path": "evaluation/logit_hist.py",
    "chars": 4072,
    "preview": "#General Imports\nimport sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json"
  },
  {
    "path": "evaluation/match_eval.py",
    "chars": 6042,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\n\nimport torch\nfrom torch.autograd im"
  },
  {
    "path": "evaluation/per_domain_acc.py",
    "chars": 3317,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\n\nimport torch\nfrom torch.autograd im"
  },
  {
    "path": "evaluation/privacy_attack.py",
    "chars": 4924,
    "preview": "#General Imports\nimport sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json"
  },
  {
    "path": "evaluation/privacy_entropy.py",
    "chars": 9502,
    "preview": "#General Imports\nimport sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json"
  },
  {
    "path": "evaluation/privacy_loss_attack.py",
    "chars": 13017,
    "preview": "#General Imports\nimport sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json"
  },
  {
    "path": "evaluation/slab_feat_eval.py",
    "chars": 4570,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\n\nimport torch\nfrom torch.autograd im"
  },
  {
    "path": "evaluation/t_sne.py",
    "chars": 3746,
    "preview": "import sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\n\nimport torch\nfrom torch.autograd im"
  },
  {
    "path": "misc_scripts/adult.txt",
    "chars": 363,
    "preview": "python3 train.py --dataset adult --model fc --out_classes 2 --train_domains male female --test_domains male female --pen"
  },
  {
    "path": "misc_scripts/logit_plot_slab.py",
    "chars": 13258,
    "preview": "#Common imports\nimport os\nimport sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport pic"
  },
  {
    "path": "models/alexnet.py",
    "chars": 3512,
    "preview": "#Module taken from the repository G2DM repository for AlexNet architecture specific to PACS: https://github.com/belaalb/"
  },
  {
    "path": "models/densenet.py",
    "chars": 1101,
    "preview": "from torch import nn\nfrom torch.utils import model_zoo\nimport torchvision\nfrom torchvision.models.resnet import BasicBlo"
  },
  {
    "path": "models/domain_bed_mnist.py",
    "chars": 1541,
    "preview": "import torch\nimport torch.utils.data\nfrom torch import nn, optim\nfrom torch.nn import functional as F\nfrom torchvision i"
  },
  {
    "path": "models/fc.py",
    "chars": 1073,
    "preview": "import torch\nimport torch.utils.data\nfrom torch import nn, optim\nfrom torch.nn import functional as F\nfrom torchvision i"
  },
  {
    "path": "models/lenet.py",
    "chars": 2002,
    "preview": "import torch\nimport torch.utils.data\nfrom torch import nn, optim\nfrom torch.nn import functional as F\nfrom torchvision i"
  },
  {
    "path": "models/resnet.py",
    "chars": 2514,
    "preview": "import torch\nfrom torch import nn\nfrom torch.utils import model_zoo\nimport torchvision\nfrom torchvision.models.resnet im"
  },
  {
    "path": "models/slab.py",
    "chars": 1241,
    "preview": "import torch\nimport torch.utils.data\nfrom torch import nn, optim\nfrom torch.nn import functional as F\nfrom torchvision i"
  },
  {
    "path": "reproduce_scripts/cxray_plot.py",
    "chars": 5300,
    "preview": "import os\nimport sys\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\ndef get_base_dir(test_domain,"
  },
  {
    "path": "reproduce_scripts/cxray_run.py",
    "chars": 6244,
    "preview": "import os\nimport sys\nimport argparse\n\n# Input Parsing\nparser = argparse.ArgumentParser()\nparser.add_argument('--metric',"
  },
  {
    "path": "reproduce_scripts/mnist_mdg_ctr_run.py",
    "chars": 863,
    "preview": "import os\nimport argparse\n\n# Input Parsing\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str,"
  },
  {
    "path": "reproduce_scripts/mnist_plot.py",
    "chars": 6414,
    "preview": "import os\nimport sys\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport numpy as np\n\ndef get_base_dir(train_case, "
  },
  {
    "path": "reproduce_scripts/mnist_run.py",
    "chars": 8461,
    "preview": "import os\nimport sys\nimport argparse\n\n# Input Parsing\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset'"
  },
  {
    "path": "reproduce_scripts/pacs_run.py",
    "chars": 8492,
    "preview": "import os\nimport sys\n\n# Input Parsing\nparser = argparse.ArgumentParser()\nparser.add_argument('--method', type=str, defau"
  },
  {
    "path": "reproduce_scripts/reproduce_rmnist_domainbed.py",
    "chars": 1978,
    "preview": "import os\nimport argparse\n\n# Input Parsing\nparser = argparse.ArgumentParser()\nparser.add_argument('--methods', nargs='+'"
  },
  {
    "path": "reproduce_scripts/reproduce_rmnist_lenet.py",
    "chars": 2294,
    "preview": "import os\nimport argparse\n\n# Input Parsing\nparser = argparse.ArgumentParser()\nparser.add_argument('--methods', nargs='+'"
  },
  {
    "path": "reproduce_scripts/reproduce_slab.py",
    "chars": 2150,
    "preview": "import os\nimport sys\n\n'''\nargv1: Allowed Values (train, evaluate)\n'''\n\ncase= sys.argv[1]\nmethods=['erm', 'mmd', 'coral',"
  },
  {
    "path": "reproduce_scripts/slab-plot.py",
    "chars": 4052,
    "preview": "import matplotlib\nimport matplotlib.pyplot as plt\nimport sys\nimport os\nimport numpy as np\n\nslab_noise= float(sys.argv[1]"
  },
  {
    "path": "reproduce_scripts/slab-run.py",
    "chars": 5116,
    "preview": "import os\nimport sys\nimport argparse\n\n# Input Parsing\nparser = argparse.ArgumentParser()\nparser.add_argument('--case', t"
  },
  {
    "path": "reproduce_scripts/slab-tune.py",
    "chars": 3847,
    "preview": "import os\nimport sys\n\n'''\nargv1: Method for which HParams need to be tuned (erm, rand, perf, mmd, coral, c-mmd, c-coral,"
  },
  {
    "path": "requirements.txt",
    "chars": 102,
    "preview": "numpy\npandas\nh5py\nscikit-learn\ntorch\ntorchvision\ntensorflow-gpu==1.15.2\nmia\nadvertorch\ntorchxrayvision"
  },
  {
    "path": "requirements_new.txt",
    "chars": 1293,
    "preview": "absl-py==0.12.0\nadvertorch==0.2.3\nastor==0.8.1\nbackcall==0.2.0\ncached-property==1.5.2\ncertifi==2020.12.5\nchardet==4.0.0\n"
  },
  {
    "path": "test.py",
    "chars": 19383,
    "preview": "#Common imports\nimport os\nimport sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport pic"
  },
  {
    "path": "test_slab.py",
    "chars": 16473,
    "preview": "#Common imports\nimport os\nimport sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport pic"
  },
  {
    "path": "train.py",
    "chars": 15392,
    "preview": "#Common imports\nimport os\nimport sys\nimport numpy as np\nimport argparse\nimport copy\nimport random\nimport json\nimport pic"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/attribute_attack.py",
    "chars": 4343,
    "preview": "#General Imports\nimport sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json"
  },
  {
    "path": "utils/bnlearn_data.py",
    "chars": 7246,
    "preview": "import numpy as np\r\nimport pandas as pd\r\nimport tensorflow as tf\r\n\r\nimport urllib.request\r\nimport zipfile\r\nimport csv\r\ni"
  },
  {
    "path": "utils/helper.py",
    "chars": 11939,
    "preview": "import torch\nimport torch.utils.data as data_utils\n\n#Sklearn\nfrom sklearn.manifold import TSNE\n\n#Pytorch\nimport torch\nfr"
  },
  {
    "path": "utils/match_function.py",
    "chars": 15771,
    "preview": "import numpy as np\nimport torch\nimport time\n\n#Sklearn\nfrom scipy.stats import bernoulli\n\n# ##TODO: Update required for t"
  },
  {
    "path": "utils/privacy_attack.py",
    "chars": 4321,
    "preview": "#General Imports\nimport sys\nimport numpy as np\nimport pandas as pd\nimport argparse\nimport copy\nimport random\nimport json"
  },
  {
    "path": "utils/scripts/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/scripts/data_utils.py",
    "chars": 8730,
    "preview": "import sys\n\nimport random, os, copy, pickle, time, random, argparse, itertools\nfrom collections import defaultdict, Coun"
  },
  {
    "path": "utils/scripts/ensemble.py",
    "chars": 4527,
    "preview": "import os, copy, pickle, time\nimport random, itertools\nfrom collections import defaultdict, Counter, OrderedDict\nimport "
  },
  {
    "path": "utils/scripts/gendata.py",
    "chars": 14594,
    "preview": "import numpy as np\nimport scipy.stats as scs\nimport random\nfrom collections import Counter\nimport torch\nfrom torch.utils"
  },
  {
    "path": "utils/scripts/gpu_utils.py",
    "chars": 3721,
    "preview": "try: import pycuda.driver as cuda\nexcept: print (\"pycuda not available\")\n\nimport torch\nimport sys, os, glob, subprocess\n"
  },
  {
    "path": "utils/scripts/lms_utils.py",
    "chars": 11515,
    "preview": "import seaborn as sns\nimport utils.scripts.gpu_utils as gu\nimport utils.scripts.data_utils as du\nimport utils.scripts.ut"
  },
  {
    "path": "utils/scripts/mnistcifar_utils.py",
    "chars": 4052,
    "preview": "import random\nimport os, copy, pickle, time\nimport itertools\nfrom collections import defaultdict, Counter, OrderedDict\ni"
  },
  {
    "path": "utils/scripts/ptb_utils.py",
    "chars": 8154,
    "preview": "import seaborn as sns\nimport utils\nimport random\nimport os, copy, pickle, time\nimport itertools\nfrom collections import "
  },
  {
    "path": "utils/scripts/synth_models.py",
    "chars": 7023,
    "preview": "import sys, copy\nimport torch, torchvision\nfrom torch import optim, nn\nimport torch.nn.functional as F\nfrom torch.utils."
  },
  {
    "path": "utils/scripts/utils.py",
    "chars": 28016,
    "preview": "import torch\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport seaborn as sns\nimport pickle\nimport copy\nfrom col"
  },
  {
    "path": "utils/slab_data.py",
    "chars": 4305,
    "preview": "import sys\nimport random\nimport os, copy, pickle, time\nimport argparse\nimport itertools\nfrom collections import defaultd"
  }
]

About this extraction

This page contains the full source code of the microsoft/robustdg GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 125 files (3.1 MB), approximately 830.6k tokens, and a symbol index with 337 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!