Full Code of microsoft/msrflute for AI

main 8bfe0854ab29 cached
151 files
775.6 KB
202.7k tokens
773 symbols
1 requests
Download .txt
Showing preview only (822K chars total). Download the full file or copy to clipboard to get everything.
Repository: microsoft/msrflute
Branch: main
Commit: 8bfe0854ab29
Files: 151
Total size: 775.6 KB

Directory structure:
gitextract_qg20kqyy/

├── .flake8
├── .github/
│   └── workflows/
│       ├── build_docs.yml
│       └── codeql.yml
├── .gitignore
├── .gitmodules
├── CHANGELOG.md
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.TXT
├── NOTICE.txt
├── README.md
├── SECURITY.md
├── azure-pipelines.yml
├── configs/
│   ├── hello_world_mlm_bert_json.yaml
│   └── hello_world_nlg_gru_json.yaml
├── core/
│   ├── __init__.py
│   ├── client.py
│   ├── config.py
│   ├── dataloader.py
│   ├── dataset.py
│   ├── evaluation.py
│   ├── federated.py
│   ├── metrics.py
│   ├── model.py
│   ├── schema.py
│   ├── server.py
│   ├── strategies/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── dga.py
│   │   ├── fedavg.py
│   │   ├── fedlabels.py
│   │   └── utils.py
│   └── trainer.py
├── doc/
│   └── sphinx/
│       ├── Makefile
│       ├── advanced.rst
│       ├── class_reference.rst
│       ├── conf.py
│       ├── index.rst
│       ├── launch.rst
│       ├── make.bat
│       ├── overview.rst
│       ├── reference.rst
│       ├── requirements.txt
│       └── scenarios.rst
├── e2e_trainer.py
├── experiments/
│   ├── __init__.py
│   ├── classif_cnn/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── cifar_dataset.py
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   └── utils/
│   │       ├── centralized_training.py
│   │       └── download_and_convert_data.py
│   ├── cv/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── data.py
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   ├── model_vgg.py
│   │   └── server.py
│   ├── cv_cnn_femnist/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocess.py
│   │   └── model.py
│   ├── cv_lr_mnist/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocessing.py
│   │   └── model.py
│   ├── cv_resnet_fedcifar100/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocessing.py
│   │   ├── group_normalization.py
│   │   └── model.py
│   ├── ecg_cnn/
│   │   ├── .gitignore
│   │   ├── centralized_model.ipynb
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   ├── readme.md
│   │   └── utils/
│   │       └── preprocess.py
│   ├── fednewsrec/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocess_mind.py
│   │   ├── fednewsrec_model.py
│   │   ├── model.py
│   │   └── utils.py
│   ├── mlm_bert/
│   │   ├── README.md
│   │   ├── config.py
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   └── utils/
│   │       ├── trainer_pt_utils.py
│   │       └── trainer_utils.py
│   ├── nlg_gru/
│   │   ├── README.md
│   │   ├── config.py
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   └── utils/
│   │       └── utility.py
│   ├── nlp_rnn_fedshakespeare/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocessing.py
│   │   └── model.py
│   └── semisupervision/
│       ├── README.md
│       ├── config.yaml
│       ├── dataloaders/
│       │   ├── RandAugment.py
│       │   ├── cifar_dataset.py
│       │   ├── dataloader.py
│       │   └── dataset.py
│       └── model.py
├── extensions/
│   ├── RL/
│   │   └── RL.py
│   ├── __init__.py
│   ├── privacy/
│   │   ├── __init__.py
│   │   ├── analysis.py
│   │   ├── dp_kmeans.py
│   │   └── metrics.py
│   └── quantization/
│       └── quant.py
├── requirements.txt
├── testing/
│   ├── README.md
│   ├── build_vocab.py
│   ├── create_data.py
│   ├── hello_world_classif_cnn.yaml
│   ├── hello_world_ecg_cnn.yaml
│   ├── hello_world_mlm_bert.yaml
│   ├── hello_world_nlg_gru.yaml
│   └── test_e2e_trainer.py
└── utils/
    ├── __init__.py
    ├── data_utils.py
    ├── dataloaders_utils.py
    ├── optimizers/
    │   ├── adamW.py
    │   ├── lamb.py
    │   └── lars.py
    ├── preprocessing/
    │   ├── create-hdf5.py
    │   ├── create-json.py
    │   └── from_json_to_hdf5.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .flake8
================================================
[flake8]
ignore = E501

================================================
FILE: .github/workflows/build_docs.yml
================================================
name: Build docs

on:
  push:
    branches: [ main ]
  pull_request:
    branches: [ main ]

  workflow_dispatch:

jobs:
  build:
    runs-on: ubuntu-latest

    steps:
      - uses: actions/checkout@v2
      
      - name: Sphinx build
        uses: ammaraskar/sphinx-action@0.4
        with:
          docs-folder: doc/sphinx/

      - name: Commit documentation changes
        run: |
          git clone https://github.com/microsoft/msrflute --branch gh-pages --single-branch gh-pages
          cp -r doc/sphinx/_build/html/* gh-pages/
          cd gh-pages
          git config --local user.email "action@github.com"
          git config --local user.name "GitHub Action"
          git add .
          git commit -m "Update documentation" -a || true
    
      - name: Push changes
        uses: ad-m/github-push-action@master
        with:
          branch: gh-pages
          directory: gh-pages
          github_token: ${{ secrets.GITHUB_TOKEN }}


================================================
FILE: .github/workflows/codeql.yml
================================================
# This is based on the standard CodeQL workflow provided by Github
name: "CodeQL"

on:
  push:
    branches: [ "main" ]
  pull_request:
    # The branches below must be a subset of the branches above
    branches: [ "main" ]
  schedule:
    - cron: '35 2 * * 3'

jobs:
  analyze:
    name: Analyze
    runs-on: ubuntu-latest
    permissions:
      actions: read
      contents: read
      security-events: write

    strategy:
      fail-fast: false
      matrix:
        language: [ 'python' ]

    steps:
    - name: Checkout repository
      uses: actions/checkout@v3

    - name: Set-up MPI
      uses: mpi4py/setup-mpi@v1

    # Initializes the CodeQL tools for scanning.
    - name: Initialize CodeQL
      uses: github/codeql-action/init@v2
      with:
        languages: ${{ matrix.language }}
        
    # Autobuild attempts to build any compiled languages  (C/C++, C#, or Java).
    # If this step fails, then you should remove it and run the build manually (see below)
    - name: Autobuild
      uses: github/codeql-action/autobuild@v2

    - name: Perform CodeQL Analysis
      uses: github/codeql-action/analyze@v2


================================================
FILE: .gitignore
================================================
__pycache__/
.vscode/
doc/sphinx/_build
testing/logs.txt
testing/outputs
testing/mockup

================================================
FILE: .gitmodules
================================================
[submodule "utils/dp-accountant"]
	path = utils/dp-accountant
	url = https://github.com/microsoft/prv_accountant


================================================
FILE: CHANGELOG.md
================================================
# Changelog

All notable changes to this project will be documented in this file.

## [0.1.0] - 2021-11-22

We're super excited to announce FLUTE: Federated Learning Utilities for Testing and Experimentation, a platform for conducting high-performance federated learning simulations!

This first release fully focuses on implementing fast prototyping to validate different CL scenarios 
in an Federated environment.

### Features

- large scale simulation (millions of clients, sampling tens of thousands per round).
- multi-GPU and multi-node orchestration backed up by MPI.
- local or global differential privacy.
- model quantization.
- a variety of standard optimizers and aggregation methods.
- most model types including CNNs, RNNs, and Huggingface Transformers.
- extensibility, enabling new models, dataloaders, optimizers, and aggregators.
- local or cloud-based job staging using AzureML.


## [1.0.0] - 2022-08-29

This release contain major changes in the communication backbone , in order
to run previous experiments you have already integrated in FLUTE, please make sure
to use `torch.distributed` instead of `MPI `to launch the jobs. For more documentation
about the new command, please refer to the [README](README.md).


### New features

- 🏎 Better performance: Support for NCCL and Gloo as backend communication protocols. 
  - Improvements in GPU utilization and overall communication speed (on the order of minutes!) for projects with huge models and datasets.
- 🌟 Remove file type dependency on client.py, now FLUTE can receive any kind of dataset and even download the data on-the-fly. The data intantiation is completely under control of each task dataset.
  - In older versions FLUTE only allowed `json` and `hdf5` files, so the client could recognize it.
- 🌟 Abstract classes for new models/dataloaders.
- 🌟 Allows Federated Learning with Personalization. 
  - Personalization allows you to leverage each client local data to obtain models that are better adjusted to their own data distribution. You can run the `cv` task in order to try out this feature.


## [1.0.1] - 2023-07-29

🔋 This release removes the restriction of the minimum number of GPUs available in FLUTE, 
allowing users to run experiments using a single-GPU worker by instantiating both: Server
and clients on the same device. For more documentation about how to run an experiments
using a single GPU, please refer to the [README](README.md).


### New features

- 🌟 Include FedProx aggregation method



================================================
FILE: CITATION.cff
================================================
cff-version: 1.2.0
message: "To cite Microsoft FLUTE in academic papers, please cite it as below."
authors:
  - name: "Microsoft Research"
title: "FLUTE: Federated Learning Utilities for Testing and Experimentation"
version: 1.0.0
date-released: "2021-22-11"
url: "https://github.com/microsoft/msrflute"
license:
 - MIT
keywords:
  - FLUTE
  - federated learning


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Microsoft Open Source Code of Conduct

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).

Resources:

- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing

This project welcomes contributions and suggestions. Most contributions require you to
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
and actually do, grant us the rights to use your contribution. For details, visit
https://cla.microsoft.com.

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.

### Pull Requests

Submit pull requests to **branch contribution**. PR's in any other branch will not be accepted.

When you submit a pull request, a CLA-bot will automatically determine whether you need
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.



================================================
FILE: LICENSE.TXT
================================================
Copyright (c) Microsoft Corporation.

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

================================================
FILE: NOTICE.txt
================================================
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
Do Not Translate or Localize

This software incorporates components from the projects listed below. The original copyright notices
and the licenses under which Microsoft received such components are set forth below and are provided for 
informational purposes only. Microsoft reserves all rights not expressly granted herein, whether by 
implication, estoppel or otherwise.

This software includes parts of the Huggingface/Transformers Library (https://github.com/huggingface/transformers). 
State-of-the-art of  Natural Language Processing for Jax, PyTorch and TensorFlow. Huggingface/Transformers library is 
licensed under Apache License 2.0, you can find a copy of this license at https://github.com/huggingface/transformers/blob/master/LICENSE

This software includes parts of the Tensorflow/Privacy Library (https://github.com/tensorflow/privacy). 
A library that includes implementations of TensorFlow optimizers for training machine learning models with
differential privacy. The Tensorflow/Privacy library is licensed under Apache License 2.0, 
you can find a copy of this license at https://github.com/tensorflow/privacy/blob/master/LICENSE

This software includes parts of LEAF Library (https://github.com/TalwalkarLab/leaf).
A Benchmark for Federated Settings. LEAF library is licensed under BSD 2-Clause License, you can find a copy
of this license at https://github.com/TalwalkarLab/leaf/blob/master/LICENSE.md

This software includes parts of ECG Classification from Kaggle Competition 
(https://www.kaggle.com/polomarco/ecg-classification-cnn-lstm-attention-mechanism). 
An example for ECG Classification | CNN LSTM Attention Mechanism. This example is 
licensed under Apache License 2.0, you can find a copy of this license at 
https://www.apache.org/licenses/LICENSE-2.0 

This software includes parts of Torchvision Library (https://github.com/pytorch/vision.git). A package of
popular datasets, model architectures, and common image transformations for computer vision. This example
is licenced under BSD 3-Clause License, you can find a copy of this licence at 
https://github.com/pytorch/vision/blob/main/LICENSE

This software includes parts of FedML Library (https://github.com/FedML-AI/FedML).The Community 
Building Open and Collaborative AI Anywhere at Any Scale. FedML library is licensed under Apache License 2.0, 
you can find a copy of this license at https://github.com/FedML-AI/FedML/blob/master/LICENSE

This software includes parts of FedNewsRec-EMNLP-Findings-2020 repository (https://github.com/taoqi98/FedNewsRec).  
Code from the paper "Privacy-Preserving News Recommendation Model Learning". This example is licenced 
under MIT License, you can find a copy of this licence at https://github.com/taoqi98/FedNewsRec/blob/master/LICENSE

This software includes parts of Fast AutoAugment repository (https://github.com/kakaobrain/fast-autoaugment).  
Code from the paper "Fast AutoAugment" (Accepted at NeurIPS 2019). This example is licenced 
under MIT License, you can find a copy of this licence at https://github.com/kakaobrain/fast-autoaugment/blob/master/LICENSE

This software includes parts of NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench).  
Code from the paper "Federated Learning on Non-IID Data Silos: An Experimental Study". This example is 
licenced under MIT License, you can find a copy of this licence at https://github.com/Xtra-Computing/NIID-Bench/blob/main/LICENSE


================================================
FILE: README.md
================================================
# FLUTE

Welcome to FLUTE (Federated Learning Utilities for Testing and Experimentation), a platform for conducting high-performance federated learning simulations.

## Features

FLUTE is a pytorch-based orchestration environment enabling GPU or CPU-based FL simulations.  The primary goal of FLUTE is to enable researchers to rapidly prototype and validate their ideas.  Features include:

- large scale simulation (millions of clients, sampling tens of thousands per round)
- single/multi GPU and multi-node orchestration
- local or global differential privacy
- model quantization
- a variety of standard optimizers and aggregation methods
- most model types including CNNs, RNNs, and Huggingface Transformers.
- extensibility, enabling new models, dataloaders, optimizers, and aggregators.
- local or cloud-based job staging using AzureML

## Benchmarking

The following common tasks were used to evaluate the performance in speed/memory utilization of FLUTE compared with the most representative simulation platforms based on their number of starts on GitHub: FedML 0.7.303 and Flower 1.0.0. 

|Task|Data Set|Model|Algorithm|# Clients|Clients per round|Batch Size|Client Optimizer|lr|Epochs|# Rounds|Test Freq|
|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|
|CV|MNIST|LR|FedAvg|1000|10|10|SGD|0.03|1|100|20|
|CV|Federated EMNIST|CNN (2 Conv + 2 FC)|FedAvg|3400|10|20|SGD|0.1|1|1500|50|
|CV|FED_CIFAR-100|ResNet-18+group normalization|FedAvg|500|10|20|SGD|0.1|1|4000|50|
|NLP|Shakespeare|RNN (2 LSTM + 1 FC)|FedAvg|715|10|4|SGD|0.8|1|1200|50|

### FedML Comparison

This comparison was carried out using Parrot (Simulator) on version 0.7.303 at commit ID [8f7f261f](https://github.com/FedML-AI/FedML/tree/8f7f261f44e58d0cb5a416b0d6fa270b42a91049). Showing that in some cases FLUTE can outperform 43x faster.

```
 _____________________________________________________________________________
|                    |   FedML (MPI) - Fastest   |   FLUTE (NCCL)  - Fastest  |
| Task               | Acc | Time     | GPU Mem  | Acc | Time     | GPU Mem   |
|--------------------|-----|----------|----------|-----|----------|-----------|
| LR_MNIST           | ~81 | 00:03:09 | ~3060 MB | ~81 | 00:01:35 | ~1060 MB  |
| CNN_FEMNIST        | ~83 | 05:49:52 | ~5180 MB | ~83 | 00:08:22 | ~1770 MB  |
| RESNET_FEDCIFAR100 | ~34 | 15:55:36 | ~5530 MB | ~33 | 01:42:01 | ~1900 MB  |
| RNN_FEDSHAKESPEARE | ~57 | 06:46:21 | ~3690 MB | ~57 | 00:21:50 | ~1270 MB  |
 -----------------------------------------------------------------------------
```

You can find the examples above in [experiments](experiments).

### Flower Comparison

This comparison was carried out using Flower (Simulator) on version 1.0.0 at commit ID [4e7fad9](https://github.com/adap/flower/tree/4e7fad99389a5ee511730841b61f279e3359cb16) with the [lr_mnist](experiments/cv_lr_mnist/) task. Showing that in some cases FLUTE can outperform 53x faster.

```
 ________________________________________________
|        |    Flower (Ray)   | FLUTE (NCCL/Gloo) |
|        | Acc |    Time     | Acc |    Time     |
|--------|-----|-------------|-----|-------------|
| CPU    | ~80 |   00:30:14  | ~80 |   00:03:20  |
| GPU 2x | ~80 |   01:21:44  | ~80 |   00:01:31  |
| GPU 4x | ~79 |   00:56:45  | ~81 |   00:01:26  |
 ------------------------------------------------
```

You can find the example above in the [cv_lr_mnist](experiments/cv_lr_mnist/) folder.

## Quick Start

Install the requirements stated inside of `requirements.txt`. Ideally this sould be done inside of a virtual environment, for instance, using Anaconda.

```
conda create -n FLUTE python==3.7
pip install -r requirements.txt
```

FLUTE uses torch.distributed API as its main communication backbone, supporting three built-in backends. For more information please refer to [Distributed Communication Package](https://pytorch.org/docs/stable/distributed.html). Therefore, we highly suggest to use NCCL backend for distributed GPU training and Gloo for distributed CPU training. There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.

After this initial setup, you can use the data created for the integration test for a first local run. Note that this data needs to be download manually inside the `testing` folder, for more instructions please look at [the README file inside `testing`](testing/README.md).

For single-GPU runs:

```
python -m torch.distributed.run --nproc_per_node=1 e2e_trainer.py -dataPath ./testing -outputPath scratch -config testing/hello_world_nlg_gru.yaml -task nlg_gru -backend nccl
```

For multi-GPU runs (3 GPUs):

```
python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath ./testing -outputPath scratch -config testing/hello_world_nlg_gru.yaml -task nlg_gru -backend nccl
```

The config file `testing/hello_world_nlg_gru.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs.

## Documentation

Online documentation is available at https://microsoft.github.io/msrflute/

Locally, the documentation is inside the `doc/sphinx` folder. To build the docs on Linux:

```
$ pip install sphinx
$ cd doc/sphinx
$ make html
```

On Windows, you can use the `make.bat` script.  It may be necessary to `export PYTHONPATH=../../` for sphinx to find the code.

## Architecture

The core client/server training code is inside the `core` folder.

- Server-side federation and global DP application takes place in `server.py`, more specifically in the `OptimizationServer.train()` method.
- Client-side training updates take place in the static method `Client.process_round()`, inside `client.py`.

General FL orchestration code is in `federated.py`, but for most hub and spoke federation scenarios you won't need to touch this (unless you want to invest in optimizing server-client calls, which would be great!). Note that FLUTE does not implement secure aggregation since this is primarily a security feature for production scenarios; contributors are invited to add it for experimentation purposes.

The primary entry point for an experiment is in the script `e2e_trainer.py`. Primary config scripts for experiments are in `configs`. For instance, a basic training scenario for a next-word prediction task is set up in `hello_world_nlg_gru_json.yaml`.

Privacy accounting is expensive so the main parameters are logged and the actual accounting can be done offline. RDP privacy accounting is in `extensions/privacy/analysis.py`. A better accounting method is in the `dp-accountant` submodule.

## Customization

See `experiments` folder for illustrations of how dataloaders and models are customized. In order to in include a new experiment, the new scenario must be added following the same folder structure as `nlg_gru` and `mlm_bert`, naming the folder with the task.

## Experiments

Experiments are defined by YAML files, examples are provided in the `configs` folder. These can be run either locally or on AzureML.

For running experiments on AzureML, the CLI can help. You should first [install the CLI](https://docs.microsoft.com/en-us/azure/machine-learning/reference-azure-machine-learning-cli) (make sure you have v2) and [create a resource group and workspace](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-manage-workspace-cli?tabs=createnewresources%2Cvnetpleconfigurationsv1cli). You can then create a compute cluster, type `az ml compute create -h` for more info. Afterwards, you should write an YAML file with instructions for the job; we provide a simple example below

```yaml
experiment_name: basic_example
description: Basic example of AML config for submitting FLUTE jobs
code:
  local_path: .
compute: azureml:Test
environment:
  image: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel
inputs:
  data:
    folder: azureml://datastores/data/paths/cifar
    mode: rw_mount
command: >
  apt -y update &&
  apt -y install openmpi-bin libopenmpi-dev openssh-client &&
  python3 -m pip install --upgrade pip &&
  python3 -m pip install -r requirements.txt &&
  python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py
  -outputPath=./outputs
  -dataPath={inputs.data}
  -task=classif_cnn
  -config=./experiments/classif_cnn/config.yaml
  -backend=nccl
```

You should replace `compute` with the name of the one you created before, and adjust the path of the datastore containing the data -- in the example above, we created a datastore called `data` and added to it a folder called `cifar`, which contained the two HDF5 files. The command passed above will install dependencies and then launch a distributed job with 4 threads, for the experiment defined in `experiments/classif_cnn`. Details on how to run a job using the AzureML CLI are given [in its documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli), but typically it suffices to set up the environment and type `az ml job create -f <name-of-the-yaml-file>`. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.

Note that the `local_path` above is relative to the location of the YAML file, so setting it to `.` assumes it is in the same folder as `e2e_trainer.py`. All files on this folder will be uploaded to Azure, including hidden folders such as `.git`, so make sure to temporarily get rid of large files and folders that are not needed.

After launching the experiment, you can follow it on AzureML Studio, which prints logs, plots metrics and makes the output easily available after the experiment is finished.

## Privacy Accounting

Accounting is expensive, so we log all the privacy parameters so that accounting can be run offline. Best run on a Linux box with a GPU.
In particular, we use a DP accountant from another Microsoft repository, which is included in ours as a submodule. For using this accountant, just follow the instructions below:

```
$ git submodule update --init --recursive
$ cd utils
$ cd dp-accountant
$ python setup.py install
$ ./bin/compute-dp-epsilon --help
usage: compute-dp-epsilon [-h] -p SAMPLING_PROBABILITY -s NOISE_MULTIPLIER -i ITERATIONS -d DELTA
```
## Third Party Notice

This software includes the files listed below from the Huggingface/Transformers Library (https://github.com/huggingface/transformers) as part of task performance and preprocessing pretrained models.

    experiments/mlm_bert
    └── utils
        ├── trainer_pt_utils.py
        └── trainer_utils.py

This software includes the file extensions/privacy/analysis.py from the Tensorflow/Privacy Library (https://github.com/tensorflow/privacy) as part of Renyi Differential Privacy implementation.

This software includes the script testing/build_vocab.py from LEAF Library (https://github.com/TalwalkarLab/leaf) to create the vocabulary needed to run a testing job. 

This software includes the model implementation of the example ECG Classification | CNN LSTM Attention Mechanism from Kaggle Competition (https://www.kaggle.com/polomarco/ecg-classification-cnn-lstm-attention-mechanism) to reproduce the [ecg_cnn](experiments/ecg_cnn/model.py) experiment.

This software includes the model implementation of the FedNewsRec repository (https://github.com/taoqi98/FedNewsRec)| Code from the paper "Privacy-Preserving News Recommendation Model Learning" (https://arxiv.org/abs/2003.09592) ported to PyTorch framework to reproduce the [fednewsrec](experiments/fednewsrec/model.py) experiment.
For more information about third-party OSS licence, please refer to [NOTICE.txt](NOTICE.txt).

This software includes the Data Augmentation scripts of the Fast AutoAugment repository (https://github.com/kakaobrain/fast-autoaugment) to preprocess the data used in the [semisupervision](experiments/semisupervision/dataloaders/cifar_dataset.py) experiment.

This software included the FedProx logic implementation of the NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench/tree/main) as Federated aggregation method used in the [trainer](core/trainer.py) object.
## Support

You are welcome to open issues on this repository related to bug reports and feature requests.

## Contributing

Contributions are welcomed and encouraged. For details on how to contribute, please see [CONTRIBUTING.md](CONTRIBUTING.md).




================================================
FILE: SECURITY.md
================================================
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->

## Security

Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).

If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.

## Reporting Security Issues

**Please do not report security vulnerabilities through public GitHub issues.**

Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).

If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).

You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 

Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:

  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
  * Full paths of source file(s) related to the manifestation of the issue
  * The location of the affected source code (tag/branch/commit or direct URL)
  * Any special configuration required to reproduce the issue
  * Step-by-step instructions to reproduce the issue
  * Proof-of-concept or exploit code (if possible)
  * Impact of the issue, including how an attacker might exploit the issue

This information will help us triage your report more quickly.

If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.

## Preferred Languages

We prefer all communications to be in English.

## Policy

Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).

<!-- END MICROSOFT SECURITY.MD BLOCK -->


================================================
FILE: azure-pipelines.yml
================================================
trigger:
- main

pool:
  vmImage: 'windows-latest'

steps:
- task: CredScan@2
  inputs:
    toolMajorVersion: 'V2'

- task: Semmle@1
  env: 
    SYSTEM_ACCESSTOKEN: $(System.AccessToken)
  inputs:
    sourceCodeDirectory: '$(Build.SourcesDirectory)'
    language: 'python'
    querySuite: 'Recommended'
    timeout: '1800'
    ram: '16384'
    addProjectDirToScanningExclusionList: true

- task: ComponentGovernanceComponentDetection@0
  inputs:
    scanType: 'Register'
    verbosity: 'Verbose'
    alertWarningLevel: 'High'

- task: PublishSecurityAnalysisLogs@2
  inputs:
    ArtifactName: 'CodeAnalysisLogs'
    ArtifactType: 'Container'
    AllTools: true
    ToolLogsNotFoundAction: 'Standard'

================================================
FILE: configs/hello_world_mlm_bert_json.yaml
================================================
# Basic configuration file for running mlm_bert example using json files.
# Parameters needed to initialize the model
model_config:
    model_type: BERT 
    model_folder: experiments/mlm_bert/model.py
    BERT:
        loader_type: text
        model:
            model_name: roberta-large
            cache_dir: ./cache_dir
            use_fast_tokenizer: False
            mask_token: <mask>
            task: mlm
            past_index: -1
            prediction_loss_only: false
            process_line_by_line: false
        training:
            seed: 12345
            label_smoothing_factor: 0  
            batch_size: 64
            max_seq_length: 256            

# Configuration for differential privacy
dp_config:
    enable_local_dp: false  # If enabled, the rest of parameters is needed. 
    enable_global_dp: false # Local dp clips and adds noise on the client and centrally accumulates the privacy budget
    eps: 100                # epsilon
    global_sigma: 0.35      # Used when global dp es enabled, specifies the global Gaussian noise
    weight_scaler: 0.0001   # indicates how the aggregation weights scaled before noise addition, and unscaled afterwards.
    max_grad: 0.008         # max gradient
    max_weight: 0.5         # The max_weight and min_weight should be already scaled by weight_scaler
    min_weight: 0.0000001   # Because we scale down the weight using weight_scalar -> clip -> add noise -> scale back up.

# Additional privacy metrics
privacy_metrics_config:
    apply_metrics: false    # If enabled, the rest of parameters is needed. 

# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: DGA

# Determines all the server-side settings for training and evaluation rounds
server_config:
    resume_from_checkpoint: true                    # Resumes from latest checkpoint iteration if available 
    do_profiling: false                             # Capture profiling information during server updates.
    fast_aggregation: true                          
    wantRL: false                                   # Enable/Disable Reinforcement learning
    RL:                                             # Reinforcement Learning parameters
        RL_path_global: false
        marginal_update_RL: true
        RL_path: ./RL_models
        model_descriptor_RL: marginalUpdate
        network_params: 300,128,128,128,64,100
        initial_epsilon: 0.5
        final_epsilon: 0.0001
        epsilon_gamma: 0.90
        max_replay_memory_size: 1000
        minibatch_size: 16
        gamma: 0.99
        optimizer_config:
            lr: 0.0003
            type: adam
            amsgrad: true
        annealing_config:
            type: step_lr
            step_interval: epoch
            step_size: 1
            gamma: 0.95
    optimizer_config:                               # Configuration for server-side optimizer
        lr: 0.00001                                 
        weight_decay: 0.01
        type: adamW
    annealing_config:                               # This section configures how the learning rate decays
        type: step_lr
        step_interval: epoch
        gamma: 1.0
        step_size: 1000
    val_freq: 4                                     # Frequency for validation rounds
    rec_freq: 16                                    # Frequency for testing rounds
    initial_val : true                              # Enable initial validation round at itr=0
    initial_rec: false                              # Enable initial testing round at itr=0
    max_iteration: 10000                            # Total number of rounds for FL
    num_clients_per_iteration: 200                  # Number of clients sampled per round
    data_config:                                    # Server-side data configuration
        val:                                        # Validation data
            val_data: <add path to data here>
            task: mlm
            mlm_probability: 0.25
            tokenizer_type_fast: False
            batch_size: 128
            max_seq_length: 256
            min_words_per_utt: 5
            max_samples_per_user: 5000
            mask_token: <mask>
            num_workers: 0
            prepend_datapath: false
            cache_dir: ./cache_dir
        # Note this is NOT the main training data configuration, which is configured in the 
        # client config.  This section is ignored unless you are running replay data.
        # If you want to run replay data- set a path name for train_data_server.
        # train:
        #     loader_type: text
        #     train_data: null
        #     train_data_server: null
        #     desired_max_samples: null
        test:                                       # Test data configuration
            test_data: <add path to data here>
            task: mlm
            mlm_probability: 0.25
            tokenizer_type_fast: False
            batch_size: 128
            max_seq_length: 256
            max_samples_per_user: 5000
            mask_token: <mask>
            num_workers: 0
            prepend_datapath: false
            cache_dir: ./cache_dir
    type: model_optimization                        # Server type
    aggregate_median: softmax                       # FL aggregation method
    weight_train_loss: mag_mean_loss                # Determines how each client's weight is computed (e.g. grad_mean_loss, train_loss)
    softmax_beta: 1.00                              
    initial_lr_client: 0.00001
    lr_decay_factor: 1.0
    best_model_criterion: loss                      # Determine the best model based on minimal loss, for checkpointing
    fall_back_to_best_model: false                  # If a model degrades, use the previous best model
    # server_replay_config:                           # This is only applies if the server-side training data is fully configured and loaded
    #     server_iterations: 50
    #     optimizer_config:
    #         lr: 0.00002
    #         amsgrad: true
    #         type: adam

# Dictates the learning parameters for client-side model updates. Train data is defined inside this config.
client_config:
    meta_learning: basic
    stats_on_smooth_grad: true
    ignore_subtask: false
    copying_train_data: false
    do_profiling: false                             # Enables client-side training profiling
    data_config:
        train:                                      # This is the main training data configuration
            list_of_train_data: <add path to data here>
            task: mlm
            mlm_probability: 0.25
            tokenizer_type_fast: False
            batch_size: 24
            max_seq_length: 256
            min_words_per_utt: 5
            desired_max_samples: 5000
            mask_token: <mask>
            num_workers: 0
            num_frames: 0
            max_grad_norm: 15.0
            prepend_datapath: false
            cache_dir: ./cache_dir
            pin_memory: true
    type: optimization
    meta_optimizer_config:
        lr: 0.01
        type: adam
    optimizer_config:
        type: adamW
        weight_decay: 0.01
        amsgrad: true
    annealing_config:
        type: step_lr
        step_interval: epoch
        step_size: 2
        gamma: 1.0

================================================
FILE: configs/hello_world_nlg_gru_json.yaml
================================================
# Basic configuration file for running nlg_gru example using json files.
# Parameters needed to initialize the model
model_config: 
    model_type: GRU
    model_folder: experiments/nlg_gru/model.py
    pretrained_model_path: <add path to pretrained weights here>
    embed_dim: 160
    vocab_size: 10000
    hidden_dim: 512
    OOV_correct: false

# Configuration for differential privacy
dp_config:
    enable_local_dp: false      # If enabled, the rest of parameters is needed. 
    # enable_local_dp: true     # Local dp clips and adds noise on the client and centrally accumulates the privacy budget
    # eps: 100                  # epsilon
    # max_grad: 0.008           # max gradient
    # weight_scaler: 0.0001     # indicates how the aggregation weights scaled before noise addition, and unscaled afterwards.
    # max_weight: 0.0001        # The max_weight and min_weight should be already scaled by weight_scaler
    # min_weight: 0.00009       # Because we scale down the weight using weight_scalar -> clip -> add noise -> scale back up.

# Additional privacy metrics
privacy_metrics_config:
    apply_metrics: false             # If enabled, the rest of parameters is needed. 
    # apply_indices_extraction: true   # If we extract word indices we want to consider the rank of the words extracted.
    # allowed_word_rank: 9000          # Any word that rank above this value is considered privacy risk
    # apply_leakage_metric: true
    # max_leakage: 30
    # max_allowed_leakage: 3
    # adaptive_leakage_threshold: 0.95 # Takes the 95th percentile of the leakage for the next round.
    # is_leakage_weighted: true
    # attacker_optimizer_config:
    #     lr: 0.03
    #     type: adamax
    #     amsgrad: false

# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
strategy: FedProx

# Determines all the server-side settings for training and evaluation rounds
server_config:   
    wantRL: false                   # Enable/Disable Reinforcement learning
    resume_from_checkpoint: true    # Resumes from latest checkpoint iteration if available 
    do_profiling: false             # Capture profiling information during server updates.
    optimizer_config:               # Configuration for server-side optimizer
        type: lamb
        lr: 0.1
        weight_decay: 0.005
    annealing_config:               # This section configures how the learning rate decays
        type: step_lr
        step_interval: epoch
        gamma: 1.0
        step_size: 100
    val_freq: 2                     # Frequency for validation rounds
    rec_freq: 4                     # Frequency for testing rounds
    initial_val : true              # Enable initial validation round at itr=0
    initial_rec: false             # Enable initial testing round at itr=0
    max_iteration: 11               # Total number of rounds for FL
    num_clients_per_iteration: 10   # Number of clients sampled per round
    data_config:                    # Server-side data configuration
        val:                        # Validation data
            batch_size: 2048
            tokenizer_type: not_applicable
            prepend_datapath: false
            val_data: <add path to data here>       # Path for validation data
            vocab_dict: <add path to vocab here>    # Path for vocabulary
            pin_memory: true
            num_workers: 0                          # Indicates how many workers are used for creating batches
            num_frames: 2400                        
            max_batch_size: 2048
            max_num_words:  25
            unsorted_batch: true
        # Note this is NOT the main training data configuration, which is configured in the 
        # client config.  This section is ignored unless you are running replay data.
        # If you want to run replay data- set a path name for train_data_server.
        # train:                                      
        #     batch_size: 128
        #     loader_type: text
        #     tokenizer_type: not_applicable
        #     prepend_datapath: false
        #     train_data: null
        #     train_data_server: null
        #     vocab_dict: <add path to vocab here>
        #     pin_memory: true
        #     num_workers: 0
        #     num_frames: 2400
        #     desired_max_samples: 500
        #     max_grad_norm: 10.0
        #     max_batch_size: 128
        #     max_num_words:  25
        #     unsorted_batch: true
        test:                                       # Test data configuration
            batch_size: 2048
            tokenizer_type: not_applicable
            prepend_datapath: false
            train_data: null
            train_data_server: null
            test_data: <add path to data here>      # Path for validation data
            vocab_dict: <add path to vocab here>    # Path for vocabulary
            pin_memory: true
            num_workers: 0                          # Indicates how many workers are used for creating batches
            max_batch_size: 2048
            max_num_words:  25
            unsorted_batch: true
    type: model_optimization
    aggregate_median: softmax                       # FL aggregation method
    weight_train_loss: train_loss                   # Determines how each client's weight is computed (e.g. grad_mean_loss, train_loss)
    softmax_beta: 20.0
    initial_lr_client: 1.0
    lr_decay_factor: 1.0
    best_model_criterion: loss                      # Determine the best model based on minimal loss, for checkpointing
    fall_back_to_best_model: false                  # If a model degrades, use the previous best model
    # server_replay_config:                           # This is only applies if the server-side training data is fully configured and loaded
    #     server_iterations: 50
    #     optimizer_config:
    #         type: adam
    #         lr: 0.00002
    #         amsgrad: true
    
# Dictates the learning parameters for client-side model updates. Train data is defined inside this config.
client_config:
    mu: 0.001                                           # Used only for FedProx aggregation method
    meta_learning: basic
    stats_on_smooth_grad: true
    ignore_subtask: false
    num_skips_threshold: 10
    copying_train_data: false
    do_profiling: false                                 # Enables client-side training profiling
    data_config:
        train:                                          # This is the main training data configuration
            batch_size: 64
            tokenizer_type: not_applicable
            prepend_datapath: false
            list_of_train_data: <add path to data here> # Path to training data
            vocab_dict: <add path to vocab here>        # Path to vocabulary
            pin_memory: true
            num_workers: 0
            desired_max_samples: 50000
            max_grad_norm: 20.0
            max_batch_size: 128
            max_num_words:  25
            unsorted_batch: true
    type: optimization
    meta_optimizer_config:
        lr: 1.0
        type: sgd
    optimizer_config:
        type: sgd
    annealing_config:
        type: step_lr
        step_interval: epoch
        step_size: 1
        gamma: 1.0

================================================
FILE: core/__init__.py
================================================


================================================
FILE: core/client.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
The Client object is short-lived, instantiated inside workers 1 to N for 
processing a given client's data. It's main method is the `process_round` 
function, used to update the model given a client's data.
'''

import copy
import logging
import os
import time

from easydict import EasyDict as edict
from importlib.machinery import SourceFileLoader
import numpy as np
import torch

# Internal imports
import core.federated as federated
from .strategies import select_strategy
from .trainer import (
    Trainer,
    run_validation_generic,
    set_component_wise_lr,
)
from utils import (
    ScheduledSamplingScheduler,
    make_optimizer,
    print_rank,
    to_device,
    convex_inference,
    alpha_update,
)
from utils.dataloaders_utils import (
    make_train_dataloader,
    make_val_dataloader,
    make_test_dataloader,
    get_dataset,
)
import extensions.privacy
from extensions.privacy import metrics as privacy_metrics
from experiments import make_model

global train_dataset
global trainset_unlab
global trainset_unlab_rand

class Client:
    # It's unclear why, but sphinx refuses to generate method docs
    # if there is no docstring for this class.
    """Client class for specifying individual client training tasks"""

    def __init__(self, client_id, config, send_gradients):
        '''
        Client side processing: computing gradients, update the model and send them back to the server

        Args:
            client_id (int): identifier for grabbing that client's data.
            config (dict): dictionary with parameters loaded from config file.
            send_gradients (bool): if True, model gradients are sent back;
                otherwise, model weights are sent back.
        '''
        super().__init__()
        
        self.client_id = client_id
        self.config = copy.deepcopy(config)
        self.send_gradients = send_gradients

    def get_client_data(self, dataset=None):
        '''"Getter" method that returns all object's attributes at once.'''

        client_data = self.get_data(self.client_id, dataset)
        return self.client_id, client_data, self.config, self.send_gradients

    @staticmethod
    def get_train_dataset(data_path, client_train_config, task):
        '''This function will obtain the dataset for all training
        users.

        Args:
            data_path (str): path to file containing taining data.
            client_train_config (dict): trainig data config.
            task (str): task name.
        '''
        global train_dataset
        global trainset_unlab
        global trainset_unlab_rand

        train_dataset = get_dataset(data_path, client_train_config, task, mode="train")

        if task == 'semisupervision':
            trainset_unlab = get_dataset(data_path, client_train_config, task, mode="train", user_idx = -2)
            trainset_unlab_rand = get_dataset(data_path, client_train_config, task, mode="train", user_idx = -3)
        else:
            trainset_unlab = None
            trainset_unlab_rand = None

        return len(train_dataset.user_list)

    @staticmethod
    def get_data(clients, dataset):
        ''' Create training dictionary'''

        if dataset == None: # Training case
            datasets = [train_dataset, trainset_unlab, trainset_unlab_rand] if trainset_unlab != None else [train_dataset]
        else: # Evaluation case
            datasets = [dataset]

        data_with_labels = hasattr(datasets[0],"user_data_label")
        
        strcts = [] # Returning list length will always be 1 except when the task is semisupervision
        for dataset in datasets:
            input_strct = {'users': [], 'num_samples': [],'user_data': dict(), 'user_data_label': dict()} if data_with_labels else {'users': [], 'num_samples': [],'user_data': dict()}
            for client in clients:
                user = dataset.user_list[client]
                input_strct['users'].append(user)
                input_strct['num_samples'].append(dataset.num_samples[client])
                input_strct['user_data'][user]= dataset.user_data[user]
                if data_with_labels: 
                    input_strct['user_data_label'][user] = dataset.user_data_label[user]
            strcts.append(edict(input_strct))
        
        return strcts 

    @staticmethod
    def run_testvalidate(client_data, server_data, mode, model):
        '''Called by worker to run test/validation sample on a client.

        This functions assumes set_model_for_round has already been called to
        push the model to the client (see federated.py).

        Args:
            client_data (tuple): client data and config. It is a tuple with 3
                components; importantly, the second component is a dict
                containing the data, and the third component is a dict with the
                config parsed from the YAML file.
            server_data (tuple): server data (model parameters mostly). It is
                a tuple with 2 components; importantly, the second component
                consists of the current model parameters.
            mode (str): whether to `test` or `validate`.
            model (torch.nn.Module): actual model without parameters.
        '''

        # Process inputs and initialize variables
        _, data_strcts, config, _ = client_data
        _, model_parameters, iteration = server_data
        config = copy.deepcopy(config)
        model_path = config["model_path"]

        begin = time.time()  

        # Use the server's data config since we're distributing test/validate from the server
        data_strct = data_strcts[0]
        data_config = config['server_config']['data_config'][mode]
        want_logits = data_config.get('wantLogits', False)
        send_dicts = config['server_config'].get('send_dicts', False)

        # Create dataloader 
        dataloader = None
        print_rank('making dataloader with task {}'.format(config['server_config']['task']), loglevel=logging.DEBUG)
        if mode == 'test':
            dataloader = make_test_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)
        elif mode == 'val':
            dataloader = make_val_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)

        # Set model parameters
        n_layers, n_params = len([f for f in model.parameters()]), len(model_parameters)
        print_rank(f'Copying model parameters... {n_layers}/{n_params}', loglevel=logging.DEBUG)
        
        model = to_device(model)
        
        if send_dicts: # Send model state dictionary
            tmp = {}
            for param_key, param_dict in zip (model.state_dict(), model_parameters):
                tmp[param_key] = param_dict
            model.load_state_dict(tmp)
        else: # Send parameters
            for p, data in zip(model.parameters(), model_parameters):
                p.data = data.detach().clone().cuda() if torch.cuda.is_available() else data.detach().clone()

        print_rank(f'Model setup complete. {time.time() - begin}s elapsed.', loglevel=logging.DEBUG)

        # Compute output and metrics on the test or validation data
        num_instances = sum(data_strct['num_samples'])
        print_rank(f'Validating {num_instances}', loglevel=logging.DEBUG)
        output, metrics = run_validation_generic(model, dataloader)
        
        # Load local model if necessary
        if config['server_config']['type']=='personalization':

            local_model = make_model(config['model_config'])
            user = data_strct['users'][0]

            local_model_name = os.path.join(model_path, user + '_model.tar')

            if os.path.exists(local_model_name):
                print_rank('Loading Local Model .. {}'.format(local_model_name))
                checkpoint = torch.load(local_model_name)
                local_model.load_state_dict(checkpoint["model_state_dict"])

                local_alpha_name = os.path.join(model_path, user + '_alpha')
                if os.path.exists(local_alpha_name):
                    alpha = torch.load(local_alpha_name)
                    print_rank('Loading Alpha Weight from {}: Value={}'.format(local_model_name, alpha))

                    # Run inference and get logits back
                    if mode == 'test':
                        dataloader = make_test_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)
                    elif mode == 'val':
                        dataloader = make_val_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)

                    output_local, local_metrics = run_validation_generic(local_model, dataloader)
                    loss_local = local_metrics['loss']['value']
                    cer = local_metrics['acc']['value']
                    # Combine logits
                    cer =convex_inference(output, output_local, alpha=alpha)
                    metrics['loss']['value'] = (metrics['loss']['value'] + loss_local) / 2 
                    metrics['acc']['value'] = cer
        output = None if not want_logits else output

        return output, metrics, num_instances



    @staticmethod
    def process_round(client_data, server_data, model, data_path, eps=1e-7):
        '''Compute gradients given client's data and update model.

        Args:
            client_data (tuple): client data and config. It is a tuple
                consisting of 4 components: an int indicating the client's id, a
                dict containing that client's data, a dict with the config
                parsed from the YAML file, and a bool indicating whether or not
                gradients should be sent.
            server_data (tuple): server data (model parameters mostly). It is
                a tuple consisting of 2 components; importantly, the first is
                a float giving the client's learning rate, and the second a list
                of torch.Tensor's with current model parameters. 
            model (torch.nn.Module): actual model without parameters.
            data_path (str): where to get data from.
            eps (float): lower bound for aggregation weights.
        '''

        # Ensure the client is assigned to the correct GPU
        if torch.cuda.is_available() and torch.cuda.device_count() == federated.size():
            torch.cuda.set_device(federated.local_rank())

        # Process inputs and initialize variables
        client_id, data_strcts, config, send_gradients = client_data
        initial_lr, model_parameters, iteration = server_data
        config = copy.deepcopy(config)

        model_config = config['model_config']
        client_config = config['client_config']
        data_config = client_config['data_config']['train']
        semisupervision_config = client_config.get('semisupervision',None)
        task = client_config.get('task', {})
        trainer_config = client_config.get('trainer_config', {})
        privacy_metrics_config = config.get('privacy_metrics_config', None)
        model_path = config["model_path"]

        strategy_algo = config['strategy']
        StrategyClass = select_strategy(strategy_algo)
        strategy = StrategyClass('client', config)
        print_rank(f'Client successfully instantiated strategy {strategy}', loglevel=logging.DEBUG)
        send_dicts = config['server_config'].get('send_dicts', False)

        begin = time.time()  
        client_stats = {}  

        data_strct = data_strcts[0]
        user = data_strct['users'][0]
        print_rank('Loading : {}-th client with name: {}, {} samples, {}s elapsed'.format(
            client_id[0], user, data_strct['num_samples'][0], time.time() - begin), loglevel=logging.INFO)

        # Get dataloaders
        train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=data_strct)

        # Instantiate the model object
        if model is None:
            model = make_model(
                model_config,
                dataloader_type=train_dataloader.__class__.__name__,
                input_dim=data_config['input_dim'],
                vocab_size=train_dataloader.vocab_size,
            )
        
        # Set model parameters
        n_layers, n_params = len([f for f in model.parameters()]), len(model_parameters)
        print_rank(f'Copying model parameters... {n_layers}/{n_params}', loglevel=logging.DEBUG)
        model = to_device(model)

        if send_dicts: # Send model state dictionary
            tmp = {}
            for param_key, param_dict in zip (model.state_dict(), model_parameters):
                tmp[param_key] = param_dict
            model.load_state_dict(tmp)
        else: # Send parameters
            for p, data in zip(model.parameters(), model_parameters):
                p.data = data.detach().clone().cuda() if torch.cuda.is_available() else data.detach().clone()
        print_rank(f'Model setup complete. {time.time() - begin}s elapsed.', loglevel=logging.DEBUG)


        # Fix parameters of layers
        if 'updatable_names' in trainer_config:
            set_component_wise_lr(model, client_config['optimizer_config'], trainer_config['updatable_names'])

        # Create the optimizer on the workers
        # NOTE: the server dictates the learning rate for the clients
        client_config['optimizer_config']['lr'] = initial_lr
        optimizer = make_optimizer(client_config['optimizer_config'], model)

        # Make the scheduled sampling scheduler
        ss_scheduler = None
        if 'ss_config' in client_config and client_config['ss_config'] is not None:
            ss_scheduler = ScheduledSamplingScheduler(model=model, **client_config['ss_config'])

        # Make the trainer
        trainer = Trainer(
            model=model,
            optimizer=optimizer,
            ss_scheduler=ss_scheduler,
            train_dataloader=train_dataloader,
            server_replay_config =client_config,
            max_grad_norm=client_config['data_config']['train'].get('max_grad_norm', None),
            anneal_config=client_config['annealing_config'] if 'annealing_config' in client_config else None,
            num_skips_threshold=client_config['num_skips_threshold'] if 'num_skips_threshold' in client_config else -1,
            ignore_subtask=client_config['ignore_subtask']
        )

        if trainer.optimizer is not None:
            initial_optimizer_state = copy.deepcopy(trainer.optimizer.state_dict())

        annealing_config = client_config['annealing_config'] if 'annealing_config' in client_config else None

        assert 'desired_max_samples' in client_config['data_config']['train'], 'Missing \'desired_max_samples\' entry in data config parameter'
        desired_max_samples = client_config['data_config']['train']['desired_max_samples']

        if trainer.optimizer is not None:  # reset the optimizer state
            if initial_lr > 0:
                trainer.optimizer.param_groups[0].update({'lr': initial_lr})
            initial_optimizer_state = copy.deepcopy(trainer.optimizer.state_dict())
            trainer.reset_optimizer(initial_optimizer_state, annealing_config)

        # Mark the end of setup
        end = time.time()
        client_stats['setup'] = end - begin
        print_rank(f'Client setup cost {client_stats["setup"]}s', loglevel=logging.DEBUG)               
        begin_training = end
        
        # Training begins here
        trainer.model.train()
        trainer.model.zero_grad()

        # Save the client batches if we want to evaluate the privacy metrics
        apply_privacy_metrics = (False if privacy_metrics_config is None else privacy_metrics_config['apply_metrics'])

        # This is where training actually happens
        algo_payload = None

        if strategy_algo == 'FedLabels':
            datasets =[get_dataset(data_path, config, task, mode="train", test_only=False, data_strct=data_strcts[i], user_idx=0) for i in range(3)]
            algo_payload = {'strategy':'FedLabels', 'data': datasets, 'iter': iteration, 'config': semisupervision_config}
        elif strategy_algo == 'FedProx':
            algo_payload = {'strategy':'FedProx', 'mu': client_config.get('mu',0.001)}
        train_loss, num_samples, algo_computation = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics, algo_payload = algo_payload)
        print_rank('client={}: training loss={}'.format(client_id[0], train_loss), loglevel=logging.DEBUG)

        # Estimate gradient magnitude mean/var
        # Now computed when the sufficient stats are updated.
        assert 'sum' in trainer.sufficient_stats
        assert 'mean' in trainer.sufficient_stats
        
        trainer.train_loss = train_loss
        trainer.num_samples = num_samples
        trainer.algo_computation = algo_computation

        # Compute pseudo-gradient
        if not send_dicts:
            for p, data in zip(trainer.model.parameters(), model_parameters):
                data = to_device(data)
                p.grad = data - p.data

        payload = strategy.generate_client_payload(trainer) if send_gradients else None

        if config['server_config']['type'] == 'personalization':
            # Initialize convex weight alpha
            alpha= config['client_config'].get('convex_model_interp', 0.75)
            local_model = make_model(config['model_config'])
            train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=data_strct)
            local_optimizer = make_optimizer(client_config['optimizer_config'], local_model)

            # Make the trainer
            local_trainer = Trainer(
                model=local_model,
                optimizer=local_optimizer,
                ss_scheduler=ss_scheduler,
                train_dataloader=train_dataloader,
                server_replay_config=client_config,
                max_grad_norm=client_config['data_config']['train'].get('max_grad_norm', None),
                anneal_config=client_config['annealing_config'] if 'annealing_config' in client_config else None,
                num_skips_threshold=client_config[
                    'num_skips_threshold'] if 'num_skips_threshold' in client_config else -1,
                ignore_subtask=client_config['ignore_subtask']
            )

            local_model_name = os.path.join(model_path, user + '_model.tar')
            local_alpha_name = os.path.join(model_path, user + '_alpha')

            if os.path.exists(local_model_name):
                print_rank('Loading Local Model .. {}'.format(local_model_name))
                local_trainer.load(local_model_name, update_lr_scheduler=False, update_ss_scheduler=False)

            if os.path.exists(local_alpha_name):
                print_rank('Loading Alpha Weight .. {}'.format(local_model_name), loglevel=logging.INFO)
                alpha = torch.load(local_alpha_name)

            # Copy original model
            original_local_model = local_trainer.get_model()

            # Training begins here
            local_trainer.model.train()
            local_trainer.model.zero_grad()

            # Run Local Processing
            train_loss, num_samples = local_trainer.train_desired_samples(desired_max_samples=desired_max_samples,
                                                                          apply_privacy_metrics=False)
            print_rank('client={}, user:{}: LOCAL training loss={}'.format(client_id[0], user, train_loss), loglevel=logging.INFO)

            local_trainer.save(
                model_path=model_path,
                config=config,
                token=user)

            # Estimate the pseudo-gradient for local model
            for p, orig_param in zip(local_trainer.model.parameters(), original_local_model.parameters()):
                orig_param = orig_param.cuda() if torch.cuda.is_available() else orig_param
                p.grad = orig_param.data - p.data

            alpha= alpha_update(local_trainer.model, trainer.model, alpha, initial_lr)
            torch.save(alpha, local_alpha_name)
            local_trainer.model.zero_grad()


        # Mark that training (including post-processing) is finished
        end = time.time()
        client_stats['training'] = end - begin_training
        client_stats['full cost'] = end - begin
        print_rank(f'Client training cost {end - begin_training}s', loglevel=logging.DEBUG)      
        print_rank(f'Client full cost {end - begin}s', loglevel=logging.DEBUG)

        # Create dictionary that is sent back to server
        client_output = {
            'cs': client_stats, 
            'tl': train_loss, 
            'mg': trainer.sufficient_stats['mag'],
            'vg': trainer.sufficient_stats['var'],
            'ng': trainer.sufficient_stats['mean'],
            'rg': trainer.sufficient_stats['norm'],
            'ns': num_samples,
            'pl': payload,
        }
       
        # Apply privacy metrics
        if privacy_metrics_config and privacy_metrics_config['apply_metrics']:
            print_rank('Applying privacy metrics', loglevel=logging.DEBUG)

            privacy_stats = {'Dropped clients': 0}
            batches = trainer.cached_batches
            trainer.cached_batches = []
            gradients = extensions.privacy.unroll_network(model.named_parameters(), select_grad=True)[0]

            if privacy_metrics_config['apply_indices_extraction']:
                allowed_word_rank = privacy_metrics_config.get('allowed_word_rank', 9000)
                embed_dim, vocab_size = model_config['embed_dim'], model_config['vocab_size']
                overlap, indices = privacy_metrics.extract_indices_from_embeddings(gradients, batches, embed_dim, vocab_size)

                max_overlap =  privacy_metrics_config.get('max_allowed_overlap', None)
                if max_overlap is not None and overlap > max_overlap:
                    print_rank('Removing this client because we extracted {}% words and the maximum allowed is {}%'.format(overlap * 100, max_overlap * 100))
                    client_output['wt'] = 0.0
                    privacy_stats['Dropped clients'] = 1

                privacy_stats['Extracted indices percentage'] = overlap
                privacy_stats['Words percentage above ' + str(allowed_word_rank) + ' word rank'] = (indices > allowed_word_rank).mean() if len(indices)>0 else 0
          
            if privacy_metrics_config['apply_leakage_metric']:
                print_rank('Applying leakage metric', loglevel=logging.DEBUG)

                orig_params = {n: p for (n, _), p in zip(trainer.model.named_parameters(), model_parameters)}
                max_ratio = np.exp(privacy_metrics_config['max_leakage'])
                optim_config = privacy_metrics_config['attacker_optimizer_config']
                is_leakage_weighted = privacy_metrics_config['is_leakage_weighted']

                leakage = privacy_metrics.practical_epsilon_leakage(orig_params,
                    trainer.model, batches, is_leakage_weighted, max_ratio, optim_config)                
                print_rank('privacy leakage: {}'.format(leakage), loglevel=logging.DEBUG)

                max_leakage =  privacy_metrics_config.get('max_allowed_leakage', None)
                if max_leakage is not None and leakage > max_leakage:
                    print_rank('Removing this client because the information leakage/practical epsilon is {} and the maximum allowed is {}'.format(leakage, max_leakage))
                    client_output['wt'] = 0.0
                    privacy_stats['Dropped clients'] = 1

                privacy_stats['Practical epsilon (Max leakage)'] = leakage
            
            client_output['ps'] = privacy_stats

        client_output['ts'] = time.time()
        return client_output


================================================
FILE: core/config.py
================================================
# Note this import requires python 3.7+
# Do we want to commit to this?
from __future__ import annotations
from dataclasses import dataclass
from collections.abc import MutableMapping
from cerberus import Validator
from importlib.machinery import SourceFileLoader
from utils.utils import print_rank
from importlib.machinery import SourceFileLoader
import os


# TODO everywhere: choose reasonable defaults.
# TODO: decide where task should live as a setting, maybe its own TaskConfig
# TODO: docstrings everywhere

# TODO: Make ModelConfig a base class that different models inherit from
# We could specify the modelconfig class in the config file,
# like we do for model.py.  The current implementation mixes NLG and BERT

# TODO: DatasetConfig needs to be teased apart.
# The main issue is we have *_data, list_of_train_data, train_data_server.
# They all essentially perform the same function in different contexts.
# also some no-longer-used parameters are still present.

# TODO: it's not clear what MutableMapping methods need overrides- we
# could probably just use the default implementation.

# TODO: not all pytorch optimizers can handle amsgrad - we should
# have distinct subclasses for the different optimizers

def from_dict(cls, config):
    """
    Helper function to convert a dict to a class
    """
    return cls(**config)


class Config(MutableMapping):
    """Base class for configuration classes."""
    def get(self, k: str, default=None):
        result = getattr(self, k, default)
        if result is None:
            return default
        return result

    def lookup(self, s: str, default=None):
        toks = s.split('.')
        child = getattr(self, toks[0], default)
        if len(toks) == 1:
            return child if child is not None else default
        elif isinstance(child, Config):
            return child.lookup('.'.join(toks[1:]), default)
        else:
            return default

    def __getitem__(self, k):
        return getattr(self, k)

    def __setitem__(self, k, v):
        setattr(self, k, v)

    def __delitem__(self, k):
        delattr(self, k)

    def __iter__(self):
        return iter(self.__dict__)

    def __len__(self):
        return len(self.__dict__)

    def __contains__(self, k):
        return getattr(self, k, None) is not None

    def pop(self, k, default=None):
        result = self.get(k, default)
        if k in self:
            delattr(self, k)
        return result


@dataclass
class ModelConfig(Config):
    """Base class for Model configurations

The model configuration specifies model architecture, parameters, and initialization settings.

Attributes:
    model_type (str): The class name of the model to instantiate. eg GRU.

    model_folder (str): The relative path to the model.py file where model_type is defined. eg experiments/nlg_gru/model.py

    pretrained_model_path (str): The path to the pretrained model.  If None, the model will be randomly initialized using the method defined in weight_init.

"""
    model_type: str = None
    model_folder: str = None
    pretrained_model_path: str = None

    @staticmethod
    def from_dict(config) -> ModelConfig:
        """Searches the model folder for config.py and if it is found the model config 
        is initialized from the class [model_type]Config"""
        cfg_path = os.path.dirname("./" + str(config['model_folder'])) + '/config.py'
        if os.path.exists(cfg_path):
            loader = SourceFileLoader('config', cfg_path).load_module()
            config_class = config['model_type'] + 'Config'
            try:
                config_type = getattr(loader, config_class)
                return from_dict(config_type, config)
            except AttributeError:
                print_rank(f"Config class {config_class} not found in {cfg_path}")
                raise
        else:
            print_rank(f"Warning: couldn't find {cfg_path}, falling back to dictionary.")
            return config
            

@dataclass
class BERTModelConfig(Config):
    """BERT model configuration

The BERT configuration specifies huggingface-specific BERT model settings.

Attributes:
    model_name (str): The name of the BERT model.  eg bert-base-uncased.

    cache_dir (str): Tokenizer cache directory, will be created if it doesn't exist.

    use_fast_tokenizer (bool): Whether to use the fast tokenizer.

    mask_token (str): special token to use for masking.

    task (str): The task to use for BERT.  eg mlm.

    past_index (int): The index of the past state in the BERT model's state dict.

    prediction_loss_only (bool): if False, also produce metrics for predictions and labels.

    process_line_by_line (bool): if True, process the input line-by-line.

ToDo:
    * check how cache_dir is used- there's a risk of multiple processes reading/writing at the same time.
    * verify the meaning of past_index (thanks copilot)
    * document the difference when process_line_by_line is True vs False

    """
    model_name: str = None
    cache_dir: str = None
    use_fast_tokenizer: bool = False
    mask_token: str = '<mask>'
    task: str = 'mlm'
    past_index: int | None = -2
    prediction_loss_only: bool = False
    process_line_by_line: bool = False

    @staticmethod
    def from_dict(config) -> BERTModelConfig:
        return from_dict(BERTModelConfig, config)


@dataclass
class BERTTrainingConfig(Config):
    """BERT training configuration

    Configuration settings for BERT training.

    Attributes:
        seed (int): random seed for reproducibility.

        label_smoothing_factor (float): label smoothing factor.  Applied label smoothing when the factor is non-zero.

        batch_size (int): batch size.

        max_seq_length (int): maximum input sequence length.
    """
    seed: int | None = None
    label_smoothing_factor: float | None = None
    batch_size: int | None = None
    max_seq_length: int | None = None

    @staticmethod
    def from_dict(config) -> BERTTrainingConfig:
        return from_dict(BERTTrainingConfig, config)


@dataclass
class BERTConfig(Config):
    """BERT configuration
    Specifies the model and training configuration for huggingface modeling scenarios.

    Attributes:
        loader_type (str): loader type hint. eg 'text'

        model (BERTModelConfig): BERT model configuration.

        training (BERTTrainingConfig): BERT training configuration.
    """
    loader_type: str = None
    model: BERTModelConfig = None
    training: BERTTrainingConfig = None

    @staticmethod
    def from_dict(config) -> BERTConfig:
        result = BERTConfig()
        for k in config:
            if k == 'model':
                result.model = BERTModelConfig.from_dict(config[k])
            elif k == 'training':
                result.training = BERTTrainingConfig.from_dict(config[k])
            else:
                setattr(result, k, config[k])
        return result


@dataclass
class PrivacyConfig(Config):
    """Privacy configuration

    The privacy configuration specified differential privacy settings for the model.
    The user can choose between local or global DP.  When local DP is enabled, a global
    epsilon can be computed by applying the RDP accountant (see extensions/privacy).
    The `eps` parameter is used to specify the privacy budget for local DP.  Conversely, when
    global DP is enabled, `eps` is ignored and `global_sigma` directly specifies the global
    Gaussian noise.   `max_grad` specifies the clipping parameter for local or global DP,
    `max_weight` specifies the clipping parameter for the local gradient aggregation weight
    (applies to softmax aggregation), and `weight_scaler` indicates how the aggregation weight
    is scaled before noise addition, and unscaled afterward. This enables a single eps/sigma
    parameter for both the gradient and its weight.

    Example:
       This example applies local DP with eps=1000. The global epsilon will be computing using Renyi DP accounting.

       .. code-block:: yaml

            dp_config:
                # Local dp clips and adds noise on the client and centrally accumulates the privacy budget.
                enable_local_dp: true
                eps: 100 # epsilon
                max_grad: 0.008  # max gradient
                # The max_weight and min_weight should be already scaled by weight_scaler
                # Because we scale down the weight using weight_scalar -> clip -> add noise -> scale back up.
                max_weight: 0.0001
                weight_scaler: 0.0001
                min_weight: 0.00009


    Attributes:
        enable_local_dp (bool): whether to enable local DP.

        enable_global_dp (bool): whether to enable global DP.

        eps (float): the privacy budget for local DP.

        delta (float): the privacy delta parameter for local DP.

        global_sigma (float): the global Gaussian noise for global DP.

        max_grad (float): the gradient clipping parameter.

        max_weight (float): the aggregation weight clipping parameter.

        weight_scaler (float): the aggregation weight scaling parameter.

        min_weight (float): the minimum per-gradient aggregation weight.

    """
    enable_local_dp: bool = False
    enable_global_dp: bool = False
    eps: float | None = None
    delta: float | None = None
    global_sigma: float | None = None
    max_grad: float | None = None
    max_weight: float | None = None
    weight_scaler: float | None = None
    min_weight: float | None = None

    @staticmethod
    def from_dict(config) -> PrivacyConfig:
        return from_dict(PrivacyConfig, config)


@dataclass
class PrivacyMetricsConfig(Config):
    """Privacy metrics configuration

    This optional feature computes local privacy metrics for computed gradients,
    and optionally filters gradients based on estimated privacy loss.

    Attributes:
        apply_metrics (bool): whether to compute privacy metrics.

        apply_indices_extraction (bool): whether to attempt local data reconstruction.

        allowed_word_rank (int): threshold for successful reconstruction.

        apply_leakage_metric (bool): whether to compute a privacy leakage metric based on the ratio of perplexities before and after local training.

        max_leakage (float): the maximum allowed privacy leakage before filtering

        adaptive_leakage_threshold (float): if non-zero, compute an adaptive leakage threshold based on the previous round of training.  For example at 0.95, the max_leakage will be adjusted to reject 5% of gradients, based on the previous round of training.

        is_leakage_weighted (bool): scales the leakage by the maximum likelihood of the pre- and post- likelihood tensors. ie the worst-case leakage is weighted by the worst-case likelihood that we might encounter it.

        attacker_optimizer_config (OptimizerConfig): the optimizer configuration for the reconstruction attack.
    """
    apply_metrics: bool = False
    apply_indices_extraction: bool = False
    allowed_word_rank: int | None = None
    apply_leakage_metric: bool = False
    max_leakage: float | None = None
    max_allowed_leakage: float | None = None
    adaptive_leakage_threshold: float | None = None
    is_leakage_weighted: bool = False
    attacker_optimizer_config: OptimizerConfig = None

    @staticmethod
    def from_dict(config) -> PrivacyMetricsConfig:
        result = PrivacyMetricsConfig()
        for k in config:
            if k == 'attacker_optimizer_config':
                result.attacker_optimizer_config = \
                    OptimizerConfig.from_dict(config[k])
            else:
                setattr(result, k, config[k])
        return result


@dataclass
class OptimizerConfig(Config):
    """Optimizer configuration

    Pass any pytorch-supported optimizer configuration. The object should include
    a `type` field which indicates the pytorch optimizer type that should be invoked.
    This will be stripped from the object before being passed to the Optimizer's init.
    """
    type: str = None
    # Leave this open for any keyword arguments, so we don't break torch constructors
    # In the future we can limit keywords to torch-specific ones.
    # lr: float = 0.0
    # weight_decay: float = 0.0
    # amsgrad: bool = False

    @staticmethod
    def from_dict(config) -> OptimizerConfig:
        # needs its own from_dict so we can accomodate any fields
        result = OptimizerConfig()
        assert 'type' in config
        for k in config:
            setattr(result, k, config[k])
        return result


@dataclass
class AnnealingConfig(Config):
    """Learning rate annealing configuration


    Attributes:
        type (str): the type of annealing. Supported methods: :code:`step_lr`, :code:`multi_step_lr`, :code:`rampup-keep-expdecay-keep`, :code:`val_loss`.

        step_interval (str): the interval at which to step the learning rate. Supported intevals: :code:`epoch`, :code:`batch`.

        gamma (float): the learning rate decay factor.

        step_size (int): the interval between annealing operations.
    """
    type: str = None
    step_interval: str = None
    gamma: float | None = None
    step_size: int | None = None

    @staticmethod
    def from_dict(config) -> AnnealingConfig:
        return from_dict(AnnealingConfig, config)


@dataclass
class DatasetConfig(Config):
    # Common to all text (NLG, MLM) dataloaders
    batch_size: int | None = None
    loader_type: str = None
    prepend_datapath: bool = False
    num_workers: int | None = None
    desired_max_samples: int | None = None

    # Common to all client.train dataloaders
    list_of_train_data: str = None
    max_grad_norm: float | None = None  # propose moving max_grad_norm to client config

    # Common to all server.train dataloaders. What is the difference?
    train_data: str = None
    train_data_server: str = None

    # Common to server.test dataloaders
    test_data: str = None

    # Common to server.val dataloaders
    val_data: str = None

    # Specific to NLG dataloaders
    tokenizer_type: str = None  # Note tokenizer_type appears in NLG configs, but always set to 'not applicable'
    vocab_dict: str = None
    pin_memory: bool = False
    num_frames: int | None = None  # num_frames is missing from NLG server.test dataloader
    max_batch_size: int | None = None
    max_num_words: int | None = None
    unsorted_batch: int | None = None
    utterance_mvn: bool = False  # only present on NLG client.train dataloader

    # Specific to MLM dataloaders
    task: str = None
    mlm_probability: float | None = None
    tokenizer_type_fast: bool = False
    max_seq_length: int | None = None
    min_words_per_utt: int | None = None
    max_samples_per_user: int | None = None
    mask_token: str = None
    cache_dir: str = None

    @staticmethod
    def from_dict(config) -> DatasetConfig:
        return from_dict(DatasetConfig, config)


@dataclass
class DataConfig(Config):
    """Data configurations

    Client and server configs may each contain a data config, consisting of train, test, and validate datasets.
    A typical configuration will define test and validate in the server data config, while the training data is defined in the client config.
    Optionally, the server can have a training config which defines server-side training data.


    Attributes:
        train (DatasetConfig): the training dataset configuration.

        val (DatasetConfig): the validation dataset configuration.

        test (DatasetConfig): the test dataset configuration.
    """
    train: DatasetConfig = None
    val: DatasetConfig = None
    test: DatasetConfig = None

    @staticmethod
    def from_dict(config) -> DataConfig:
        train = DatasetConfig.from_dict(config['train']) if 'train' in config else None
        val = DatasetConfig.from_dict(config['val']) if 'val' in config else None
        test = DatasetConfig.from_dict(config['test']) if 'test' in config else None

        return DataConfig(
            train, val, test
        )


@dataclass
class ServerReplayConfig(Config):
    """Server replay configuration

    When server-side training data is defined, this config defines how it is applied after each client training round.

    Attributes:
        server_iterations (int): the number of iterations to run over server-side training data for.

        optimizer_config (OptimizerConfig): the optimizer configuration to use for the server.
    """
    server_iterations: int
    ignore_subtask: bool
    optimizer_config: OptimizerConfig

    @staticmethod
    def from_dict(config) -> ServerReplayConfig:
        return ServerReplayConfig(
            config['server_iterations'],
            config['ignore_subtask'],
            OptimizerConfig.from_dict(config['optimizer_config'])
        )


@dataclass
class RLConfig(Config):
    """Reinforcement learning configuration

    RL can be applied during dynamic gradient aggregation to speed up convergence. This configuration defines the settings for server-side RL to train the model for DGA.

    Attributes:
        marginal_update_RL (bool): whether to update the RL model when the loss is small.

        RL_path (str): the path to the RL model to train.

        RL_path_global (bool): whether the global training output path should be prepended to RL_path.

        model_descriptor_RL (str): string to append to the model filename.

        network_params (list): List of layer widths in the RL network. eg: 300,128,128,128,64,100

        initial_epsilon (float): the initial epsilon value for the epsilon-greedy policy.

        final_epsilon (float): the final epsilon value for the epsilon-greedy policy.

        epsilon_gamma (float): the decay rate for the epsilon-greedy policy.

        max_replay_memorize_size (int): the maximum number of samples to store in the replay memory.

        minibatch_size (int): the size of the minibatch to use for training.

        gamma (float): the discount factor for the RL model.

        optimizer_config (OptimizerConfig): the optimizer configuration to use for the RL model.

        annealing_config (AnnealingConfig): the annealing configuration to use for the RL model.


    """
    marginal_update_RL: bool = False
    RL_path: str = None
    RL_path_global: bool = False
    model_descriptor_RL: str = None
    network_params: list = None
    initial_epsilon: float | None = None
    final_epsilon: float | None = None
    epsilon_gamma: float | None = None
    max_replay_memory_size: int | None = None
    minibatch_size: int | None = None
    gamma: float | None = None
    optimizer_config: OptimizerConfig = None
    annealing_config: AnnealingConfig = None

    @staticmethod
    def from_dict(config) -> RLConfig:
        result = RLConfig()
        for k in config:
            if k == 'optimizer_config':
                result.optimizer_config = OptimizerConfig.from_dict(config[k])
            elif k == 'annealing_config':
                result.annealing_config = AnnealingConfig.from_dict(config[k])
            else:
                setattr(result, k, config[k])
        return result


@dataclass
class ServerConfig(Config):
    """Server configuration

    The server configuration defines the server-side settings.

    Attributes:
        resume_from_checkpoint (bool): whether to resume training from a checkpoint.

        max_iterations (int): the maximum number of iterations (federated training rounds) to run.

        num_clients (int): the number of clients to use per training round.

        optimizer_config (OptimizerConfig): the optimizer configuration to use server-side.

        annealing_config (AnnealingConfig): the learning rate annealing configuration to use server-side.

        val_freq (int): the number of iterations between validation evaluation runs.

        rec_freq (int): the number of iterations between test evaluation runs.

        initial_val (bool): whether to run validation before initiating training.

        initial_rec (bool): whether to run test before initiating training.

        wantRL (bool): whether to train the RL model.

        RL (RLConfig): the RL configuration to use if wantRL is True.

        data_config (DataConfig): the data configuration to use server-side.

        type (str): the type of server. Currently this parameter is ignored and OptimizationServer is always used. However there is some validation code that checks for one of the following values:

            - model_averaging
            - optimization
            - model_optimization
            - cluster_finetuning
            - cluster_parallel

        aggregate_median (str): the aggregation method to use (DGA softmax, or mean). Note that this only applies when the global aggregation strategy is DGA.

        weight_train_loss (str): when softmax DGA is enabled, what metric to use for weighting. One of

            - train_loss
            - mag_var_loss
            - mag_mean_loss

        softmax_beta (float): the beta value to use for the softmax DGA.

        max_weight (float): the maximum allowed client weight.

        initial_lr_client (float): the initial learning rate for each client.

        lr_decay_factor (float): the client learning rate decay factor.

        best_model_criterion (str): The metric to choose when resetting to the best model so far.

        server_replay_config (ServerReplayConfig): the server replay configuration to use for any server-side training.

    """
    resume_from_checkpoint: bool = False
    max_iteration: int | None = None
    num_clients_per_iteration: int | None = None
    optimizer_config: OptimizerConfig = None
    annealing_config: AnnealingConfig = None
    val_freq: int | None = None
    rec_freq: int | None = None
    initial_val: bool = True
    initial_rec: bool = True
    wantRL: bool = False
    RL: RLConfig = None
    data_config: DataConfig = None
    type: str = None
    aggregate_median: str = None
    weight_train_loss: str = None
    softmax_beta: float | None = None
    max_weight: float | None = None
    initial_lr_client: float | None = None
    lr_delay_factor: float | None = None
    best_model_criterion: str = 'loss'
    server_replay_config: ServerReplayConfig = None

    @staticmethod
    def from_dict(config) -> ServerConfig:
        result = ServerConfig()

        for k in config:
            if k == 'optimizer_config':
                result.optimizer_config = \
                    OptimizerConfig.from_dict(config[k])
            elif k == 'annealing_config':
                result.annealing_config = \
                    AnnealingConfig.from_dict(config[k])
            elif k == 'data_config':
                result.data_config = \
                    DataConfig.from_dict(config[k])
            elif k == 'server_replay_config':
                result.server_replay_config = \
                    ServerReplayConfig.from_dict(config[k])
            elif k == 'RL':
                result.RL = \
                    RLConfig.from_dict(config[k])
            else:
                setattr(result, k, config[k])
        return result


@dataclass
class ClientConfig(Config):
    """
    Client configuration

    The client configuration defines the client-side settings.

    Attributes:
        meta_learning (str): Set to 'basic'.  Currently ignored.

        stats_on_smooth_grad (bool): When true, gradient statistics are reset each round. Currently, it appears these statistics aren't used.

        ignore_subtask (bool): Used to determine which model loss to use. In most cases just set to False.

        num_skips_threshold (int): previously used to skip users, deprecated.

        copying_train_data (bool): has no effect.

        do_profiling (bool): whether to enable client-side profiling.

        data_config (DataConfig): the data configuration to use client-side.

        type (str): the type of client. Currently this parameter is ignored?

        meta_optimizer_config (OptimizerConfig): the optimizer configuration to use for meta-learning.

        optimizer_config (OptimizerConfig): the optimizer configuration to use for client-side training.

        annealing_config (AnnealingConfig): the learning rate annealing configuration to use client-side.
    """
    meta_learning: str = None
    stats_on_smooth_grad: bool = False
    ignore_subtask: bool = False
    num_skips_threshold: int | None = None
    copying_train_data: bool = False
    do_profiling: bool = False
    data_config: DataConfig = None
    type: str = None
    meta_optimizer_config: OptimizerConfig = None
    optimizer_config: OptimizerConfig = None
    annealing_config: AnnealingConfig = None

    @staticmethod
    def from_dict(config) -> ClientConfig:
        result = ClientConfig()
        for k in config:
            if k == 'data_config':
                result.data_config = DataConfig.from_dict(config[k])
            elif k == 'meta_optimizer_config':
                result.meta_optimizer_config = \
                    OptimizerConfig.from_dict(config[k])
            elif k == 'optimizer_config':
                result.optimizer_config = \
                    OptimizerConfig.from_dict(config[k])
            elif k == 'annealing_config':
                result.annealing_config = \
                    AnnealingConfig.from_dict(config[k])
            else:
                setattr(result, k, config[k])
        return result


@dataclass
class FLUTEConfig(Config):
    """
    FLUTEConfig represents the global configuration for a training job.

    Attributes:
        model_config (ModelConfig): the model configuration to use.

        dp_config (PrivacyConfig): differential privacy configuration.

        strategy (str): Aggregation strategy, eg DGA or FedAvg.

        server_config (ServerConfig): the server configuration to use.

        client_config (ClientConfig): the client configuration to use.

    """
    model_config: ModelConfig = None
    dp_config: PrivacyConfig = None
    privacy_metrics_config: PrivacyMetricsConfig = None
    strategy: str = None
    server_config: ServerConfig = None
    client_config: ClientConfig = None

    def validate(config):

        # Join paths in config file
        if config["server_config"]["wantRL"]:
            rl_path = config["server_config"]["RL"]["RL_path"]
            rl_path = os.path.join(config["output_path"],rl_path) if config["server_config"]["RL"].get("RL_path_global", True) \
                                                            else os.path.join(config["output_path"], config["experiment_name"],rl_path)

        if "pretrained_model_path" in config["model_config"]:
            config["model_config"]["pretrained_model_path"] = os.path.join(config["data_path"], config["model_config"]["pretrained_model_path"])

        for section in ["server_config", "client_config"]:
            for mode in ['test','val','train']:
                if mode in config[section]["data_config"] and "vocab_dict" in config[section]["data_config"][mode]:
                    config[section]["data_config"][mode]["vocab_dict"] = os.path.join(config['data_path'], config[section]["data_config"][mode]["vocab_dict"])
                
                # TODO: Remove BERT specific parameters
                if 'BERT' in config['model_config']:
                    if mode!= 'train':
                        config['server_config']['data_config'][mode]['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
                        config['server_config']['data_config'][mode]['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
                    else:
                        config['client_config']['data_config'][mode]['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
                        config['client_config']['data_config'][mode]['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
        return config

    @staticmethod
    def from_dict(config) -> FLUTEConfig:

        # Validate schema in config file
        schema = eval(open('./core/schema.py', 'r').read())
        v = Validator(schema)
        if not v.validate(config,schema):
            raise ValueError('Missing {} argumment in config file '.format(v.errors))
        
        # Normalize default values
        original_config = config
        config = v.normalized(config)

        for section in ['server_config', 'client_config']:
            for mode in config[section]['data_config'].keys():
                diff = config[section]['data_config'][mode].keys() - original_config[section]['data_config'][mode].keys()
                if len(diff) > 0:
                    print_rank("Assigning default values for: {} in [{}][{}][data_config]".format(diff, section, mode))
        
        dp_config = \
            PrivacyConfig.from_dict(config['dp_config']) \
            if 'dp_config' in config else None

        priv_metrics_config = \
            PrivacyMetricsConfig.from_dict(config['privacy_metrics_config']) \
            if 'privacy_metrics_config' in config else None

        strategy = config.get('strategy', 'DGA')

        return FLUTEConfig(
            ModelConfig.from_dict(config['model_config']),
            dp_config, priv_metrics_config, strategy,
            ServerConfig.from_dict(config['server_config']),
            ClientConfig.from_dict(config['client_config'])
        )


================================================
FILE: core/dataloader.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from torch.utils.data import DataLoader as PyTorchDataLoader
from abc import ABC

class BaseDataLoader(ABC, PyTorchDataLoader):
    '''This is a wrapper class for PyTorch dataloaders.'''

    def create_loader(self):
        '''Returns the dataloader'''
        return self

        
    


================================================
FILE: core/dataset.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from torch.utils.data import Dataset as PyTorchDataset
from abc import ABC, abstractmethod

class BaseDataset(ABC, PyTorchDataset):
    '''This is a wrapper class for PyTorch datasets.'''

    @abstractmethod
    def __init__(self,**kwargs):
        super(BaseDataset, self).__init__()
        
    @abstractmethod
    def __getitem__(self, idx, **kwargs):
        '''Fetches a data sample for a given key'''
        pass
    
    @abstractmethod
    def __len__(self):
        '''Returns the size of the dataset'''
        pass
    
    @abstractmethod
    def load_data(self,**kwargs):
        '''Wrapper method to read/instantiate the dataset'''
        pass


================================================
FILE: core/evaluation.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
In this file we define the functions for running
test and validation tasks inside the Server.
'''

import logging
import torch
import numpy as np

# Internal imports
import core.federated as federated
from core.client import Client
from utils import print_rank

# AzureML-related libs
from azureml.core import Run
run = Run.get_context()

class Evaluation():

    def __init__(self, config, model_path, process_testvalidate, idx_val_clients, idx_test_clients, single_worker):

        self.config = config
        self.model_path = model_path
        self.process_testvalidate = process_testvalidate
        self.server_type = config['server_config']['type']
        self.idx_val_clients = idx_val_clients
        self.idx_test_clients = idx_test_clients
        self.send_dicts = config['server_config'].get('send_dicts', False)
        self.single_worker = single_worker
        super().__init__()
    
    def run(self, eval_list, req, metric_logger=None):
        '''Run test/validation taks depending on the modes
        received in the eval_list.
        
        Args:
            eval_list (arr): Contains the tasks to run.
            req (dict): information for test/val tasks
            metric_logger (callback, optional): callback used for logging.
                Defaults to None, in which case AML logger is used.
        '''      
        
        self.worker_trainer = req['worker_trainer']
        if self.send_dicts:
            global_model_values = [self.worker_trainer.model.state_dict()[param_key].to(torch.device('cpu')) for param_key in self.worker_trainer.model.state_dict()]
        else:
            global_model_values = [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()]

        if 'tmp_unsup' in req:
            unsup_values = req['tmp_unsup'].values()
            sup_values = req['tmp_sup'].values()
            semisupervision_inference = True
        else:
            semisupervision_inference = False

        save_model = False 
        
        if metric_logger is None:
            metric_logger = run.log

        for mode in eval_list:

            # Skipping validation round when RL is enabled
            if 'wantRL' in self.config['server_config'] and self.config['server_config']['wantRL'] and mode == "val":
                continue
            
            # Compute avg_loss and avg_acc
            self.metrics = self.run_distributed_inference(mode, global_model_values)
            req = self.initialize_req(req) if len(req) == 1 else req

            # Only if for semisupervision
            if semisupervision_inference:
                unsup_metrics = self.run_distributed_inference(mode, unsup_values)
                sup_metrics = self.run_distributed_inference(mode, sup_values)

                for key, value in unsup_metrics.items():
                    metric_logger(str("Unsup" +mode + " " + key).capitalize(), value['value'])
                    print_rank('LOG UNSUP: {}_{}={}'.format(mode, key, value['value']))
                
                for key, value in sup_metrics.items():
                    metric_logger(str("Sup" + mode + " " + key).capitalize(), value['value'])
                    print_rank('LOG SUP: {}_{}={}'.format(mode, key, value['value']))

            # Log metrics
            for key, value in self.metrics.items():
                metric_logger(str(mode + " " + key).capitalize(), value['value'])
                print_rank('LOG: {}_{}={}: best_{}_{}={}'.format(mode, key, value['value'], mode, key, req[str("best_"+ mode + "_" + key)]))

            for key,value in self.metrics.items():
                attr = str("best_"+ mode + "_" + key)
                if value['higher_is_better']:
                    if self.metrics[key]['value'] > req[attr]: 
                        req[attr] = self.metrics[key]['value']
                        save_model = True
                else:
                    if self.metrics[key]['value'] < req[attr]:
                        req[attr] = self.metrics[key]['value']
                        save_model = True
                
                if save_model and mode == 'val':
                    self.worker_trainer.save(
                        model_path=self.model_path,
                        token=str('best_'+ mode +'_'+key),
                        config=self.config['server_config']
                    )
                    save_model = False
        
        return req
    
    def initialize_req(self, req):
        '''Update the keys, to have the same as metrics dictionary. This 
        function is only used during itr=0 for initializing the req 
        dictionary. 

        Args:
            req (dict): Best results for all the metrics (e.g. best_val_acc).
        '''
        for mode in ['test','val']:
            for key in self.metrics.keys():
                attr = "best_"+ mode + "_" + key 
                req[attr] = -1.0 if self.metrics[key]['higher_is_better'] else float('inf')

        return req

    def run_distributed_inference(self, mode, model):
        '''Call `run_distributed_evaluation` specifically for test or validation.
        
        This is just a helper function that fetches the clients depending on
        the mode and calls `run_distributed_evaluation` using that list.

        Args:
            mode (str): `test` or `val`.
        '''
        if mode == 'val':
            clients = self.idx_val_clients
        elif mode == 'test':
            clients = self.idx_test_clients
        else:
            raise NotImplementedError('Unsupported mode: {}'.format(mode))

        return self.run_distributed_evaluation(mode, clients, model)

    def run_distributed_evaluation(self, mode, clients, model):
        '''Perform evaluation using available workers.

        See also `process_test_validate` on federated.py.

        Args:
            mode (str): `test` or `val`.
            clients (list): clients for test/val round.
        '''

        total = 0
        self.logits = {'predictions': [], 'probabilities': [], 'labels': []}
        server_data = (0.0, model, 0)
        for result in self.process_testvalidate(clients, server_data, mode, self.single_worker):
            output, metrics, count = result
            val_metrics =  {key: {'value':0, 'higher_is_better': False} for key in metrics.keys()} if total == 0 else val_metrics
 
            for key in val_metrics:
                val_metrics[key]['value'] += metrics[key]['value']* count
                val_metrics[key]['higher_is_better'] = metrics[key]['higher_is_better']
            total+= count
            
            if output is not None:
                self.logits['predictions'].append(output['predictions'])
                self.logits['probabilities'].append(output['probabilities'])
                self.logits['labels'].append(output['labels'])

        if  self.logits['probabilities'] and self.logits['predictions'] and self.logits['labels']:
            self.logits['predictions'] = np.concatenate(self.logits['predictions'])
            self.logits['probabilities'] = np.concatenate(self.logits['probabilities'])
            self.logits['labels'] = np.concatenate(self.logits['labels'])

        
        for key in val_metrics:
                val_metrics[key]['value'] = val_metrics[key]['value']/total
            
        self.losses = [val_metrics['loss']['value'], val_metrics['acc']['value']] # For compatibility with Server
        return val_metrics

def make_eval_clients(dataset, config):
    '''Generator that yields clients for evaluation, continuously.

    Args:
        dataset (torch.utils.data.Dataset): used to get client's data
        config (dict): used for the client's constructor
    '''

    total = sum(dataset.num_samples)
    clients = federated.size() - 1 if federated.size()>1 else federated.size()
    delta = total / clients + 1
    threshold = delta
    current_users_idxs = list()
    current_total = 0

    if config["server_config"]["type"] == "personalization":  
        for i in range(len(dataset.user_list)):
            yield Client([i], config, False)
    else:
        for i in range(len(dataset.user_list)):
            current_users_idxs.append(i)
            count = dataset.num_samples[i]
            current_total += count
            if current_total > threshold:
                print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
                yield Client(current_users_idxs, config, False)
                current_users_idxs = list()
                current_total = 0

        if len(current_users_idxs) != 0:
            print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
            yield Client(current_users_idxs, config, False)


================================================
FILE: core/federated.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import cProfile
import logging
import threading 

import torch
import torch.distributed as dist
import numpy as np

from core.client import Client
from utils import (
    print_rank,
    print_profiler,
    to_device,
)

COMMAND_UPDATE = 0
COMMAND_TRAIN = 1
COMMAND_TERMINATE = 10
COMMAND_TESTVAL = 11
COMMAND_SYNC_NODES = 9
GLOBAL_MESSAGE = None

def encode_string(word, string_to_int = True):
    """ Encodes/Decodes the dictionary keys into an array of integers to be sent 
    as tensors of the same shape during NCCL/Gloo P2P communication.
    
    Args:
            word (string/array): key to be encoded/decoded.
            string_to_int (bool): flag that indicates which action to perform.
    """

    if string_to_int: # encode
        word = word.ljust(8, ' ') if len(word) < 8 else word # padding -- 8 is max length, all tensors must have the same size during communication
        word_encoded = [letter for letter in word.encode()]
        return word_encoded
    else: #decode
        cleanup_array = [letter for letter in word if letter!= 32] # Remove padding
        word_decoded = bytes(cleanup_array).decode()
        return word_decoded

def rank():
    """ Return rank of node. """
    return int(os.environ['RANK'])

def local_rank():
    """ Return local rank of node. """
    return int(os.environ['LOCAL_RANK'])

def size():
    """ Returns number of nodes in the distributed group, including server. """
    return int(os.environ['WORLD_SIZE'])

def _recv(x, src=0):
    """ Receives tensors with a single element or a list of tensors 
    with the same shape during distributed communication. """

    x = torch.tensor(x) if torch.is_tensor(x) == False else x
    x = to_device(x)
    dist.recv(tensor=x, src=src)
    x.to('cpu')
    
    try:
        return x.item() # single element
    except:
        return x.tolist() # list of tensors

def _recv_gradients(src):
    """ Receives a list of tensors with different shape during 
    distributed communication. """

    n, n_dimensions, grads = 0, 0, [] # tensors intialization -- required by torch.
    n = _recv(n,src)
    for i in range(n):
        n_dimensions = _recv(n_dimensions,src)
        dimensions = [0 for i in range(n_dimensions)]
        dimensions = _recv(dimensions, src)
        print_rank(f"Received dimensions {dimensions}", loglevel=logging.DEBUG)
        param = to_device(torch.zeros(dimensions))
        print_rank(f"Shape assigned {param.shape}", loglevel=logging.DEBUG)
        dist.recv(param,src)
        grads.append(param.detach().cpu())
    torch.cuda.empty_cache() 
    return grads

def _send(x, dst=0):
    """ Send tensors with a single element or a list of tensors 
    with the same shape during distributed communication. """
    x = torch.tensor(x)
    x = to_device(x)
    dist.send(x, dst)
    del x 
    torch.cuda.empty_cache()

def _send_metrics(output):
    """ Organize the keys and values from the resulting dictionary 
    from test/val rounds into arrays that are sent as independent 
    tensors during distributed communication. """

    keys = [encode_string(key) for key in output.keys()]
    values = [float(output[key]['value']) for key in output.keys()]
    higher_is_better = [int(output[key]['higher_is_better']) for key in output.keys()] # send the boolean as int

    _send(len(keys),0) 
    _send(keys)
    _send(values)
    _send(higher_is_better)

def _send_gradients(gradients, dst):
    """ Send a list of tensors with different shape during 
    distributed communication. """

    _send(len(gradients), dst)
    for i in gradients:
        dimensions = [int(d) for d in i.shape]
        _send(len(dimensions),dst)
        _send(dimensions,dst)
        param = to_device(i)
        dist.send(param,dst)
        del param 
        torch.cuda.empty_cache()

def _send_train_output(output):
    """ Organize the keys and values from the the returning ´client_output´ 
    dictionary in ´Client.proces_round()´ function during training rounds,
    into arrays that are sent as independent tensors during distributed 
    communication. """

    cs_values = [float(cs_v) for cs_v in output['cs'].values()] # cs dict -- values are flatten in 1d array
    pl_values = [float(output['pl']['weight'])] # pl dict
    gradients = output['pl']['gradients'] # gradients are sent independently

    if len(output.keys()) > 9: # DP metrics
        ps_values = [float(ps_v) for ps_v in output['ps'].values()]
        values = cs_values + [float(output[key]) for key in output.keys() if key not in ['cs','pl','ps']] + pl_values + ps_values # reorganizing values in the order expected by the Server
    else:
        values = cs_values + [float(output[key]) for key in output.keys() if key not in ['cs','pl']] + pl_values # reorganizing values in the order expected by the Server
    
    # Send data
    _send(int(len(output.keys())),0) # Warn for number of keys
    _send(values, 0)
    _send_gradients(gradients, 0)

def build_grads_dict(node):
    """ Reconstruct the dictionary ´client_output´ returned by 
    ´Client.proces_round()´ function on the Server side during 
    distributed communication. """

    # Initialize tensors
    n_keys = 0
    n_keys = _recv(n_keys,node)
    print(n_keys)

    if n_keys == 9:
        keys = ['cs','tl','mg','vg','ng','rg','ns','ts','pl']
        values = [0.0 for i in range(11)] # initializing tensor shape -- 11 is fixed number of keys expected
    elif n_keys == 10:
        keys = ['cs','tl','mg','vg','ng','rg','ns','ts','pl','ps']
        values = [0.0 for i in range(15)] # When the privacy metrics are enabled
    elif n_keys == 11:
        keys = ['cs','tl','mg','vg','ng','rg','ns','wt','ts','pl','ps']
        values = [0.0 for i in range(16)] # When the privacy metrics are enabled
    
    # Read data
    values = _recv(values,node)
    grads = _recv_gradients(node)
    
    cs_values = [{key: values.pop(0) for key in ['setup','training','full cost']}] # recreating cs dict
    # Rebuilding original dictionary
    if n_keys == 9:
        pl_values = [{'weight':values.pop(), 'gradients': grads}] # recreating pl dict
        values_list = cs_values + [values.pop(0) for i in range(7)] + pl_values # 7 is fixed length for remaining items
    else:
        ps_values = [{key: values.pop() for key in ['Practical epsilon (Max leakage)','Words percentage above 9000 word rank','Extracted indices percentage','Dropped clients']}]
        pl_values = [{'weight':values.pop(), 'gradients': grads}] # recreating pl dict
        values_list = cs_values + [values.pop(0) for i in range(len(values))] + pl_values + ps_values

    result = dict(zip(keys,values_list))

    # Cast values to original type
    for key in ['mg','vg','ng','rg']:
        result[key] = np.float32(result[key])
    result['ns'] = int(result['ns'] )
                
    return result

def build_metrics_dict(node):
    """ Reconstruct the dictionary returned during test/val rounds
    on the Server side during distributed communication. """

    # Initialize tensors
    n = 0
    n = _recv(n,node)
    keys = [[0 for j in range(8)] for i in range(n)] # max_seq_len for metric name is 8
    values = [0.0 for i in range(n)]
    higher_is_better = [0 for i in range(n)]

    # Read data
    keys = _recv(keys,node)
    values = _recv(values,node)
    higher_is_better = _recv(higher_is_better,node)

    # Reorganize output + decode dict keys
    orig_keys = [encode_string(key, string_to_int=False) for key in keys]
    values_dict = [{'value': float(v), 'higher_is_better': bool(higher_is_better[i])} for i, v in enumerate(values)]
    metrics = dict(zip(orig_keys,values_dict))
    num_instances = int(metrics.pop('num')['value'])

    result = None, metrics, num_instances
            
    return result

def receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes):
    """ Receives the clients output on the Server side in async/sync mode. 
    Asynchronous mode is only enabled when using NCCL backend given that Gloo 
    does not provide native non-blocking implementation to check if the operation 
    has been completed during distributed training"""

    if dist.get_backend() == "nccl": # Async
        for node, req in node_request_map:
            if req.is_completed():
                result = build_metrics_dict(node) if command == COMMAND_TESTVAL else build_grads_dict(node)
                results_list.append(result)
                free_nodes.append(node)
                node_request_map.remove((node,req))
                print_rank(f"Finished releasing the nodes {free_nodes}", loglevel=logging.DEBUG)
    else: # Sync
        print_rank(f"Waiting for a workers", loglevel=logging.DEBUG)
        gather_objects = [(None,None,None) for i in range(size())]
        output = [None for _ in gather_objects]
        dist.all_gather_object(output, gather_objects[rank()])
        print_rank(f" All workers have finished ... taking the remaining clients {len(output)}", loglevel=logging.DEBUG)
        output = [e for i,e in enumerate(output) if i not in idle_nodes ] # Cleanup for idle workers
        results_list = results_list + output[1:]
        free_nodes = list(range(1, size()))
    
    return node_request_map, results_list, free_nodes

def append_async_requests(node_request_map, node):
    """ Appends the asynchronous request sent to each worker during 
    asynchronous training. """

    ack = to_device(torch.tensor(1))
    req = dist.irecv(tensor=ack, src=node)
    node_request_map.append((node,req))
    return node_request_map

def sync_idle_nodes(client_queue, free_nodes):
    """ Request dummy outputs to the odd (idle) nodes during synchronous training
    to prevent them to get trapped in the state of the previous iterations """

    idle_nodes = []
    if len(client_queue) == 0:
        print_rank(f"Free idle nodes {len(free_nodes)}", loglevel=logging.DEBUG)
        while len(free_nodes) > 0:
            node = free_nodes.pop()
            idle_nodes.append(node)
            _send(COMMAND_SYNC_NODES, node)
    return idle_nodes

class Server:
    """Server object responsible for orchestration and aggregation.

    The Server is one of the two objects that may exist inside of a thread, all
    throughout its execution (the other being the Worker). At every round, the
    Server samples clients ids and send their data for an available Worker to process.
    The Workers then each produce a new model, and all models are sent to the Server
    for aggregation.

    The methods defined here are related to orchestration only, the aggregation
    will be done by a different object which inherits from this one.

    Notes:
        This class has no :code`__init__` method, and all its methods are static.
        It thus only serves the purpose of grouping the methods, but nothing
        is actually stored inside of the object.
    """
    @staticmethod
    def dispatch_clients(clients, server_data, command, mode=None, do_profiling=False, single_worker=None):
        """Perform the orchestration between Clients and Workers.

        This function does the following:
            1. It sends the server_data to all workers
            2. For each available Worker:
                2a. It sends the index of the client to instantiate
                2c. It triggers the execution of the command on the
                    Client.
            3. Collect and return all client outputs.

        Notes:
            This function yields the gradients of different clients
            as they are received. Therefore, the order of the results generally
            does not correspond to the order of the clients.

            All commands used during Server-Worker communication must be 
            float/integers given that torch.distributed only allows to
            send/recv tensors.

        Args:
            clients (list): list of clients to be processed.
            server_data (dict): server data sent to the workers and passed to
                clients, typically includes the global model at that step.
            command (int): instruction for worker to execute on the Client.
            mode (int): test/val only provided during evaluation rounds.
            do_profiling (bool): enables profiler during comunication.
        
        Returns:
            Generator of results.
        """
        # Single GPU flag
        single_gpu = True if size()==1 else False
        print_rank(f"Single GPU flag Server: {single_gpu}", loglevel=logging.DEBUG)

        # Some cleanup
        torch.cuda.empty_cache()
        torch.cuda.synchronize() if torch.cuda.is_available() else None

        # Initialize communication profiler
        profiler = None
        if do_profiling:
            profiler = cProfile.Profile()
            profiler.enable()

        # Update lr + model parameters each round for all workers
        lr, model_params, nround = server_data
        if not single_gpu:
            for worker_rank in range(1, size()):
                _send(COMMAND_UPDATE, worker_rank)
                _send(lr,worker_rank)
                _send_gradients(model_params, worker_rank)
                _send(float(nround),worker_rank)
                print_rank(f"Finished sending lr {lr} and n_params {len(model_params)} to worker {worker_rank} - round {nround}", loglevel=logging.DEBUG)
                print_rank(f"Finished sending server_data to workers", loglevel=logging.DEBUG)
        
            client_queue = clients.copy()
            print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
            free_nodes = list(range(1, size()))
            results_list, node_request_map = [], []

            # Initiate computation for all clients
            while client_queue:
                print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
                assert len(free_nodes) > 0
                node = free_nodes.pop()
                index = len(client_queue)-1
                client_to_process = client_queue.pop(index) 
                print_rank(f"Sending client {index} to worker {node}", loglevel=logging.DEBUG)
                _send(command, node) # The command should indicate the worker which function to run on the client

                if command == COMMAND_TESTVAL:
                    _send(mode,node) # Only for test/val has a value
                    _send(index, node) # Worker receives the index of the client to pop
                elif command == COMMAND_TRAIN:
                    _send(client_to_process, node)
                print_rank(f"Finished assigning worker {node}, free nodes {free_nodes}", loglevel=logging.DEBUG)

                if dist.get_backend() == "nccl":
                    append_async_requests(node_request_map, node)
                    idle_nodes = None
                else:
                    idle_nodes = sync_idle_nodes(client_queue, free_nodes)
    
                # Waits until receive the output from all ranks
                if not free_nodes:
                    print_rank(f"Waiting for a workers, free nodes {free_nodes}, reqs_lst {node_request_map}", loglevel=logging.DEBUG)
                    while len(free_nodes) == 0:
                        node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)
                        for output in results_list:
                            yield output
                        results_list = []

            # Wait for all workers to finish
            while (len(node_request_map)) != 0:
                node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)

                for output in results_list:
                    yield output
                results_list = []
        else:
            # For a single-GPU execution, there is no P2P communication in the same GPU. Using threats to coordinate.
            
            global GLOBAL_MESSAGE
            GLOBAL_MESSAGE = server_data

            if command == COMMAND_TESTVAL:
                t1 = threading.Thread(target=single_worker.trigger_evaluate)
                t1.start()
                t1.join()
                yield GLOBAL_MESSAGE
            elif command == COMMAND_TRAIN:
                total_clients = clients.copy()
                
                for client_id in total_clients:
                    GLOBAL_MESSAGE = lr, model_params, nround, client_id
                    t1 = threading.Thread(target=single_worker.trigger_train)
                    t1.start()
                    t1.join()
                    result = GLOBAL_MESSAGE
                    yield result

        if do_profiling:
            profiler.disable()
            print_profiler(profiler)

        # Some cleanup
        torch.cuda.empty_cache()
        torch.cuda.synchronize() if torch.cuda.is_available() else None

    @staticmethod
    def process_clients(clients, server_data, single_worker):
        """Ask workers to perform training on Clients.

        Args:
            clients (list): list of clients indexes sampled by ´Server.py´ 
                            object per iteration.
            server_data (dict): dictionary containing model.

        Returns:
            Generator of results.
        """
        return Server.dispatch_clients(clients, server_data, COMMAND_TRAIN, single_worker=single_worker)

    @staticmethod
    def process_testvalidate(clients, server_data, mode, single_worker):
        """Ask workers to perform test/val on Clients.

        Args:
            clients (list): list of clients indexes for test/val rounds.
            server_data (dict): dictionary containing model.
            mode (str): test/val.

        Returns:
            Generator of results.
        """

        mode = [-2] if mode == "test" else [2]
        return Server.dispatch_clients(clients, server_data, COMMAND_TESTVAL, mode, single_worker=single_worker)

    @staticmethod
    def terminate_workers(terminate=True):
        """Terminate the execution of the workers."""

        if terminate:
            print_rank("Terminating worker processes")
            for worker_rank in range(1, size()):
                _send(COMMAND_TERMINATE, worker_rank)

class Worker:
    """Worker object responsible for instantiate Clients based on incoming data
    from the Server and perform train/eval functions on it.

    Each worker lives on a different NCCL/Gloo thread and is assigned to a different
    GPU. Via the :code:`dispatch_clients` function, the Server passes to the
    Worker specific instructions to process clients' data, typically in order
    to generate a new model or to compute metrics.

    Attributes:
        model (torch.nn.Module): model being trained.
        data_path (str): path where all clients' data is located.
        do_profiling (bool): if True, analyzes execution in depth.
        val_clients (list): clients list for validation rounds.
        test_clients (list): clients list for testing rounds.
        config (dict): clients configuration.
        val_dataset (torch.utils.data.Dataset): validation dataset.
        test_dataset (torch.utils.data.Dataset): testing dataset.
    """
    def __init__(self, model=None, data_path=None, do_profiling=False, val_clients= None, \
                test_clients=None, config=None, val_dataset = None, test_dataset = None):

        self.model = model
        self.data_path = data_path
        self.do_profiling = do_profiling
        self.config = config
        self.val_clients = val_clients
        self.test_clients = test_clients
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset

    def run(self):
        """Main loop executed by worker nodes.
        
        This method handles the NCCL/Gloo communication between the worker and
        the server. It keeps listening for commands from the Server,
        and performs different actions on the Client assigned depending on 
        the command received.
        """
        # Single GPU flag
        single_gpu = True if size()==1 else False
        print_rank(f"Single GPU flag Client: {single_gpu}", loglevel=logging.DEBUG)
    
        if not single_gpu:
            while True:  # keeps listening for incoming server calls

                # Initialize tensors -- required by torch.distributed
                command, client_idx, mode = 0, 0, 0  # int
                lr, nround = torch.zeros(1), torch.zeros(1) # float

                # Read command
                command = _recv(command)
                print_rank(f"Command received {command} on worker {rank()}", loglevel=logging.DEBUG)

                # Receive server data -- lr, model_params
                if command == COMMAND_UPDATE:
                    print_rank(f"COMMMAND_UPDATE received {rank()}", loglevel=logging.DEBUG)                
                    lr = _recv(lr, 0)
                    model_params = _recv_gradients(0)
                    nround = _recv(nround, 0)
                    server_data = (lr, model_params, int(nround))
                    print_rank(f"Received lr: {lr} and n_params: {len(model_params)} - round {nround}", loglevel=logging.DEBUG)
                    
                elif command == COMMAND_TRAIN:
                    print_rank(f"COMMMAND_TRAIN received {rank()}", loglevel=logging.DEBUG)
                    
                    # Init profiler in training worker
                    profiler = None
                    if self.do_profiling:
                        profiler = cProfile.Profile()
                        profiler.enable()
                                    
                    # Receive client id from Server
                    client_idx = _recv(client_idx)
                    print_rank(f"Cliend idx received from Server: {client_idx}", loglevel=logging.DEBUG)

                    # Instantiate client
                    client_to_process = Client(
                            [client_idx],
                            self.config,
                            self.config['client_config']['type'] == 'optimization') 
                    
                    # Execute Client.get_data()
                    client_data = client_to_process.get_client_data()

                    # Execute Client.process_round()
                    output = client_to_process.process_round(client_data, server_data, self.model, self.data_path)

                    # Send output back to Server
                    if dist.get_backend() == "nccl":
                        # ASYNC mode -- enabled only for nccl backend
                        ack = to_device(torch.tensor(1))
                        dist.isend(tensor=ack, dst=0)
                        _send_train_output(output)
                    else:
                        # SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
                        gather_objects = [output for i in range(size())]
                        output = [None for _ in gather_objects]
                        dist.all_gather_object(output, gather_objects[rank()])

                    # Some cleanup
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize() if torch.cuda.is_available() else None

                    if self.do_profiling:
                        profiler.disable()
                        print_profiler(profiler)

                elif command == COMMAND_TESTVAL:
                    print_rank(f"COMMMAND_TESTVAL received {rank()}", loglevel=logging.DEBUG)

                    # Init profiler in validation worker
                    profiler = None
                    if self.do_profiling:
                        profiler = cProfile.Profile()
                        profiler.enable()
                    
                    # Receive mode and client id from Server
                    mode = _recv(mode)
                    mode = "test" if mode == -2 else "val"
                    client_idx = _recv(client_idx)
                    print_rank(f"Client idx received from Server: {client_idx}, {mode}", loglevel=logging.DEBUG)
                    
                    # Get client and dataset
                    clients = self.val_clients if mode == "val" else self.test_clients
                    dataset = self.val_dataset if mode == "val" else self.test_dataset
                    clients_queue = clients.copy()
                    assert 0 <= client_idx < len(clients_queue)
                    client_to_process = clients_queue.pop(client_idx)

                    # Execute Client.get_data()
                    client_data = client_to_process.get_client_data(dataset)
    
                    # Execute Client.run_testvalidate()
                    output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model)

                    # Send output back to Server
                    if dist.get_backend() == "nccl":
                        # ASYNC mode -- enabled only for nccl backend
                        _, metrics, num_instances = output
                        metrics['num']= {'value': float(num_instances), 'higher_is_better': False}
                        output = metrics
                        print_rank(f"Worker {rank()} output {output}", loglevel=logging.DEBUG)
                        ack = to_device(torch.tensor(1))
                        dist.isend(tensor=ack, dst=0)
                        _send_metrics(output)
                    else:
                        # SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
                        gather_objects = [output for i in range(size())]
                        output = [None for _ in gather_objects]
                        dist.all_gather_object(output, gather_objects[rank()])
                        print_rank(f"Worker {rank()} sent output back", loglevel=logging.DEBUG)

                    # Some cleanup
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize() if torch.cuda.is_available() else None

                    if self.do_profiling:
                        profiler.disable()
                        print_profiler(profiler)

                elif command == COMMAND_TERMINATE:
                    print_rank(f"COMMMAND_TERMINATE received {rank()}", loglevel=logging.DEBUG)

                    # Some cleanup
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize() if torch.cuda.is_available() else None
                    return

                elif command == COMMAND_SYNC_NODES: # Only for sync calls
                    print_rank(f"COMMMAND_SYNC_NODES received {rank()}", loglevel=logging.DEBUG)

                    gather_objects = [None for i in range(size())]
                    output = [None for _ in gather_objects]
                    dist.all_gather_object(output, gather_objects[rank()])
                    print_rank(f"Worker IDLE {rank()} sent dummy output back", loglevel=logging.DEBUG)

                    # Some cleanup
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize() if torch.cuda.is_available() else None
                else:
                    assert False, "unknown command"

    def trigger_evaluate(self):
        global GLOBAL_MESSAGE

        lr, model_params, nround = GLOBAL_MESSAGE
        server_data = (lr, model_params, int(nround))
        mode = "val"

        # Get client and dataset
        clients = self.val_clients if mode == "val" else self.test_clients
        dataset = self.val_dataset if mode == "val" else self.test_dataset
        clients_queue = clients.copy()
        client_to_process = clients_queue.pop()

        # Execute Client.get_data()
        client_data = client_to_process.get_client_data(dataset)

        # Execute Client.run_testvalidate()
        output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model)
        _, metrics, num_instances = output
        metrics['num']= {'value': float(num_instances), 'higher_is_better': False}
        GLOBAL_MESSAGE = (_, metrics, num_instances)

        # Some cleanup
        torch.cuda.empty_cache()
        torch.cuda.synchronize() if torch.cuda.is_available() else None
    
    def trigger_train(self):
        global GLOBAL_MESSAGE
        lr, model_params, nround, client_idx = GLOBAL_MESSAGE
        server_data = (lr, model_params, int(nround))

        # Instantiate client
        client_to_process = Client([client_idx], self.config, self.config['client_config']['type'] == 'optimization') 
    
        # Execute Client.get_data()
        client_data = client_to_process.get_client_data()

        # Execute Client.process_round()
        GLOBAL_MESSAGE = client_to_process.process_round(client_data, server_data, self.model, self.data_path)

        # Some cleanup
        torch.cuda.empty_cache()
        torch.cuda.synchronize() if torch.cuda.is_available() else None

================================================
FILE: core/metrics.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
In this file we define the wrapper class for 
implementing metrics.
'''
import logging

import numpy as np
import torch

from utils import print_rank

class Metrics():

    def __init__(self):
        super().__init__()

    def compute_metrics(self,dataloader, model):
        '''This method is called by ´run_validation_generic´ function 
        inside trainer.py .
        
        This is just a helper function that computes the metrics returned 
        in the inference function inside ´model.py´.
        '''
        print_rank("Computing metrics")
        return self.call_inference(dataloader,model)

    def call_inference(self, dataloader, model):
        
        metrics, sum_metrics = dict(), dict()
        output_tot = {"probabilities": [], "predictions": [], "labels":[]}
        counter = 0

        with torch.no_grad():
            for _, batch in enumerate(dataloader):
                val_loss = model.loss(batch).item()
                inf_results = model.inference(batch)
                inf_results ['loss'] = {'value': val_loss,'higher_is_better': False}
                output = inf_results.pop('output')
                batch_size = inf_results.pop('batch_size')

                for key in inf_results.keys():
                    if not isinstance(inf_results[key], dict):
                        inf_results[key] = {'value':inf_results[key],'higher_is_better': True}
                    sum_metrics[key] = [] if not key in sum_metrics else sum_metrics[key]

                if isinstance(output, dict):
                    output_tot["probabilities"].append(output["probabilities"])
                    output_tot["predictions"].append(output["predictions"])
                    output_tot["labels"].append(output["labels"])

                for q in inf_results.keys():
                    sum_metrics[q].append(inf_results[q]['value']* batch_size)
                counter += batch_size
                torch.cuda.empty_cache()

        output_tot["probabilities"] = np.concatenate(output_tot["probabilities"]) if output_tot["probabilities"] else []
        output_tot["predictions"] = np.concatenate(output_tot["predictions"]) if output_tot["predictions"] else []
        output_tot["labels"] = np.concatenate(output_tot["labels"]) if output_tot["labels"] else []

        # Post-processing of metrics
        print_rank(f"validation complete {counter}", loglevel=logging.DEBUG)
        model.set_train()

        for k in inf_results.keys():
            metrics[k] = inf_results[k]
            metrics[k]['value'] = sum(sum_metrics[k])/counter

        print_rank(f"validation examples {counter}", loglevel=logging.DEBUG)
        torch.cuda.empty_cache()
        
        return output_tot, metrics


================================================
FILE: core/model.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch as T
from abc import ABC, abstractmethod

class BaseModel(ABC, T.nn.Module):
    '''This is a wrapper class for PyTorch models.'''

    @abstractmethod
    def __init__(self,**kwargs):
        super(BaseModel, self).__init__()
        
    @abstractmethod
    def loss(self, input):
        '''Performs forward step and computes the loss

        Returns:
            torch: Computed loss.
        '''
        pass
    
    @abstractmethod
    def inference(self, input):
        '''Performs forward step and computes metrics
             
        Returns:
            dict: The metrics to be computed. The following keys are
            the minimum required by FLUTE during evaluations rounds: 
                - output
                - acc
                - batch_size

            More metrics can be computed by adding the key with a
            dictionary that includes the fields ´value´ and 
            ´higher_is_better´ as follows:

            {'output':output, 
             'acc': accuracy, 
             'batch_size': n_samples, 
             'f1_score': {'value':f1,'higher_is_better': True}}
        '''
        pass

    def set_eval(self):
        '''Bring the model into evaluation mode'''
        self.eval()

    def set_train(self):
        '''Bring the model into training mode'''
        self.train()


================================================
FILE: core/schema.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# '''
# In this file we define the  schema for the configuration 
# files that will be pass it to an instance of the Validator 
# in e2e_trainer.py 
# '''

{
    'model_config':{
            'required': True,
            'type': 'dict',
            'allow_unknown': True,
            'schema': {
                'model_type': {'required': True, 'type':'string'},
                'model_folder': {'required': True, 'type':'string'},
                'BERT':{
                    'required':False,
                    'type': 'dict',
                    'allow_unknown': True,
                    'schema':{
                        'loader_type': {'required': False, 'type': 'string'},
                        'model': {
                            'required': True,
                            'type': 'dict',
                            'allow_unknown': True,
                            'schema': {
                                'model_name_or_path': {'required': False, 'type':'string'},
                                'model_name': {'required': True, 'type':'string'},
                                'process_line_by_line': {'required': True, 'type':'boolean'},
                            }
                        }
                    }
                },
            }
    },

    'dp_config':{
            'required': True,
            'type': 'dict',
            'allow_unknown': True,
            'schema': {
                'enable_local_dp': {'required': True, 'type':'boolean'},
                'enable_global_dp': {'required': False, 'type':'boolean'},
                'eps': {'required': False, 'type':'float'},
                'delta': {'required': False, 'type':'float'},
                'global_sigma': {'required': False, 'type':'float'},
                'max_grad': {'required': False, 'type':'float'},
                'max_weight': {'required': False, 'type':'float'},
                'weight_scaler': {'required': False, 'type':'float'},
                'min_weight': {'required': False, 'type':'float'},
                }
    },

    'privacy_metrics_config':{
            'required': True,
            'type': 'dict',
            'allow_unknown': True,
            'schema': {
                'apply_metrics': {'required': True, 'type':'boolean'},
                'apply_indices_extraction': {'required': False, 'type':'boolean'},
                'allowed_word_rank': {'required': False, 'type':'integer'},
                'apply_leakage_metric': {'required': False, 'type':'boolean'},
                'max_leakage': {'required': False, 'type':'float'},
                'adaptive_leakage_threshold': {'required': False, 'type':'float'},
                'is_leakage_weighted': {'required': False, 'type':'boolean'},
                'attacker_optimizer_config': {'required': False, 'type':'dict', 'allow_unknown': True},
                }
    },

    'strategy':{
        'required': True,
        'type': 'string'
    },

    'server_config':{
            'required': True,
            'type': 'dict',
            'allow_unknown': True,
            'schema': {
                'wantRL': {'required': True, 'type':'boolean', 'allow_unknown': True},
                'RL': {'required': False, 'type':'dict'},
                'resume_from_checkpoint': {'required': True, 'type':'boolean'},
                'do_profiling': {'required': True, 'type':'boolean'},
                'optimizer_config': {
                    'required': True, 
                    'type':'dict',
                    'allow_unknown': True,
                    'schema': {
                        'type': {'required': True, 'type':'string', 'allowed':['sgd', 'adam','adamax', 'lars', 'LarsSGD', 'lamb', 'adamW']},
                        'lr': {'required': True, 'type':'float'},
                        'weight_decay': {'required': False, 'type':'float'},
                    }
                },
                'annealing_config': {
                    'required': True, 
                    'type':'dict',
                    'allow_unknown': True,
                    'schema': {
                        'type': {'required': True, 'type':'string'},
                        'step_interval': {'required': True, 'type':'string'},
                        'gamma': {'required': True, 'type':'float'},
                        'step_size': {'required': True, 'type':'integer'},
                    }
                },
                'val_freq': {'required': False, 'type':'integer', 'default': 1},
                'rec_freq': {'required': False, 'type':'integer', 'default': 8},
                'initial_val': {'required': False, 'type':'boolean', 'default': True},
                'initial_rec': {'required': False, 'type':'boolean', 'default': False},
                'max_iteration': {'required': False, 'type':'integer', 'default': 10000},
                'num_clients_per_iteration': {'required': False, 'type':'integer', 'default': 1},
                'data_config': {
                    'required': True, 
                    'type':'dict',
                    'allow_unknown': True,
                    'keysrules':{'forbidden':['num_clients']},
                    'schema': {
                        'val': {
                            'required': True, 
                            'type':'dict',
                            'allow_unknown': True,
                            'schema': {
                                'batch_size': {'required': False, 'type':'integer', 'default': 40},
                                'val_data': {'required': True, 'type':'string', 'nullable':True},
                                'tokenizer_type': {'required': False, 'type':'string'},
                                'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
                                'vocab_dict': {'required': False, 'type':'string'},
                                'pin_memory': {'required': False, 'type':'boolean', 'default': True},
                                'num_workers': {'required': False, 'type':'integer', 'default': 1},
                                'num_frames': {'required': False, 'type':'integer', 'default': 0},
                                'max_batch_size': {'required': False, 'type':'integer', 'default': 0},
                                'max_num_words': {'required': False, 'type':'integer'},
                                'max_grad_norm': {'required': False, 'type':'float', 'default': 5.0 },
                                'unsorted_batch': {'required': False, 'type':'boolean', 'default': False},
                                'cache_dir': {'required': False, 'type':'string'},
                            },
                        },
                        'test': {
                            'required': True, 
                            'type':'dict',
                            'allow_unknown': True,
                            'schema': {
                                'batch_size': {'required': False, 'type':'integer', 'default': 40},
                                'test_data': {'required': True, 'type':'string', 'nullable': True},
                                'tokenizer_type': {'required': False, 'type':'string'},
                                'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
                                'vocab_dict': {'required': False, 'type':'string'},
                                'pin_memory': {'required': False, 'type':'boolean', 'default': True},
                                'num_workers': {'required': False, 'type':'integer', 'default': 1},
                                'num_frames': {'required': False, 'type':'integer', 'default': 0},
                                'max_batch_size': {'required': False, 'type':'integer', 'default': 0},
                                'max_num_words': {'required': False, 'type':'integer'},
                                'max_grad_norm': {'required': False, 'type':'float', 'default': 5.0 },
                                'unsorted_batch': {'required': False, 'type':'boolean', 'default': False},
                                'cache_dir': {'required': False, 'type':'string'},
                            },
                        },
                        'train': {
                            'required': False, 
                            'type':'dict',
                            'allow_unknown': True,
                            'schema': {
                                'batch_size': {'required': False, 'type':'integer', 'default': 40},
                                'train_data_server': {'required': False, 'type':'string'},
                                'desired_max_samples': {'required': False, 'type':'integer'},
                                'tokenizer_type': {'required': False, 'type':'string'},
                                'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
                                'vocab_dict': {'required': False, 'type':'string'},
                                'pin_memory': {'required': False, 'type':'boolean', 'default': True},
                                'num_workers': {'required': False, 'type':'integer', 'default': 1},
                                'num_frames': {'required': False, 'type':'integer', 'default': 0},
                                'max_batch_size': {'required': False, 'type':'integer', 'default': 0},
                                'max_num_words': {'required': False, 'type':'integer'},
                                'max_grad_norm': {'required': False, 'type':'float', 'default': 5.0 },
                                'unsorted_batch': {'required': False, 'type':'boolean', 'default': False},
                                'cache_dir': {'required': False, 'type':'string'},
                            }
                        },
                    }
                },
                'type': {
                    'required': False, 
                    'type':'string',
                    'allowed':['model_optimization', 'personalization'],
                    'default': 'model_optimization'
                },
                'aggregate_median': {'required': False, 'type':'string'},
                'initial_lr_client': {'required': True, 'type':'float'},
                'lr_decay_factor': {'required': True, 'type':'float'},
                'weight_train_loss': {'required': True, 'type':'string'},
                'best_model_criterion': {'required': False, 'type':'string', 'default':'loss'},
                'fall_back_to_best_model': {'required': False, 'type':'boolean', 'default': False},
                'softmax_beta': {'required': True, 'type':'float'},
                'server_replay_config': {
                    'required': False, 
                    'type':'dict',
                    'schema':{
                        'server_iterations': {'required': True, 'type':'integer'},
                        'optimizer_config': {
                            'required': True, 
                            'type':'dict',
                            'allow_unknown': True,
                            'schema': {
                                'type': {'required': True, 'type':'string', 'allowed':['sgd', 'adam','adamax', 'lars', 'LarsSGD', 'lamb', 'adamW']},
                                'lr': {'required': True, 'type':'float'},
                                'weight_decay': {'required': False, 'type':'float'},
                                'amsgrad': {'required': False, 'type':'boolean'},
                            }
                        },
                    }
                },
                'nbest_task_scheduler': {
                    'required': False, 
                    'type':'dict',
                    'schema':{
                        'num_tasks': {'required': True, 'type':'integer'}, 
                        'iteration_per_task': {'required': True, 'type':'integer'},
                    }
                },
            }
    },

    'client_config':{
        'required': True,
        'type': 'dict',
        'allow_unknown': True,
        'schema': {
            'meta_learning': {'required': False, 'type':'string'},
            'stats_on_smooth_grad': {'required': False, 'type':'boolean'},
            'ignore_subtask': {'required': True, 'type':'boolean'},
            'num_skips_threshold': {'required': False, 'type':'integer'},
            'copying_train_data': {'required': False, 'type':'boolean'},
            'do_profiling': {'required': True, 'type':'boolean'},
            'data_config': {
                'required': True, 
                'type':'dict',
                'allow_unknown': True,
                'keysrules':{'forbidden':['num_clients']},
                'schema': {
                    'train': {
                        'required': True, 
                        'type':'dict',
                        'allow_unknown': True,
                        'schema': {
                            'batch_size': {'required': False, 'type':'integer', 'default': 40},
                            'list_of_train_data': {'required': True, 'type':'string', 'nullable': True},
                            'tokenizer_type': {'required': False, 'type':'string'},
                            'prepend_datapath': {'required': False, 'type':'boolean', 'default': False},
                            'vocab_dict': {'required': False, 'type':'string'},
                            'pin_memory': {'required': False, 'type':'boolean', 'default': True},
                            'num_workers': {'required': False, 'type':'integer', 'default': 1},
                            'num_frames': {'required': False, 'type':'integer', 'default': 0},
                            'max_batch_size': {'required': False, 'type':'integer', 'default': 0},
                            'max_num_words': {'required': False, 'type':'integer'},
                            'max_grad_norm': {'required': False, 'type':'float', 'default': 5.0 },
                            'unsorted_batch': {'required': False, 'type':'boolean', 'default': False},
                        }
                    },
                }
            },
            'type': {
                'required': False, 
                'type':'string',
                'allowed':['optimization', 'gradient_computation'],
                'default': 'gradient_computation',
            },
            'meta_optimizer_config': {
                'required': False, 
                'type':'dict',
                'allow_unknown': True,
                'schema': {
                    'type': {'required': True, 'type':'string', 'allowed':['sgd', 'adam','adamax', 'lars', 'LarsSGD', 'lamb', 'adamW']},
                    'lr': {'required': True, 'type':'float'},
                }
            },
            'optimizer_config': {
                'required': True, 
                'type':'dict',
                'allow_unknown': True,
                'schema': {
                    'type': {'required': True, 'type':'string', 'allowed':['sgd', 'adam','adamax', 'lars', 'LarsSGD', 'lamb', 'adamW']},
                    'lr': {'required': False, 'type':'float'},
                    'weight_decay': {'required': False, 'type':'float'},
                }
            },
            'annealing_config': {
                'required': False, 
                'type':'dict',
                'allow_unknown': True,
                'schema': {
                    'type': {'required': True, 'type':'string'},
                    'step_interval': {'required': True, 'type':'string'},
                    'gamma': {'required': False, 'type':'float'},
                    'step_size': {'required': False, 'type':'integer'},
                }
            },
            'ss_config': {'required': False, 'type':'dict', 'allow_unknown': True},
        }
    },
}

================================================
FILE: core/server.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
In this file, we define the classes that live inside 'worker 0', the worker
responsible for orchestration and aggregation. The main class is the
OptimizationServer, which sends clients to the other workers to process and
combines the resulting models.
'''

import json
import logging
import os
import random
import shutil
import time
from collections import defaultdict

import numpy as np
import torch

# Internal imports
import core.federated as federated
from core.evaluation import Evaluation
from core.client import Client
from .strategies import select_strategy
from .trainer import (
    ModelUpdater,
    Trainer,
    set_component_wise_lr,
)
from utils import (
    get_lr,
    print_rank,
    update_json_log,
    to_device,
)

# For profiling
import cProfile
import pstats

# AzureML-related libs
from azureml.core import Run
run = Run.get_context()


class OptimizationServer(federated.Server):
    def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
                 config, idx_val_clients, idx_test_clients, single_worker):
        '''Implement Server's orchestration and aggregation.

        This is the main Server class, that actually implements orchestration
        and aggregation, inheriting from `federated.Server`, which deals with
        communication only.

        The `train` method is central in FLUTE, as it defines good part of what
        happens during training.

        Args:
            num_clients (int): total available clients.
            model (torch.nn.Module): neural network model.
            optimizer (torch.optim.Optimizer): optimizer.
            ss_scheduler: scheduled sampling scheduler.
            data_path (str): points to where data is.
            model_path (str): points to where pretrained model is.
            server_train_dataloader (torch.utils.data.DataLoader): dataloader for training
            config (dict): JSON style configuration parameters
            idx_val_clients (list): validation client ids
            idx_test_clients (list): testing clients ids
        '''

        super().__init__()

        # Initialize all attributes from arguments
        self.client_idx_list = list(range(num_clients))
        self.config = config
        server_config = config['server_config']
        decoder_config = config.get('decoder_config', None)

        self.max_iteration = server_config['max_iteration']
        self.do_clustering = server_config.get('clustering', False)
        self.send_dicts = server_config.get('send_dicts', False)

        self.num_clients_per_iteration = [int(x) for x in server_config['num_clients_per_iteration'].split(',')] \
            if isinstance(server_config['num_clients_per_iteration'], str) \
            else [server_config['num_clients_per_iteration']]

        self.val_freq = server_config['val_freq']
        self.req_freq = server_config['rec_freq']

        self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients, single_worker)

        # TODO: does this need to be adjusted for custom metrics?
        self.metrics = dict()

        self.model_backup_freq = server_config.get('model_backup_freq', 100)
        self.worker_trainer_config = server_config.get('trainer_config', {})

        self.aggregate_median = server_config['aggregate_median']
        self.initial_lr_client = server_config.get('initial_lr_client', -1.0)
        self.lr_decay_factor = server_config.get('lr_decay_factor', 1.0)

        self.model_type = config['model_config']['model_type']
        self.quant_thresh = config['client_config'].get('quant_thresh', None)
        self.quant_bits = config['client_config'].get('quant_bits', 10)

        self.list_of_train_data = config['client_config']['data_config']['train']['list_of_train_data']
        self.data_path = data_path
        self.single_worker = single_worker

        # Get max grad norm from data config
        if 'train' in server_config['data_config']:
            max_grad_norm = server_config['data_config']['train'].get('max_grad_norm', None)
        else:
            max_grad_norm = None

        # Creating an instance to update the model with stats aggregated from workers
        self.worker_trainer = ModelUpdater(
            model=model,
            optimizer=optimizer,
            ss_scheduler=ss_scheduler,
            train_dataloader=server_train_dataloader,
            val_dataloader=None,
            max_grad_norm=max_grad_norm,
            anneal_config=server_config['annealing_config'],
            model_type=self.model_type,
            decoder_config=decoder_config
        )
        self.metrics['worker_trainer'] = self.worker_trainer
        # Creating an instance for the server-side trainer (runs mini-batch SGD)
        self.server_replay_iterations = None
        self.server_trainer = None
        if server_train_dataloader is not None:
            assert 'server_replay_config' in server_config, 'server_replay_config is not set'
            assert 'optimizer_config' in server_config[
                'server_replay_config'], 'server-side replay training optimizer is not set'
            self.server_optimizer_config = server_config['server_replay_config']['optimizer_config']
            self.server_trainer_config = server_config['server_replay_config'].get('trainer_config', {})
            self.server_replay_iterations = server_config['server_replay_config']['server_iterations']
            self.server_trainer = Trainer(
                model=model,
                optimizer=None,
                ss_scheduler=ss_scheduler,
                train_dataloader=server_train_dataloader,
                server_replay_config=server_config['server_replay_config'],
                max_grad_norm=server_config['server_replay_config']\
                                            .get('max_grad_norm',server_config['data_config']['train']\
                                                .get('max_grad_norm',None)),
                anneal_config=server_config['server_replay_config'].get('annealing_config', None),
                ignore_subtask = server_config['server_replay_config'].get('ignore_subtask', False)
            )

        self.skip_model_update = False  # will not update the model if True

        self.train_loss = 0.0
        self.model_path = model_path
        self.best_model_criterion = server_config['best_model_criterion']
        self.fall_back_to_best_model = server_config['fall_back_to_best_model']
        self.last_model_path = os.path.join(self.model_path, 'latest_model.tar')
        self.best_model_path = os.path.join(self.model_path,
            'best_val_{}_model.tar'.format(self.best_model_criterion))
        self.log_path = os.path.join(self.model_path, 'status_log.json')
        self.cur_iter_no = 0  # keep the iteration number for Tensor board plotting
        self.lr_weight = 1.0

        self.losses = []
        self.no_label_updates = 0  # no. label updates

        # Update the parameters above if the log file
        if server_config.get('resume_from_checkpoint', False):
            self.load_saved_status()

        # Decoding config
        self.decoder_config = decoder_config
        self.spm_model = server_config['data_config']['test'].get('spm_model', None)

        self.do_profiling = server_config.get('do_profiling', False)

        StrategyClass = select_strategy(config['strategy'])
        self.strategy = StrategyClass('server', self.config, self.model_path)
        print_rank(f'Server successfully instantiated strategy {self.strategy}', loglevel=logging.DEBUG)

    def load_saved_status(self):
        '''Load checkpoint from disk'''

        # Check if model is on disk, if so loads it onto trainer
        if os.path.exists(self.last_model_path):
            print_rank('Resuming from checkpoint model {}'.format(self.last_model_path))
            self.worker_trainer.load(self.last_model_path, update_lr_scheduler=True, update_ss_scheduler=True)
            if self.server_trainer is not None:
                self.server_trainer.model = self.worker_trainer.model  # make sure that the models are in sync

        # Check if log is on disk, if so loads it onto current stats
        if os.path.exists(self.log_path):
            with open(self.log_path, 'r') as logfp:  # loading the iteration no., best loss and CER
                elems = json.load(logfp)
                self.cur_iter_no = elems.get('i', 0)
                self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
                self.metrics['best_val_acc'] = elems.get('best_val_acc', 0)
                self.metrics['best_test_loss'] = elems.get('best_test_loss', float('inf'))
                self.metrics['best_test_acc'] = elems.get('best_test_acc', 0)
                self.lr_weight = elems.get('weight', 1.0)
                self.no_label_updates = elems.get('num_label_updates', 0)
                print_rank(f'Resuming from status_log: cur_iter: {self.cur_iter_no}')

    def run(self):
        '''Trigger training.

        This is a simple wrapper to the `train` method.
        '''
        print_rank('server started')
        self.train()
        print_rank('server terminated')

    def train(self):
        '''Main method for training.'''

        self.run_stats = {
            'secsPerClientRound': [],
            'secsPerClient': [],
            'secsPerClientTraining': [],
            'secsPerClientSetup': [],
            'secsPerClientFull': [],
            'secsPerRoundHousekeeping': [],
            'secsPerRoundTotal': [],
            'communicationCosts': []
        }

        run.log('Max iterations', self.max_iteration)
        try:
            self.worker_trainer.model = to_device(self.worker_trainer.model)

            # Do an initial validation round to understand the pretrained model's validation accuracy
            # Skip if we resumed from a checkpoint (cur_iter_no > 0)
            eval_list = []
            if self.cur_iter_no == 0:

                if self.config['server_config']['initial_rec']:
                    eval_list.append('test')
                if self.config['server_config']['initial_val']:
                    eval_list.append('val')
                    run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))

                print_rank("Running {} at itr={}".format(eval_list, self.cur_iter_no))
                self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
                eval_list = [] # some cleanup

            # Dump all the information in aggregate_metric
            print_rank('Saving Model Before Starting Training', loglevel=logging.INFO)
            for token in ['best_val_loss', 'best_val_acc', 'best_test_acc', 'latest']:
                self.worker_trainer.save(
                    model_path=self.model_path,
                    token=token,
                    config=self.config['server_config']
                )

            # Training loop
            self.worker_trainer.model.train()
            for i in range(self.cur_iter_no, self.max_iteration):
                begin = time.time()
                metrics_payload = {}

                def log_metric(k, v):
                    metrics_payload[k] = v

                print_rank('==== iteration {}'.format(i))
                log_metric('Current iteration', i)

                # Initial value for the learning rate of the worker
                initial_lr = self.initial_lr_client * self.lr_weight
                print_rank('Client learning rate {}'.format(initial_lr))

                # Run training on clients
                self.worker_trainer.model.zero_grad()
                self.train_loss = []

                if self.send_dicts: # Send state dictionaries
                    glob_payload = [self.worker_trainer.model.state_dict()[param_key].to(torch.device('cpu')) for param_key in self.worker_trainer.model.state_dict()]
                else: # Send parameters
                    glob_payload = [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()]
                
                server_data = (initial_lr, glob_payload, i)

                # Random number of clients per iteration
                if len(self.num_clients_per_iteration) > 1:
                    num_clients_curr_iter = random.randint(
                        self.num_clients_per_iteration[0],
                        self.num_clients_per_iteration[1]
                    )
                else:
                    num_clients_curr_iter = self.num_clients_per_iteration[0]
                log_metric('Clients for round', num_clients_curr_iter)

                # Perform annealing in quantization threshold
                if self.quant_thresh is not None:
                    self.config['client_config']['quant_thresh'] *= self.config['client_config'].get('quant_anneal', 1.0)
                    self.quant_thresh = self.config['client_config']['quant_thresh']
                    log_metric('Quantization Thresh.', self.config['client_config']['quant_thresh'])

                #  Create the pool of clients -- sample from this pool to assign to workers
                sampled_idx_clients = random.sample(self.client_idx_list,
                    num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
                
                # Initialize stats
                clients_begin = time.time()

                client_losses = []
                client_mag_grads = []
                client_mean_grads = []
                client_var_grads = []
                client_norm_grads = []

                self.run_stats['secsPerClient'].append([])
                self.run_stats['secsPerClientFull'].append([])
                self.run_stats['secsPerClientTraining'].append([])
                self.run_stats['secsPerClientSetup'].append([])
                self.run_stats['communicationCosts'].append([])

                # Check if we want privacy metrics
                apply_privacy_metrics = self.config.get('privacy_metrics_config', None) and \
                    self.config['privacy_metrics_config']['apply_metrics']
                adaptive_leakage = apply_privacy_metrics and \
                    self.config['privacy_metrics_config'].get('adaptive_leakage_threshold', None)
                if apply_privacy_metrics:
                    privacy_metrics_stats = defaultdict(list)

                # Initialize profiler
                profiler = None
                if self.do_profiling:
                    profiler = cProfile.Profile()
                    profiler.enable()

                # Reset gradient for the model before assigning the new gradients
                self.worker_trainer.model.zero_grad()
                
                print_rank(f"Clients sampled from server {sampled_idx_clients}", loglevel=logging.DEBUG)
                for client_output in self.process_clients(sampled_idx_clients, server_data, self.single_worker):
                    # Process client output
                    client_timestamp = client_output['ts']
                    client_stats = client_output['cs']
                    client_loss = client_output['tl']
                    client_mag_grad = client_output['mg']
                    client_mean_grad = client_output['ng']
                    client_var_grad = client_output['vg']
                    client_norm_grad = client_output['rg']
                    client_payload = client_output['pl']

                    if apply_privacy_metrics:
                        privacy_stats = client_output['ps']
                        for metric, value in privacy_stats.items():
                            privacy_metrics_stats[metric].append(value)

                    self.run_stats['communicationCosts'][-1].append(time.time() - client_timestamp)

                    # Get actual pseudo-gradients for aggregation
                    payload_processed = self.strategy.process_individual_payload(self.worker_trainer, client_payload)
                    if not payload_processed:
                        print_rank('Dropping client', loglevel=logging.DEBUG)
                        num_clients_curr_iter -= 1
                        continue

                    # Aggregate stats
                    self.train_loss.append(client_loss)
                    client_losses.append(client_loss)
                    client_mag_grads.append(client_mag_grad.item())
                    client_mean_grads.append(client_mean_grad.item())
                    client_var_grads.append(client_var_grad.item())
                    client_norm_grads.append(client_norm_grad.item())

                    # Mark the end of client processing
                    client_end = time.time()

                    self.run_stats['secsPerClientFull'][-1].append(client_stats['full cost'])
                    self.run_stats['secsPerClientTraining'][-1].append(client_stats['training'])
                    self.run_stats['secsPerClientSetup'][-1].append(client_stats['setup'])
                    self.run_stats['secsPerClient'][-1].append(client_end - clients_begin)

                # Tear down profiler
                if self.do_profiling:
                    profiler.disable()
                    stats = pstats.Stats(profiler)
                    stats.sort_stats('cumulative').print_stats()

                # Prepare output
                client_mag_grads = np.array(client_mag_grads)
                client_mean_grads = np.array(client_mean_grads)
                client_var_grads = np.array(client_var_grads)
                client_norm_grads = np.array(client_norm_grads)

                client_stats = (client_mag_grads, client_mean_grads, client_var_grads)

                dump_norm_stats = self.config.get('dump_norm_stats', False)
                if dump_norm_stats:
                    with open(os.path.join(self.model_path, 'norm_stats.txt'), 'a', encoding='utf-8') as outF:
                        outF.write('{}\n'.format(json.dumps(list(client_norm_grads))))

                # Print the privacy metrics
                if apply_privacy_metrics:
                    for metric, values in privacy_metrics_stats.items():
                        if metric == 'Dropped clients':
                            log_metric(metric, sum(values))
                        else:
                            log_metric(metric, max(values))

                if type(adaptive_leakage) is float:
                    values = privacy_metrics_stats['Practical epsilon (Max leakage)']
                    new_threshold = list(sorted(values))[int(adaptive_leakage*len(values))]
                    print_rank('Updating leakage threshold to {}'.format(new_threshold))
                    self.config['privacy_metrics_config']['max_allowed_leakage'] = new_threshold

                # Mark that all clients have been processed
                end = time.time()
                self.run_stats['secsPerClientRound'].append(end - begin)
                begin = end

                # Log the training loss to tensorboard/AML
                log_metric('Training loss', sum(self.train_loss))

                # Combine payloads
                self.losses = self.strategy.combine_payloads(
                    worker_trainer=self.worker_trainer,
                    curr_iter=i,
                    num_clients_curr_iter=num_clients_curr_iter,
                    total_clients = len(self.client_idx_list),
                    client_stats=client_stats,
                    logger=log_metric,
                )
                
                # Run a couple of iterations of training data on the server
                if self.server_trainer is not None:
                    print_rank('Running replay iterations on server')

                    if 'updatable_names' in self.server_trainer_config:
                        set_component_wise_lr(
                            self.worker_trainer.model,
                            self.server_optimizer_config,
                            self.server_trainer_config['updatable_names']
                        )
                    self.server_trainer.prepare_iteration(self.worker_trainer.model)
                    self.server_trainer.train_desired_samples(self.server_replay_iterations)
                    self.worker_trainer.model.load_state_dict(self.server_trainer.model.state_dict())
                    torch.cuda.empty_cache()

                # Update a sampling scheduler
                print_rank('Run ss scheduler')
                self.worker_trainer.run_ss_scheduler()

                # Run inference and score on val/test depending on the iter. number
                if ((i+1) % self.val_freq) == 0:
                    eval_list.append("val")
                if ((i+1) % self.req_freq) == 0 :
                    eval_list.append("test")
                
                if len(eval_list)> 0:
                    print_rank('Running {} at itr={}'.format(eval_list,i+1))
                    self.metrics['worker_trainer'] = self.worker_trainer
                    if hasattr(self.strategy,'tmp_unsup'):
                        self.metrics['tmp_sup'] = self.strategy.tmp_sup
                        self.metrics['tmp_unsup'] = self.strategy.tmp_unsup
                    self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
                    self.losses = self.evaluation.losses
                    eval_list = []

                # Create a schedule for the initial_lr (for the worker)
                if 'val' in eval_list:
                    run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
                    if not (self.losses[0] < self.metrics['best_val_loss']):
                        self.lr_weight *= self.lr_decay_factor
                        print_rank('LOG: Client weight of learning rate {}..'.format(self.lr_weight))

                # Backup the current best models
                self.backup_models(i)

                # Fall back to the best model if the option is enabled
                self.fall_back_to_prev_best_status()

                # Logging the latest best values only after the 1st val/test round has been executed
                if len(self.metrics) > 1:
                    update_json_log(
                        self.log_path,
                        {
                            'i': i + 1,
                            'best_val_loss': float(self.metrics['best_val_loss']),
                            'best_val_acc': float(self.metrics['best_val_acc']),
                            'best_test_loss': float(self.metrics['best_test_loss']),
                            'best_test_acc': float(self.metrics['best_test_acc']),
                            'weight': float(self.lr_weight),
                            'num_label_updates': int(self.no_label_updates)
                        },
                    )

                end = time.time()

                # Aggregate stats
                self.run_stats['secsPerRoundHousekeeping'].append(end - begin)
                self.run_stats['secsPerRoundTotal'].append(self.run_stats['secsPerClientRound'][-1] + \
                    self.run_stats['secsPerRoundHousekeeping'][-1])

                log_metric('secsPerRoundTotal', self.run_stats['secsPerRoundTotal'][-1])
                if self.do_profiling:
                    log_metric('secsPerClientRound', self.run_stats['secsPerClientRound'][-1])
                    log_metric('secsPerRoundHousekeeping', self.run_stats['secsPerRoundHousekeeping'][-1])

                    metrics_for_stats = [
                        'secsPerClient',
                        'secsPerClientTraining',
                        'secsPerClientFull',
                        'secsPerClientSetup',
                        'communicationCosts',
                    ]

                    for metric in metrics_for_stats:
                        log_metric(f'{metric}Mean', np.mean(self.run_stats[metric][-1]))
                        log_metric(f'{metric}Median', np.median(self.run_stats[metric][-1]))
                        log_metric(f'{metric}Max', max(self.run_stats[metric][-1]))

                    for k in self.run_stats:
                        if k in metrics_for_stats:
                            print_rank('{}: {}'.format(k, max(self.run_stats[k][-1])), loglevel=logging.DEBUG)
                        else:
                            print_rank('{}: {}'.format(k, self.run_stats[k][-1]), loglevel=logging.DEBUG)

                # Log all the metrics
                for k in metrics_payload:
                    run.log(k, metrics_payload[k])

        finally:  # perform cleanup even if error was raised above
            self.terminate_workers(terminate=(not self.do_clustering))

    def backup_models(self, i):
        '''Save the current best models.

        Save CER model, the best loss model and the best WER model. This occurs
        at a specified period.

        Args:
            i: no. of iterations.
        '''

        # Always save the latest model
        self.worker_trainer.save(
            model_path=self.model_path,
            token='latest',
            config=self.config['server_config'],
        )

        if (i % self.model_backup_freq) == 0:  # save the current best models
            self.worker_trainer.save(
                model_path=self.model_path,
                token='epoch{}'.format(i),
                config=self.config['server_config']
            )

            for bodyname in ['best_val_acc', 'best_val_loss', 'best_test_acc']:
                src_model_path = os.path.join(self.model_path, '{}_model.tar'.format(bodyname))
                if os.path.exists(src_model_path):
                    dst_model_path = os.path.join(self.model_path, 'epoch{}_{}_model.tar'.format(i, bodyname))
                    shutil.copyfile(src_model_path, dst_model_path)
                    print_rank('Saved {}'.format(dst_model_path))

    def fall_back_to_prev_best_status(self):
        '''Go back to the past best status and switch to the recent best model.'''

        if self.fall_back_to_best_model:
            print_rank('falling back to model {}'.format(self.best_model_path))

            # Save current learning rate
            tmp_lr = get_lr(self.worker_trainer.optimizer)

            # Load previous best model
            self.worker_trainer.load(self.best_model_path, update_lr_scheduler=False, update_ss_scheduler=False)

            # Update previous learning rate on optimizer
            for g in self.worker_trainer.optimizer.param_groups:
                g['lr'] = tmp_lr

            if self.server_trainer is not None:
                self.server_trainer.model = self.worker_trainer.model  # make sure that the models are in sync


def select_server(server_type):
    '''Select a server type using different possible strings.

    Right now this just returns `OptimizationServer`, but this
    function could be useful when there are multiple choices of
    server.

    Args:
        server_type (str): indicates server choice.
        config (dict): config parsed from YAML, passed so that
            parameters can be used to select a given server.
    '''
    if server_type == "personalization":
        from experiments.cv.server import PersonalizationServer
        return PersonalizationServer
    else:
        return OptimizationServer


================================================
FILE: core/strategies/__init__.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .base import BaseStrategy
from .fedavg import FedAvg
from .dga import DGA
from .fedlabels import FedLabels

def select_strategy(strategy):
    ''' Selects the aggregation strategy class
    
    NOTE: FedProx uses FedAvg weights during aggregation, 
    which are proportional to the number of samples in 
    each client.
    '''
    if strategy.lower() == 'dga':
        return DGA
    elif strategy.lower() in ['fedavg', 'fedprox']:
        return FedAvg
    elif strategy.lower() == 'fedlabels':
        return FedLabels
    else:
        raise ValueError(f'cannot use strategy f{strategy}')

================================================
FILE: core/strategies/base.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import abstractmethod


@abstractmethod
class BaseStrategy:
    def __init__(self, mode, config, model_path=None):
        '''Federated learning strategy

        Args:
            mode (str): which part the instantiated object should play,
                typically either :code:`client` or :code:`server`.
            config (dict): initial config dict.
            model_path (str): where to find model, needed for debugging only.
        '''
        pass

    def generate_client_payload(self, trainer):
        '''Generate client payload

        Args:
            trainer (core.Trainer object): trainer on client.

        Returns:
            dict containing payloads in some specified format.
        '''
        pass

    def process_individual_payload(self, worker_trainer, payload):
        '''Process client payload
        
        Args:
            worker_trainer (core.Trainer object): trainer on server
                (aka model updater).
            payload (dict): whatever is generated by
                :code:`generate_client_payload`.

        Returns:
            True if processed succesfully, False otherwise.
        '''
        pass

    def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
        '''Combine payloads to update model
        
        Args:
            worker_trainer (core.Trainer object): trainer on server
                (aka model updater).
            curr_iter (int): current iteration.
            num_clients_curr_iter (int): number of clients on current iteration.
            total_clients (int): size of total pool of clients (for privacy accounting)
            client_stats (dict): stats being collected.
            logger (callback): function called to log quantities.
        '''
        pass

================================================
FILE: core/strategies/dga.py
================================================
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy
import json
import logging
import math
import os

import numpy as np
import torch

from extensions import privacy, RL, quant_model
from utils import compute_grad_cosines, print_rank, to_device
from core.strategies import BaseStrategy
from core.strategies.utils import (
    aggregate_gradients_inplace,
    filter_weight,
)

from azureml.core import Run
run = Run.get_context()

MIN_WEIGHT = 1e-7


class DGA(BaseStrategy):
    '''Dynamic Gradient Aggregation'''

    def __init__(self, mode, config, model_path=None):
        ''' Dynamic Gradient Aggregation (DGA) strategy.

        For more info see arXiv:2106.07578.

        Args:
            mode (str): which part the instantiated object should play,
                typically either :code:`client` or :code:`server`.
            config (dict): initial config dict.
            model_path (str): where to find model, needed for debugging only.
        '''

        super().__init__(mode=mode, config=config, model_path=model_path)

        if mode not in ['client', 'server']:
            raise ValueError('mode in strategy must be either `client` or `server`')

        self.config = config
        self.model_path = model_path
        self.mode = mode

        # Parse config
        self.model_config = config['model_config']
        self.client_config = config['client_config']
        self.server_config = config['server_config']

        self.dp_config = config.get('dp_config', None)

        if mode == 'client':
            self.stats_on_smooth_grad = self.client_config.get('stats_on_smooth_grad', False)
            self.quant_threshold = self.client_config.get('quant_thresh', None)
            self.quant_bits = self.client_config.get('quant_bits', 10)
        elif mode == 'server':
            self.dump_norm_stats = self.config.get('dump_norm_stats', False)
            self.aggregate_fast = self.server_config.get('fast_aggregation', False)
            self.want_rl = self.server_config.get('wantRL', False)
            self.stale_prob = self.server_config.get('stale_prob', 0.0)

            self.skip_model_update = False

            # Do some checks and create objects based on configs
            if self.aggregate_fast:
                print_rank('It is NOT possible to enable RL with fast_aggregation, RL is set to False', loglevel=logging.INFO)
                self.want_rl = False

                print_rank('It is NOT possible in Current Implementation to have stale gradients with fast_aggregation, stale_prob is set to 0.0', loglevel=logging.INFO)
                self.stale_prob = 0.0

            if self.want_rl:
                self.rl = RL(config=self.server_config)

            # Initialize accumulators
            self.client_parameters_stack = []
            self.client_parameters_stack_stale = []
            self.client_weights = []

            self.weight_sum_stale = 0.0

    def generate_client_payload(self, trainer):
        '''Generate client payload

        Args:
            trainer (core.Trainer object): trainer on client.

        Returns:
            dict containing payloads in some specified format.
        '''

        if self.mode != 'client':
            raise RuntimeError('this method can only be invoked by the client')

        # Get weights for aggregation, potentially using DGA
        weight = 1.0
        add_weight_noise = False

        # Reset gradient stats and recalculate them on the smooth/pseudo gradient
        if self.stats_on_smooth_grad:
            trainer.reset_gradient_power()
            trainer.estimate_sufficient_stats()

        # If we are using softmax based on training loss, it needs DP noise
        if self.config['server_config']['aggregate_median'] == 'softmax':
            # This matters when DP is required
            add_weight_noise = True

            if 'weight_train_loss' not in self.config['server_config'] or self.config['server_config']['weight_train_loss'] == 'train_loss':
                training_weight = trainer.train_loss / trainer.num_samples
            elif self.config['server_config']['weight_train_loss'] == 'mag_var_loss':
                training_weight = trainer.sufficient_stats['var']
            elif self.config['server_config']['weight_train_loss'] == 'mag_mean_loss':
                training_weight = trainer.sufficient_stats['mean']
            else:
                training_weight = trainer.sufficient_stats['mag']

            try:
                weight = math.exp(-self.config['server_config']['softmax_beta'] * training_weight)
            except:
                print_rank('There is an issue with the weight -- Reverting to {}'.format(MIN_WEIGHT), loglevel=logging.DEBUG)
                weight = MIN_WEIGHT
            weight = filter_weight(weight)

        # Add local DP noise here.
        # When weight == 0, something went wrong. So we'll skip adding noise and return a zero gradient.
        if weight > 0.0 and self.dp_config is not None and self.dp_config.get('enable_local_dp', False):
            weight = privacy.apply_local_dp(trainer, weight, self.dp_config, add_weight_noise)

        # In all other cases we can compute the weight after adding noise
        if not add_weight_noise:
            assert self.config['server_config']['aggregate_median'] == 'mean'
            assert weight == 1.0

        # Weight the gradient and remove gradients of the layers we want to freeze
        for n, p in trainer.model.named_parameters():
            p.grad = weight * p.grad
            if self.model_config.get('freeze_layer', None) and n == self.model_config['freeze_layer']:
                print_rank('Setting gradient to zero for layer: {}'.format(n), loglevel=logging.INFO)
                p.grad.mul_(0)

        # Gradient quantization step -- if quant_threshold is None, the code returns without doing anything
        quant_model(trainer.model, quant_threshold=self.quant_threshold, quant_bits=self.quant_bits, global_stats=False)

        payload = {}
        payload['weight'] = weight
        payload['gradients'] = [p.grad.to(torch.device('cpu')) for p in trainer.model.parameters()]

        return payload

    def process_individual_payload(self, worker_trainer, payload):
        '''Process client payload

        Args:
            worker_trainer (core.Trainer object): trainer on server
                (aka model updater).
            payload (dict): whatever is generated by
                :code:`generate_client_payload`.

        Returns:
            True if processed succesfully, False otherwise.
        '''

        if self.mode != 'server':
            raise RuntimeError('this method can only be invoked by the server')

        if payload['weight'] == 0.0:
            return False

        self.client_weights.append(payload['weight'])
        if self.aggregate_fast:
            aggregate_gradients_inplace(worker_trainer.model, payload['gradients'])
        else:
            self.client_parameters_stack.append(payload['gradients'])
        return True

    def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
        '''Combine payloads to update model

        Args:
            worker_trainer (core.Trainer object): trainer on server
                (aka model updater).
            curr_iter (int): current iteration.
            num_clients_curr_iter (int): number of clients on current iteration.
            total_clients (int): size of total pool of clients (for privacy accounting)
            client_stats (dict): stats being collected.
            logger (callback): function called to log quantities.

        Returns:
            losses, computed for use with LR scheduler.
        '''

        if self.mode != 'server':
            raise RuntimeError('this method can only be invoked by the server')

        if self.want_rl:
            rl_model = self._run_rl_inference(self.client_weights, *client_stats)

        # Aggregation step
        if self.dump_norm_stats:
            cps_copy = [[g.clone().detach() for g in x] for x in self.client_parameters_stack]
        weight_sum = self._aggregate_gradients(worker_trainer, num_clients_curr_iter, self.client_weights, metric_logger=logger)
        print_rank('Sum of weights: {}'.format(weight_sum), loglevel=logging.DEBUG)

        torch.cuda.empty_cache()

        # Normalize with weight_sum
        for p in worker_trainer.model.parameters():
            p.grad /= weight_sum

        if self.dump_norm_stats:
            cosines = compute_grad_cosines(cps_copy, [p.grad.clone().detach() for p in worker_trainer.model.parameters()])
            with open(os.path.join(self.model_path, 'cosines.txt'), 'a', encoding='utf-8') as outfile:
                outfile.write('{}\n'.format(json.dumps(cosines)))

        # DP-specific steps
        privacy.apply_global_dp(self.config, worker_trainer.model, num_clients_curr_iter=num_clients_curr_iter, select_grad=True, metric_logger=logger)
        eps = privacy.update_privacy_accountant(self.config, total_clients, curr_iter=curr_iter, num_clients_curr_iter=num_clients_curr_iter)
        if eps:
            print_rank(f'DP result: {eps}')

        if self.skip_model_update is True:
            print_rank('Skipping model update')
            return

        # Run optimization with gradient/model aggregated from clients
        print_rank('Updating model')
        worker_trainer.update_model()
        print_rank('Updating learning rate scheduler')
        losses = worker_trainer.run_lr_scheduler(force_run_val=False)

        if self.want_rl:
            self._run_rl_training(curr_iter, rl_model, self.client_weights, *client_stats, logger)

        return losses

    def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, client_weights, metric_logger=None):
        '''Go through stored gradients, aggregate and put them inside model.

        Args:
            num_clients_curr_iter (int): how many clients were processed.
            client_weights: weight for each client.
            metric_logger (callback, optional): callback used for logging.
                Defaults to None, in which case AML logger is used.

        Returns:
            float: sum of weights for all clients.
        '''

        weight_sum = 0
        if metric_logger is None:
            metric_logger = run.log

        if not self.aggregate_fast:
            metric_logger('Stale Gradients Ratio', len(self.client_parameters_stack_stale) / num_clients_curr_iter)
            if len(self.client_parameters_stack_stale) > 0:
                weight_sum = self.weight_sum_stale
                for client_parameters in self.client_parameters_stack_stale:
                    # Model parameters are already multiplied with weight on client, we only have to sum them up
                    aggregate_gradients_inplace(worker_trainer.model, client_parameters)
                self.client_parameters_stack_stale = []
                self.weight_sum_stale = 0

            for client_weight, client_parameters in zip(client_weights, self.client_parameters_stack):
                if np.random.random() > self.stale_prob:
                    # Model parameters are already multiplied with weight on client, we only have to sum them up
                    aggregate_gradients_inplace(worker_trainer.model, client_parameters)
                else:
              
Download .txt
gitextract_qg20kqyy/

├── .flake8
├── .github/
│   └── workflows/
│       ├── build_docs.yml
│       └── codeql.yml
├── .gitignore
├── .gitmodules
├── CHANGELOG.md
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.TXT
├── NOTICE.txt
├── README.md
├── SECURITY.md
├── azure-pipelines.yml
├── configs/
│   ├── hello_world_mlm_bert_json.yaml
│   └── hello_world_nlg_gru_json.yaml
├── core/
│   ├── __init__.py
│   ├── client.py
│   ├── config.py
│   ├── dataloader.py
│   ├── dataset.py
│   ├── evaluation.py
│   ├── federated.py
│   ├── metrics.py
│   ├── model.py
│   ├── schema.py
│   ├── server.py
│   ├── strategies/
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── dga.py
│   │   ├── fedavg.py
│   │   ├── fedlabels.py
│   │   └── utils.py
│   └── trainer.py
├── doc/
│   └── sphinx/
│       ├── Makefile
│       ├── advanced.rst
│       ├── class_reference.rst
│       ├── conf.py
│       ├── index.rst
│       ├── launch.rst
│       ├── make.bat
│       ├── overview.rst
│       ├── reference.rst
│       ├── requirements.txt
│       └── scenarios.rst
├── e2e_trainer.py
├── experiments/
│   ├── __init__.py
│   ├── classif_cnn/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── cifar_dataset.py
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   └── utils/
│   │       ├── centralized_training.py
│   │       └── download_and_convert_data.py
│   ├── cv/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── data.py
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   ├── model_vgg.py
│   │   └── server.py
│   ├── cv_cnn_femnist/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocess.py
│   │   └── model.py
│   ├── cv_lr_mnist/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocessing.py
│   │   └── model.py
│   ├── cv_resnet_fedcifar100/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocessing.py
│   │   ├── group_normalization.py
│   │   └── model.py
│   ├── ecg_cnn/
│   │   ├── .gitignore
│   │   ├── centralized_model.ipynb
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   ├── readme.md
│   │   └── utils/
│   │       └── preprocess.py
│   ├── fednewsrec/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocess_mind.py
│   │   ├── fednewsrec_model.py
│   │   ├── model.py
│   │   └── utils.py
│   ├── mlm_bert/
│   │   ├── README.md
│   │   ├── config.py
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   └── utils/
│   │       ├── trainer_pt_utils.py
│   │       └── trainer_utils.py
│   ├── nlg_gru/
│   │   ├── README.md
│   │   ├── config.py
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   └── dataset.py
│   │   ├── model.py
│   │   └── utils/
│   │       └── utility.py
│   ├── nlp_rnn_fedshakespeare/
│   │   ├── README.md
│   │   ├── config.yaml
│   │   ├── dataloaders/
│   │   │   ├── dataloader.py
│   │   │   ├── dataset.py
│   │   │   └── preprocessing.py
│   │   └── model.py
│   └── semisupervision/
│       ├── README.md
│       ├── config.yaml
│       ├── dataloaders/
│       │   ├── RandAugment.py
│       │   ├── cifar_dataset.py
│       │   ├── dataloader.py
│       │   └── dataset.py
│       └── model.py
├── extensions/
│   ├── RL/
│   │   └── RL.py
│   ├── __init__.py
│   ├── privacy/
│   │   ├── __init__.py
│   │   ├── analysis.py
│   │   ├── dp_kmeans.py
│   │   └── metrics.py
│   └── quantization/
│       └── quant.py
├── requirements.txt
├── testing/
│   ├── README.md
│   ├── build_vocab.py
│   ├── create_data.py
│   ├── hello_world_classif_cnn.yaml
│   ├── hello_world_ecg_cnn.yaml
│   ├── hello_world_mlm_bert.yaml
│   ├── hello_world_nlg_gru.yaml
│   └── test_e2e_trainer.py
└── utils/
    ├── __init__.py
    ├── data_utils.py
    ├── dataloaders_utils.py
    ├── optimizers/
    │   ├── adamW.py
    │   ├── lamb.py
    │   └── lars.py
    ├── preprocessing/
    │   ├── create-hdf5.py
    │   ├── create-json.py
    │   └── from_json_to_hdf5.py
    └── utils.py
Download .txt
SYMBOL INDEX (773 symbols across 91 files)

FILE: core/client.py
  class Client (line 49) | class Client:
    method __init__ (line 54) | def __init__(self, client_id, config, send_gradients):
    method get_client_data (line 70) | def get_client_data(self, dataset=None):
    method get_train_dataset (line 77) | def get_train_dataset(data_path, client_train_config, task):
    method get_data (line 102) | def get_data(clients, dataset):
    method run_testvalidate (line 127) | def run_testvalidate(client_data, server_data, mode, model):
    method process_round (line 227) | def process_round(client_data, server_data, model, data_path, eps=1e-7):

FILE: core/config.py
  function from_dict (line 32) | def from_dict(cls, config):
  class Config (line 39) | class Config(MutableMapping):
    method get (line 41) | def get(self, k: str, default=None):
    method lookup (line 47) | def lookup(self, s: str, default=None):
    method __getitem__ (line 57) | def __getitem__(self, k):
    method __setitem__ (line 60) | def __setitem__(self, k, v):
    method __delitem__ (line 63) | def __delitem__(self, k):
    method __iter__ (line 66) | def __iter__(self):
    method __len__ (line 69) | def __len__(self):
    method __contains__ (line 72) | def __contains__(self, k):
    method pop (line 75) | def pop(self, k, default=None):
  class ModelConfig (line 83) | class ModelConfig(Config):
    method from_dict (line 101) | def from_dict(config) -> ModelConfig:
  class BERTModelConfig (line 120) | class BERTModelConfig(Config):
    method from_dict (line 158) | def from_dict(config) -> BERTModelConfig:
  class BERTTrainingConfig (line 163) | class BERTTrainingConfig(Config):
    method from_dict (line 183) | def from_dict(config) -> BERTTrainingConfig:
  class BERTConfig (line 188) | class BERTConfig(Config):
    method from_dict (line 204) | def from_dict(config) -> BERTConfig:
  class PrivacyConfig (line 217) | class PrivacyConfig(Config):
    method from_dict (line 279) | def from_dict(config) -> PrivacyConfig:
  class PrivacyMetricsConfig (line 284) | class PrivacyMetricsConfig(Config):
    method from_dict (line 318) | def from_dict(config) -> PrivacyMetricsConfig:
  class OptimizerConfig (line 330) | class OptimizerConfig(Config):
    method from_dict (line 345) | def from_dict(config) -> OptimizerConfig:
  class AnnealingConfig (line 355) | class AnnealingConfig(Config):
    method from_dict (line 374) | def from_dict(config) -> AnnealingConfig:
  class DatasetConfig (line 379) | class DatasetConfig(Config):
    method from_dict (line 422) | def from_dict(config) -> DatasetConfig:
  class DataConfig (line 427) | class DataConfig(Config):
    method from_dict (line 447) | def from_dict(config) -> DataConfig:
  class ServerReplayConfig (line 458) | class ServerReplayConfig(Config):
    method from_dict (line 473) | def from_dict(config) -> ServerReplayConfig:
  class RLConfig (line 482) | class RLConfig(Config):
    method from_dict (line 531) | def from_dict(config) -> RLConfig:
  class ServerConfig (line 544) | class ServerConfig(Config):
    method from_dict (line 626) | def from_dict(config) -> ServerConfig:
  class ClientConfig (line 651) | class ClientConfig(Config):
    method from_dict (line 693) | def from_dict(config) -> ClientConfig:
  class FLUTEConfig (line 713) | class FLUTEConfig(Config):
    method validate (line 736) | def validate(config):
    method from_dict (line 763) | def from_dict(config) -> FLUTEConfig:

FILE: core/dataloader.py
  class BaseDataLoader (line 7) | class BaseDataLoader(ABC, PyTorchDataLoader):
    method create_loader (line 10) | def create_loader(self):

FILE: core/dataset.py
  class BaseDataset (line 7) | class BaseDataset(ABC, PyTorchDataset):
    method __init__ (line 11) | def __init__(self,**kwargs):
    method __getitem__ (line 15) | def __getitem__(self, idx, **kwargs):
    method __len__ (line 20) | def __len__(self):
    method load_data (line 25) | def load_data(self,**kwargs):

FILE: core/evaluation.py
  class Evaluation (line 21) | class Evaluation():
    method __init__ (line 23) | def __init__(self, config, model_path, process_testvalidate, idx_val_c...
    method run (line 35) | def run(self, eval_list, req, metric_logger=None):
    method initialize_req (line 113) | def initialize_req(self, req):
    method run_distributed_inference (line 128) | def run_distributed_inference(self, mode, model):
    method run_distributed_evaluation (line 146) | def run_distributed_evaluation(self, mode, clients, model):
  function make_eval_clients (line 185) | def make_eval_clients(dataset, config):

FILE: core/federated.py
  function encode_string (line 27) | def encode_string(word, string_to_int = True):
  function rank (line 45) | def rank():
  function local_rank (line 49) | def local_rank():
  function size (line 53) | def size():
  function _recv (line 57) | def _recv(x, src=0):
  function _recv_gradients (line 71) | def _recv_gradients(src):
  function _send (line 89) | def _send(x, dst=0):
  function _send_metrics (line 98) | def _send_metrics(output):
  function _send_gradients (line 112) | def _send_gradients(gradients, dst):
  function _send_train_output (line 126) | def _send_train_output(output):
  function build_grads_dict (line 147) | def build_grads_dict(node):
  function build_metrics_dict (line 190) | def build_metrics_dict(node):
  function receive_workers_output (line 216) | def receive_workers_output(node_request_map, results_list, free_nodes, c...
  function append_async_requests (line 242) | def append_async_requests(node_request_map, node):
  function sync_idle_nodes (line 251) | def sync_idle_nodes(client_queue, free_nodes):
  class Server (line 264) | class Server:
    method dispatch_clients (line 282) | def dispatch_clients(clients, server_data, command, mode=None, do_prof...
    method process_clients (line 413) | def process_clients(clients, server_data, single_worker):
    method process_testvalidate (line 427) | def process_testvalidate(clients, server_data, mode, single_worker):
    method terminate_workers (line 443) | def terminate_workers(terminate=True):
  class Worker (line 451) | class Worker:
    method __init__ (line 470) | def __init__(self, model=None, data_path=None, do_profiling=False, val...
    method run (line 482) | def run(self):
    method trigger_evaluate (line 634) | def trigger_evaluate(self):
    method trigger_train (line 660) | def trigger_train(self):

FILE: core/metrics.py
  class Metrics (line 14) | class Metrics():
    method __init__ (line 16) | def __init__(self):
    method compute_metrics (line 19) | def compute_metrics(self,dataloader, model):
    method call_inference (line 29) | def call_inference(self, dataloader, model):

FILE: core/model.py
  class BaseModel (line 7) | class BaseModel(ABC, T.nn.Module):
    method __init__ (line 11) | def __init__(self,**kwargs):
    method loss (line 15) | def loss(self, input):
    method inference (line 24) | def inference(self, input):
    method set_eval (line 45) | def set_eval(self):
    method set_train (line 49) | def set_train(self):

FILE: core/server.py
  class OptimizationServer (line 47) | class OptimizationServer(federated.Server):
    method __init__ (line 48) | def __init__(self, num_clients, model, optimizer, ss_scheduler, data_p...
    method load_saved_status (line 183) | def load_saved_status(self):
    method run (line 206) | def run(self):
    method train (line 215) | def train(self):
    method backup_models (line 530) | def backup_models(self, i):
    method fall_back_to_prev_best_status (line 561) | def fall_back_to_prev_best_status(self):
  function select_server (line 581) | def select_server(server_type):

FILE: core/strategies/__init__.py
  function select_strategy (line 9) | def select_strategy(strategy):

FILE: core/strategies/base.py
  class BaseStrategy (line 8) | class BaseStrategy:
    method __init__ (line 9) | def __init__(self, mode, config, model_path=None):
    method generate_client_payload (line 20) | def generate_client_payload(self, trainer):
    method process_individual_payload (line 31) | def process_individual_payload(self, worker_trainer, payload):
    method combine_payloads (line 45) | def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr...

FILE: core/strategies/dga.py
  class DGA (line 27) | class DGA(BaseStrategy):
    method __init__ (line 30) | def __init__(self, mode, config, model_path=None):
    method generate_client_payload (line 88) | def generate_client_payload(self, trainer):
    method process_individual_payload (line 157) | def process_individual_payload(self, worker_trainer, payload):
    method combine_payloads (line 183) | def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr...
    method _aggregate_gradients (line 243) | def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, ...
    method _run_rl_inference (line 286) | def _run_rl_inference(self, client_weights, client_mag_grads, client_m...
    method _run_rl_training (line 348) | def _run_rl_training(self, iter, rl_model, client_weights, client_mag_...

FILE: core/strategies/fedavg.py
  class FedAvg (line 20) | class FedAvg(BaseStrategy):
    method __init__ (line 23) | def __init__(self, mode, config, model_path=None):
    method generate_client_payload (line 61) | def generate_client_payload(self, trainer):
    method process_individual_payload (line 93) | def process_individual_payload(self, worker_trainer, payload):
    method combine_payloads (line 119) | def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr...
    method _aggregate_gradients (line 168) | def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, ...

FILE: core/strategies/fedlabels.py
  class FedLabels (line 20) | class FedLabels(BaseStrategy):
    method __init__ (line 23) | def __init__(self, mode, config, model_path=None):
    method generate_client_payload (line 60) | def generate_client_payload(self, trainer):
    method process_individual_payload (line 94) | def process_individual_payload(self, worker_trainer, payload):
    method combine_payloads (line 120) | def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr...
    method _aggregate_gradients (line 170) | def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, ...
  function aggregate_gradients_inplace (line 218) | def aggregate_gradients_inplace(keys, values, first, tmp, ratio):

FILE: core/strategies/utils.py
  function filter_weight (line 11) | def filter_weight(weight):
  function aggregate_gradients_inplace (line 21) | def aggregate_gradients_inplace(model, gradients):

FILE: core/trainer.py
  class TrainerBase (line 30) | class TrainerBase:
    method __init__ (line 47) | def __init__(
    method epoch_boundary (line 68) | def epoch_boundary(self):
    method train_desired_samples (line 72) | def train_desired_samples(self, desired_max_samples, apply_privacy_met...
    method save (line 75) | def save(self):
    method load (line 78) | def load(self):
  class ModelUpdater (line 82) | class ModelUpdater(TrainerBase):
    method __init__ (line 101) | def __init__(
    method update_model (line 127) | def update_model(self):
    method run_lr_scheduler (line 139) | def run_lr_scheduler(self, force_run_val=False):
    method run_ss_scheduler (line 157) | def run_ss_scheduler(self):
    method save (line 163) | def save(self, model_path, token=None, config=None):
    method load (line 176) | def load(self, save_path, update_lr_scheduler, update_ss_scheduler):
  class Trainer (line 200) | class Trainer(TrainerBase):
    method __init__ (line 223) | def __init__(
    method reset_gradient_power (line 263) | def reset_gradient_power(self):
    method accumulate_gradient_power (line 271) | def accumulate_gradient_power(self):
    method estimate_sufficient_stats (line 294) | def estimate_sufficient_stats(self):
    method train_desired_samples (line 314) | def train_desired_samples(self, desired_max_samples=None, apply_privac...
    method run_train_epoch (line 341) | def run_train_epoch(self, desired_max_samples=None, apply_privacy_metr...
    method run_train_epoch_fedprox (line 416) | def run_train_epoch_fedprox(self, desired_max_samples=None, apply_priv...
    method run_train_epoch_sup (line 503) | def run_train_epoch_sup(self, desired_max_samples=None, apply_privacy_...
    method get_model (line 621) | def get_model(self):
    method prepare_iteration (line 624) | def prepare_iteration(self, model=None):
    method reset_optimizer (line 640) | def reset_optimizer(self, optimizer_state_dict, annealing_config=None):
    method save (line 653) | def save(self, model_path, token=None, config=None):
    method load (line 666) | def load(self, save_path, update_lr_scheduler, update_ss_scheduler):
  function run_validation_generic (line 690) | def run_validation_generic(model, val_dataloader):
  function set_component_wise_lr (line 725) | def set_component_wise_lr(model, optimizer_config, updatable_names):
  function save_model (line 753) | def save_model(model_path, config, model, optimizer, lr_scheduler, ss_sc...

FILE: e2e_trainer.py
  function log_run_properties (line 40) | def log_run_properties(config: FLUTEConfig):
  function run_worker (line 77) | def run_worker(model_path, config, task, data_path, local_rank, backend):

FILE: experiments/__init__.py
  function make_model (line 8) | def make_model(model_config, dataloader_type=None, input_dim=-1, output_...

FILE: experiments/classif_cnn/dataloaders/cifar_dataset.py
  class CIFAR10 (line 7) | class CIFAR10:
    method __init__ (line 8) | def __init__(self) :
  function _process (line 26) | def _process(dataset, n_users):

FILE: experiments/classif_cnn/dataloaders/dataloader.py
  class DataLoader (line 9) | class DataLoader(BaseDataLoader):
    method __init__ (line 10) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 28) | def collate_fn(self, batch):

FILE: experiments/classif_cnn/dataloaders/dataset.py
  class Dataset (line 8) | class Dataset(BaseDataset):
    method __init__ (line 9) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 28) | def __getitem__(self, idx):
    method __len__ (line 31) | def __len__(self):
    method load_data (line 34) | def load_data(self, data, test_only):

FILE: experiments/classif_cnn/model.py
  class Net (line 11) | class Net(nn.Module):
    method __init__ (line 14) | def __init__(self):
    method forward (line 23) | def forward(self, x):
  class CNN (line 33) | class CNN(BaseModel):
    method __init__ (line 36) | def __init__(self, model_config):
    method loss (line 40) | def loss(self, input: torch.Tensor) -> torch.Tensor:
    method inference (line 47) | def inference(self, input):

FILE: experiments/classif_cnn/utils/centralized_training.py
  class Net (line 43) | class Net(nn.Module):
    method __init__ (line 44) | def __init__(self):
    method forward (line 53) | def forward(self, x):

FILE: experiments/classif_cnn/utils/download_and_convert_data.py
  function _dump_dict_to_hdf5 (line 10) | def _dump_dict_to_hdf5(data_dict: dict, hdf5_file: h5py.File):
  function _process_and_save_to_disk (line 26) | def _process_and_save_to_disk(dataset, n_users, file_format, output):

FILE: experiments/cv/data.py
  class DataPartitioner (line 16) | class DataPartitioner(object):
    method __init__ (line 19) | def __init__(self, data, sizes=None, rnd=0, alpha=0, num_c=10,
    method get_lab_distr (line 35) | def get_lab_distr(self):
    method return_partition (line 39) | def return_partition(self, partition, flag='data', is_train_set=True):
    method __use_fixed_lab_distr__ (line 67) | def __use_fixed_lab_distr__(self, data, lab_distr, ratio, rnd, num_c):
    method __getDirichletData__ (line 119) | def __getDirichletData__(self, data, psizes, alpha, num_c, rnd):
  function partition_dataset (line 174) | def partition_dataset(rnd, img_size, image, total_num_clients, image_pat...
  function prepare_dataset (line 232) | def prepare_dataset(rnd=2020, img_size=40, image='cifar', total_num_clie...
  function _dump_dict_to_hdf5 (line 253) | def _dump_dict_to_hdf5(data_dict: dict, hdf5_file: h5py.File):
  function _process_and_save_to_disk (line 270) | def _process_and_save_to_disk(dataset, save_to_disk, file_format, output...
  function get_transform (line 305) | def get_transform(transform, img_size=32):

FILE: experiments/cv/dataloaders/dataloader.py
  class DataLoader (line 10) | class DataLoader(BaseDataLoader):
    method __init__ (line 11) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/cv/dataloaders/dataset.py
  class Dataset (line 9) | class Dataset(BaseDataset):
    method __init__ (line 10) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 30) | def __getitem__(self, idx):
    method __len__ (line 33) | def __len__(self):
    method load_data (line 36) | def load_data(self, data, test_only):

FILE: experiments/cv/model.py
  function conv3x3 (line 40) | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: in...
  function conv1x1 (line 46) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  class BasicBlock (line 51) | class BasicBlock(nn.Module):
    method __init__ (line 54) | def __init__(
    method forward (line 82) | def forward(self, x: Tensor) -> Tensor:
  class Bottleneck (line 101) | class Bottleneck(nn.Module):
    method __init__ (line 110) | def __init__(
    method forward (line 137) | def forward(self, x: Tensor) -> Tensor:
  class ResNet (line 160) | class ResNet(BaseModel):
    method __init__ (line 161) | def __init__(
    method _make_layer (line 217) | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], plan...
    method forward (line 242) | def forward(self, inputs):
    method get_logit (line 262) | def get_logit(self, x = None, evalis = True, logmax=False):
    method inference (line 288) | def inference(self, inputs):
    method loss (line 305) | def loss(self, inputs):
    method copy_state_dict (line 316) | def copy_state_dict(self, state_dict):
    method get_model (line 320) | def get_model(self):
  function _resnet (line 324) | def _resnet(
  function resnet18 (line 344) | def resnet18(config, pretrained: bool = False, progress: bool = True, **...
  function resnet34 (line 359) | def resnet34(config, pretrained: bool = False, progress: bool = True, **...
  function resnet50 (line 374) | def resnet50(config, pretrained: bool = False, progress: bool = True, **...
  function resnet101 (line 389) | def resnet101(config, pretrained: bool = False, progress: bool = True, *...
  function resnet152 (line 403) | def resnet152(config, pretrained: bool = False, progress: bool = True, *...
  function resnext50_32x4d (line 418) | def resnext50_32x4d(config, pretrained: bool = False, progress: bool = T...
  function resnext101_32x8d (line 435) | def resnext101_32x8d(config, pretrained: bool = False, progress: bool = ...
  function wide_resnet50_2 (line 452) | def wide_resnet50_2(config, pretrained: bool = False, progress: bool = T...
  function wide_resnet101_2 (line 473) | def wide_resnet101_2(config, pretrained: bool = False, progress: bool = ...

FILE: experiments/cv/model_vgg.py
  class VGG (line 23) | class VGG(nn.Module):
    method __init__ (line 27) | def __init__(self, vgg, num_class, topK_results=None):
    method forward (line 49) | def forward(self, inputs):
    method loss (line 57) | def loss(self, inputs):
    method inference (line 66) | def inference(self, inputs):
    method get_logit (line 81) | def get_logit(self, inputs = None, evalis = True, logmax=False):
    method copy_state_dict (line 106) | def copy_state_dict(self, state_dict):
    method set_eval (line 109) | def set_eval(self):
    method set_train (line 115) | def set_train(self):
  function make_layers (line 122) | def make_layers(cfg, n_channels=3, batch_norm=True):
  function vgg11 (line 147) | def vgg11(config):
  function vgg11_bn (line 153) | def vgg11_bn(config):
  function vgg13 (line 159) | def vgg13(config):
  function vgg13_bn (line 165) | def vgg13_bn(config):
  function vgg16 (line 171) | def vgg16(config):
  function vgg16_bn (line 177) | def vgg16_bn(config):
  function vgg19 (line 183) | def vgg19(config):
  function vgg19_bn (line 189) | def vgg19_bn(config):

FILE: experiments/cv/server.py
  class PersonalizationServer (line 9) | class PersonalizationServer(OptimizationServer):
    method __init__ (line 10) | def __init__(self, num_clients, model, optimizer, ss_scheduler, data_p...

FILE: experiments/cv_cnn_femnist/dataloaders/dataloader.py
  class DataLoader (line 10) | class DataLoader(BaseDataLoader):
    method __init__ (line 11) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/cv_cnn_femnist/dataloaders/dataset.py
  class Dataset (line 8) | class Dataset(BaseDataset):
    method __init__ (line 9) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 33) | def __getitem__(self, idx):
    method __len__ (line 36) | def __len__(self):
    method load_data (line 39) | def load_data(self, data, test_only):

FILE: experiments/cv_cnn_femnist/dataloaders/preprocess.py
  class FEMNIST (line 19) | class FEMNIST:
    method __init__ (line 20) | def __init__(self) :
  function download_files (line 45) | def download_files(data_cache_dir):

FILE: experiments/cv_cnn_femnist/model.py
  class CNN_DropOut (line 12) | class CNN_DropOut(torch.nn.Module):
    method __init__ (line 53) | def __init__(self, only_digits=True):
    method forward (line 66) | def forward(self, x):
  class CNN (line 82) | class CNN(BaseModel):
    method __init__ (line 85) | def __init__(self, model_config):
    method loss (line 89) | def loss(self, input: torch.Tensor) -> torch.Tensor:
    method inference (line 97) | def inference(self, input):

FILE: experiments/cv_lr_mnist/dataloaders/dataloader.py
  class DataLoader (line 10) | class DataLoader(BaseDataLoader):
    method __init__ (line 11) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/cv_lr_mnist/dataloaders/dataset.py
  class Dataset (line 8) | class Dataset(BaseDataset):
    method __init__ (line 9) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 34) | def __getitem__(self, idx):
    method __len__ (line 37) | def __len__(self):
    method load_data (line 40) | def load_data(self, data, test_only):

FILE: experiments/cv_lr_mnist/dataloaders/preprocessing.py
  class MNIST (line 19) | class MNIST:
    method __init__ (line 20) | def __init__(self) :
  function download_mnist (line 29) | def download_mnist(data_cache_dir):
  function read_data (line 42) | def read_data(train_data_dir, test_data_dir):

FILE: experiments/cv_lr_mnist/model.py
  class LogisticRegression (line 12) | class LogisticRegression(torch.nn.Module):
    method __init__ (line 13) | def __init__(self, input_dim, output_dim):
    method forward (line 17) | def forward(self, x):
  class LR (line 23) | class LR(BaseModel):
    method __init__ (line 26) | def __init__(self, model_config):
    method loss (line 30) | def loss(self, input: torch.Tensor) -> torch.Tensor:
    method inference (line 38) | def inference(self, input):

FILE: experiments/cv_resnet_fedcifar100/dataloaders/dataloader.py
  class DataLoader (line 10) | class DataLoader(BaseDataLoader):
    method __init__ (line 11) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/cv_resnet_fedcifar100/dataloaders/dataset.py
  class Dataset (line 8) | class Dataset(BaseDataset):
    method __init__ (line 9) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 33) | def __getitem__(self, idx):
    method __len__ (line 36) | def __len__(self):
    method load_data (line 39) | def load_data(self, data, test_only):

FILE: experiments/cv_resnet_fedcifar100/dataloaders/preprocessing.py
  class FEDCIFAR100 (line 20) | class FEDCIFAR100:
    method __init__ (line 21) | def __init__(self) :
  function download_files (line 46) | def download_files(data_cache_dir):

FILE: experiments/cv_resnet_fedcifar100/group_normalization.py
  function group_norm (line 10) | def group_norm(
  class _GroupNorm (line 99) | class _GroupNorm(_BatchNorm):
    method __init__ (line 100) | def __init__(
    method _check_input_dim (line 115) | def _check_input_dim(self, input):
    method forward (line 118) | def forward(self, input):
  class GroupNorm2d (line 134) | class GroupNorm2d(_GroupNorm):
    method _check_input_dim (line 163) | def _check_input_dim(self, input):
  class GroupNorm3d (line 168) | class GroupNorm3d(_GroupNorm):
    method _check_input_dim (line 173) | def _check_input_dim(self, input):

FILE: experiments/cv_resnet_fedcifar100/model.py
  function conv3x3 (line 26) | def conv3x3(in_planes, out_planes, stride=1):
  function norm2d (line 33) | def norm2d(planes, num_channels_per_group=32):
  class BasicBlock (line 43) | class BasicBlock(nn.Module):
    method __init__ (line 46) | def __init__(self, inplanes, planes, stride=1, downsample=None, group_...
    method forward (line 56) | def forward(self, x):
  class Bottleneck (line 75) | class Bottleneck(nn.Module):
    method __init__ (line 78) | def __init__(self, inplanes, planes, stride=1, downsample=None, group_...
    method forward (line 92) | def forward(self, x):
  class ResNet (line 115) | class ResNet(nn.Module):
    method __init__ (line 116) | def __init__(self, block, layers, num_classes=1000, group_norm=0):
    method _make_layer (line 154) | def _make_layer(self, block, planes, blocks, stride=1, group_norm=0):
    method forward (line 176) | def forward(self, x):
  function resnet18 (line 194) | def resnet18(pretrained=False, **kwargs):
  function resnet34 (line 205) | def resnet34(pretrained=False, **kwargs):
  function resnet50 (line 216) | def resnet50(pretrained=False, **kwargs):
  function resnet101 (line 227) | def resnet101(pretrained=False, **kwargs):
  function resnet152 (line 238) | def resnet152(pretrained=False, **kwargs):
  class RESNET (line 248) | class RESNET(BaseModel):
    method __init__ (line 251) | def __init__(self, model_config):
    method loss (line 255) | def loss(self, input: torch.Tensor) -> torch.Tensor:
    method inference (line 263) | def inference(self, input):

FILE: experiments/ecg_cnn/dataloaders/dataloader.py
  class DataLoader (line 9) | class DataLoader(BaseDataLoader):
    method __init__ (line 10) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/ecg_cnn/dataloaders/dataset.py
  class Dataset (line 9) | class Dataset(BaseDataset):
    method __init__ (line 10) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 29) | def __getitem__(self, idx):
    method __len__ (line 33) | def __len__(self):
    method load_data (line 36) | def load_data(self,data):

FILE: experiments/ecg_cnn/model.py
  class Swish (line 15) | class Swish(nn.Module):
    method forward (line 16) | def forward(self, x):
  class ConvNormPool (line 19) | class ConvNormPool(nn.Module):
    method __init__ (line 21) | def __init__(
    method forward (line 69) | def forward(self, input):
  class RNN (line 88) | class RNN(nn.Module):
    method __init__ (line 90) | def __init__(
    method forward (line 108) | def forward(self, input):
  class Net (line 112) | class Net(nn.Module):
    method __init__ (line 113) | def __init__(
    method forward (line 140) | def forward(self, input):
  class SuperNet (line 153) | class SuperNet(BaseModel):
    method __init__ (line 155) | def __init__(self, model_config):
    method loss (line 159) | def loss(self, input: torch.Tensor):
    method inference (line 165) | def inference(self, input):

FILE: experiments/ecg_cnn/utils/preprocess.py
  function _dump_dict_to_hdf5 (line 11) | def _dump_dict_to_hdf5(data_dict: dict, hdf5_file: h5py.File):
  function _process_and_save_to_disk (line 27) | def _process_and_save_to_disk(dataset, n_users, output):
  class HeartDataSet (line 58) | class HeartDataSet:
    method __init__ (line 59) | def __init__(self, heartdata, cutoff):
    method __len__ (line 63) | def __len__(self):
  function resampleSet (line 69) | def resampleSet(train_df):

FILE: experiments/fednewsrec/dataloaders/dataloader.py
  class DataLoader (line 9) | class DataLoader(BaseDataLoader):
    method __init__ (line 10) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/fednewsrec/dataloaders/dataset.py
  class Dataset (line 9) | class Dataset(BaseDataset):
    method __init__ (line 10) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 32) | def __getitem__(self, idx):
    method __len__ (line 35) | def __len__(self):
    method load_data (line 38) | def load_data(self, data, test_only):

FILE: experiments/fednewsrec/dataloaders/preprocess_mind.py
  class MIND (line 16) | class MIND:
    method __init__ (line 17) | def __init__(self, root_data_path, embedding_path) :
  function GetUserDataFunc (line 64) | def GetUserDataFunc(news_title,train_user_id_sample,train_user,train_ses...
  function newsample (line 81) | def newsample(nnn,ratio):
  function read_news (line 87) | def read_news(root_data_path,modes):
  function get_doc_input (line 130) | def get_doc_input(news,news_index,category,subcategory,word_dict):
  function load_matrix (line 146) | def load_matrix(embedding_path,word_dict):
  function read_clickhistory (line 164) | def read_clickhistory(root_data_path,mode):
  function parse_user (line 193) | def parse_user(session,news_index):
  function get_train_input (line 211) | def get_train_input(session,uid_click_talbe,news_index):
  function get_test_input (line 250) | def get_test_input(session,news_index):

FILE: experiments/fednewsrec/fednewsrec_model.py
  class AttentivePooling (line 12) | class AttentivePooling(nn.Module):
    method __init__ (line 13) | def __init__(self, dim1: int, dim2: int):
    method forward (line 25) | def forward(self, x):
    method fromTensorFlow (line 33) | def fromTensorFlow(self, tfmodel):
  class Attention (line 44) | class Attention(nn.Module):
    method __init__ (line 46) | def __init__(self, input_dim, nb_head, size_per_head, **kwargs):
    method fromTensorFlow (line 63) | def fromTensorFlow(self, tf, criteria = lambda l: l.name.startswith('a...
    method forward (line 74) | def forward(self, x):
  class Permute (line 113) | class Permute(nn.Module):
    method __init__ (line 114) | def __init__(self, *dims):
    method forward (line 118) | def forward(self, x):
  class SwapTrailingAxes (line 121) | class SwapTrailingAxes(nn.Module):
    method __init__ (line 122) | def __init__(self):
    method forward (line 125) | def forward(self, x):
  class DocEncoder (line 128) | class DocEncoder(nn.Module):
    method __init__ (line 129) | def __init__(self):
    method fromTensorFlow (line 153) | def fromTensorFlow(self, tfDoc):
    method forward (line 189) | def forward(self, x):
  class VecTail (line 200) | class VecTail(nn.Module):
    method __init__ (line 201) | def __init__(self, n):
    method forward (line 205) | def forward(self, x):
  class UserEncoder (line 208) | class UserEncoder(nn.Module):
    method __init__ (line 209) | def __init__(self):
    method forward (line 227) | def forward(self, news_vecs_input):
    method fromTensorFlow (line 257) | def fromTensorFlow(self, tfU):
  class TimeDistributed (line 284) | class TimeDistributed(nn.Module):
    method __init__ (line 285) | def __init__(self, module): #, batch_first=False):
    method forward (line 290) | def forward(self, x):
  class FedNewsRec (line 315) | class FedNewsRec(nn.Module):
    method __init__ (line 316) | def __init__(self, title_word_embedding_matrix):
    method forward (line 329) | def forward(self, click_title, can_title):
    method news_encoder (line 358) | def news_encoder(self, news_title):

FILE: experiments/fednewsrec/model.py
  class FEDNEWS (line 19) | class FEDNEWS(BaseModel):
    method __init__ (line 22) | def __init__(self, model_config):
    method loss (line 32) | def loss(self, input: torch.Tensor) -> torch.Tensor:
    method inference (line 47) | def inference(self, input):
    method read_news (line 71) | def read_news(self, root_data_path, modes):
    method load_matrix (line 114) | def load_matrix(self, embedding_path,word_dict):

FILE: experiments/fednewsrec/utils.py
  function mrr_score (line 3) | def mrr_score(y_true, y_score):
  function ndcg_score (line 9) | def ndcg_score(y_true, y_score, k=10):
  function dcg_score (line 14) | def dcg_score(y_true, y_score, k=10):

FILE: experiments/mlm_bert/config.py
  class BERTModelConfig (line 9) | class BERTModelConfig(Config):
    method from_dict (line 47) | def from_dict(config) -> BERTModelConfig:
  class BERTTrainingConfig (line 52) | class BERTTrainingConfig(Config):
    method from_dict (line 72) | def from_dict(config) -> BERTTrainingConfig:
  class BERTSpecificConfig (line 77) | class BERTSpecificConfig(Config):
    method from_dict (line 93) | def from_dict(config) -> BERTSpecificConfig:
  class BERTConfig (line 106) | class BERTConfig(ModelConfig):
    method from_dict (line 113) | def from_dict(config) -> ModelConfig:

FILE: experiments/mlm_bert/dataloaders/dataloader.py
  class DataLoader (line 13) | class DataLoader(BaseDataLoader):
    method __init__ (line 18) | def __init__(self, mode, data, num_workers=0,  **kwargs):
    method get_user (line 93) | def get_user(self):

FILE: experiments/mlm_bert/dataloaders/dataset.py
  class Dataset (line 11) | class Dataset(BaseDataset):
    method __init__ (line 16) | def __init__(self, data, args, tokenizer=None, test_only=False, user_i...
    method __len__ (line 66) | def __len__(self):
    method __getitem__ (line 69) | def __getitem__(self, idx):
    method load_data (line 86) | def load_data(self, orig_strct, user_idx):
    method process_x (line 107) | def process_x(self, raw_x_batch):
    method process_user (line 122) | def process_user(self, user, user_data):
    method post_process_list (line 142) | def post_process_list(self):
  function LineByLineTextDataset (line 204) | def LineByLineTextDataset(tokenizer, input_lines, truncation=True, max_l...

FILE: experiments/mlm_bert/model.py
  class BERT (line 39) | class BERT(BaseModel):
    method __init__ (line 40) | def __init__(self, model_config, **kwargs):
    method copy_state_dict (line 151) | def copy_state_dict(self, state_dict):
    method get_model (line 154) | def get_model(self):
    method _prepare_inputs (line 158) | def _prepare_inputs(self, inputs):
    method forward (line 172) | def forward(self, inputs):
    method loss (line 177) | def loss(self, inputs):
    method compute_loss (line 199) | def compute_loss(self, inputs_orig, return_outputs=False):
    method inference (line 243) | def inference(
    method prediction_loop (line 282) | def prediction_loop(
    method _gather_and_numpify (line 369) | def _gather_and_numpify(self, tensors, name):
    method prediction_step (line 379) | def prediction_step(
    method floating_point_ops (line 444) | def floating_point_ops(self, inputs):
    method set_eval (line 462) | def set_eval(self):
    method set_train (line 469) | def set_train(self):

FILE: experiments/mlm_bert/utils/trainer_pt_utils.py
  function torch_pad_and_concatenate (line 47) | def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
  function numpy_pad_and_concatenate (line 62) | def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
  function nested_concat (line 77) | def nested_concat(tensors, new_tensors, padding_index=-100):
  function nested_numpify (line 95) | def nested_numpify(tensors):
  function nested_detach (line 102) | def nested_detach(tensors):
  function reissue_pt_warnings (line 111) | def reissue_pt_warnings(caught_warnings):
  function nested_new_like (line 121) | def nested_new_like(arrays, num_samples, padding_index=-100):
  function nested_expand_like (line 128) | def nested_expand_like(arrays, new_seq_length, padding_index=-100):
  function nested_truncate (line 138) | def nested_truncate(tensors, limit):
  function _get_first_shape (line 145) | def _get_first_shape(arrays):
  class DistributedTensorGatherer (line 152) | class DistributedTensorGatherer:
    method __init__ (line 186) | def __init__(self, world_size, num_samples, make_multiple_of=None, pad...
    method add_arrays (line 196) | def add_arrays(self, arrays):
    method _nested_set_tensors (line 216) | def _nested_set_tensors(self, storage, arrays):
    method finalize (line 235) | def finalize(self):
  class LabelSmoother (line 248) | class LabelSmoother:
    method __call__ (line 261) | def __call__(self, model_output, labels):
  function get_length_grouped_indices (line 284) | def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None...
  class LengthGroupedSampler (line 317) | class LengthGroupedSampler(Sampler):
    method __init__ (line 323) | def __init__(self, dataset: Dataset, batch_size: int, lengths: Optiona...
    method __len__ (line 335) | def __len__(self):
    method __iter__ (line 338) | def __iter__(self):
  class DistributedLengthGroupedSampler (line 343) | class DistributedLengthGroupedSampler(DistributedSampler):
    method __init__ (line 349) | def __init__(
    method __iter__ (line 394) | def __iter__(self) -> Iterator:
  function _get_learning_rate (line 419) | def _get_learning_rate(self):
  function metrics_format (line 442) | def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
  function log_metrics (line 464) | def log_metrics(self, split, metrics):
  function save_metrics (line 482) | def save_metrics(self, split, metrics):

FILE: experiments/mlm_bert/utils/trainer_utils.py
  function set_seed (line 31) | def set_seed(seed: int):
  class EvalPrediction (line 46) | class EvalPrediction(NamedTuple):
  class PredictionOutput (line 58) | class PredictionOutput(NamedTuple):
  class ComputeMetrics (line 64) | class ComputeMetrics:
    method __init__ (line 65) | def __init__(self, p: EvalPrediction, mask=None):
    method compute_metrics (line 70) | def compute_metrics(p: EvalPrediction, mask=None):

FILE: experiments/nlg_gru/config.py
  class GRUConfig (line 9) | class GRUConfig(ModelConfig):
    method from_dict (line 32) | def from_dict(config) -> GRUConfig:

FILE: experiments/nlg_gru/dataloaders/dataloader.py
  class DataLoader (line 12) | class DataLoader(BaseDataLoader):
    method __init__ (line 17) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 64) | def collate_fn(self, batch):

FILE: experiments/nlg_gru/dataloaders/dataset.py
  class Dataset (line 12) | class Dataset(BaseDataset):
    method __init__ (line 17) | def __init__(self, data, min_num_words=2, max_num_words=25, test_only=...
    method __len__ (line 32) | def __len__(self):
    method __getitem__ (line 37) | def __getitem__(self, idx):
    method load_data (line 49) | def load_data(self, orig_strct, user_idx):
    method process_x (line 65) | def process_x(self, user_data):

FILE: experiments/nlg_gru/model.py
  class GRU2 (line 11) | class GRU2(T.nn.Module):
    method __init__ (line 12) | def __init__(self, input_size, hidden_size, input_bias, hidden_bias):
    method _forward_cell (line 19) | def _forward_cell(self, input : Tensor, hidden : Tensor) -> Tensor:
    method forward (line 30) | def forward(self, input : Tensor) -> Tuple[Tensor, Tensor]:
  class Embedding (line 39) | class Embedding(T.nn.Module):
    method __init__ (line 40) | def __init__(self, vocab_size, embedding_size):
    method forward (line 49) | def forward(self, input : Tensor, embed : bool) -> Tensor:
  class GRU (line 57) | class GRU(BaseModel): #DLM_2_0
    method __init__ (line 58) | def __init__(self, model_config, OOV_correct=False, dropout=0.0, topK_...
    method forward (line 73) | def forward(self, input : T.Tensor) -> Tuple[Tensor, Tensor]:
    method loss (line 84) | def loss(self, input : T.Tensor) -> T.Tensor:
    method inference (line 102) | def inference(self, input):

FILE: experiments/nlg_gru/utils/utility.py
  function load_vocab (line 19) | def load_vocab(url):
  function to_indices (line 36) | def to_indices(vocab, batch, ndim=2, oov_idx=0, pad_idx=-1):
  function case_backoff_batch (line 89) | def case_backoff_batch(batch, vocab):
  function encode_data (line 113) | def encode_data(data_dict, vocab):

FILE: experiments/nlp_rnn_fedshakespeare/dataloaders/dataloader.py
  class DataLoader (line 10) | class DataLoader(BaseDataLoader):
    method __init__ (line 11) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/nlp_rnn_fedshakespeare/dataloaders/dataset.py
  class Dataset (line 8) | class Dataset(BaseDataset):
    method __init__ (line 9) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 33) | def __getitem__(self, idx):
    method __len__ (line 36) | def __len__(self):
    method load_data (line 39) | def load_data(self, data, test_only):

FILE: experiments/nlp_rnn_fedshakespeare/dataloaders/preprocessing.py
  function preprocess (line 38) | def preprocess(sentences, max_seq_len=SEQUENCE_LENGTH):
  function char_to_id (line 63) | def char_to_id(char):
  function get_word_dict (line 70) | def get_word_dict():
  function split (line 79) | def split(dataset):
  function download_files (line 85) | def download_files(data_cache_dir):
  class FEDSHAKESPEARE (line 102) | class FEDSHAKESPEARE:
    method __init__ (line 103) | def __init__(self) :

FILE: experiments/nlp_rnn_fedshakespeare/model.py
  class nlp_rnn_fedshakespeare (line 12) | class nlp_rnn_fedshakespeare(nn.Module):
    method __init__ (line 13) | def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
    method forward (line 26) | def forward(self, input_seq):
  class RNN (line 40) | class RNN(BaseModel):
    method __init__ (line 43) | def __init__(self, model_config):
    method loss (line 47) | def loss(self, input: torch.Tensor) -> torch.Tensor:
    method inference (line 55) | def inference(self, input):

FILE: experiments/semisupervision/dataloaders/RandAugment.py
  function ShearX (line 16) | def ShearX(img, v):  # [-0.3, 0.3]
  function ShearY (line 23) | def ShearY(img, v):  # [-0.3, 0.3]
  function TranslateX (line 30) | def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
  function TranslateXabs (line 38) | def TranslateXabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
  function TranslateY (line 45) | def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
  function TranslateYabs (line 53) | def TranslateYabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
  function Rotate (line 60) | def Rotate(img, v):  # [-30, 30]
  function AutoContrast (line 67) | def AutoContrast(img, _):
  function Invert (line 71) | def Invert(img, _):
  function Equalize (line 75) | def Equalize(img, _):
  function Flip (line 79) | def Flip(img, _):  # not from the paper
  function Solarize (line 83) | def Solarize(img, v):  # [0, 256]
  function SolarizeAdd (line 88) | def SolarizeAdd(img, addition=0, threshold=128):
  function Posterize (line 97) | def Posterize(img, v):  # [4, 8]
  function Contrast (line 103) | def Contrast(img, v):  # [0.1,1.9]
  function Color (line 108) | def Color(img, v):  # [0.1,1.9]
  function Brightness (line 113) | def Brightness(img, v):  # [0.1,1.9]
  function Sharpness (line 118) | def Sharpness(img, v):  # [0.1,1.9]
  function Cutout (line 123) | def Cutout(img, v):  # [0, 60] => percentage: [0, 0.2]
  function CutoutAbs (line 132) | def CutoutAbs(img, v):  # [0, 60] => percentage: [0, 0.2]
  function SamplePairing (line 154) | def SamplePairing(imgs):  # [0, 0.4]
  function Identity (line 163) | def Identity(img, v):
  function augment_list (line 167) | def augment_list(grey):  # 16 oeprations and their ranges
  class Lighting (line 231) | class Lighting(object):
    method __init__ (line 234) | def __init__(self, alphastd, eigval, eigvec):
    method __call__ (line 239) | def __call__(self, img):
  class CutoutDefault (line 252) | class CutoutDefault(object):
    method __init__ (line 256) | def __init__(self, length):
    method __call__ (line 259) | def __call__(self, img):
  class RandAugment (line 277) | class RandAugment:
    method __init__ (line 278) | def __init__(self, n, m, grey=False):
    method __call__ (line 283) | def __call__(self, img):

FILE: experiments/semisupervision/dataloaders/cifar_dataset.py
  class CIFAR100 (line 23) | class CIFAR100:
    method __init__ (line 24) | def __init__(self, user_idx=None, test_only=None, args=None, read_data...
  function save_json (line 86) | def save_json(dict, filename):
  function _process (line 91) | def _process(dataset, train=True):
  function partition_imagedataset (line 115) | def partition_imagedataset(X_train, Y_train, X_unlabel_train, Y_unlabel_...
  function __getDirichletData__ (line 172) | def __getDirichletData__(y, args):

FILE: experiments/semisupervision/dataloaders/dataloader.py
  class DataLoader (line 10) | class DataLoader(BaseDataLoader):
    method __init__ (line 11) | def __init__(self, mode, num_workers=0, **kwargs):
    method collate_fn (line 29) | def collate_fn(self, batch):

FILE: experiments/semisupervision/dataloaders/dataset.py
  class Dataset (line 8) | class Dataset(BaseDataset):
    method __init__ (line 9) | def __init__(self, data, test_only=False, user_idx=0, **kwargs):
    method __getitem__ (line 30) | def __getitem__(self, idx):
    method __len__ (line 33) | def __len__(self):
    method load_data (line 36) | def load_data(self, data, test_only, sup_config):

FILE: experiments/semisupervision/model.py
  class BasicBlock (line 15) | class BasicBlock(nn.Module):
    method __init__ (line 18) | def __init__(self, in_planes, planes, stride=1):
    method forward (line 35) | def forward(self, x):
  class Bottleneck (line 43) | class Bottleneck(nn.Module):
    method __init__ (line 46) | def __init__(self, in_planes, planes, stride=1):
    method forward (line 65) | def forward(self, x):
  class ResNet (line 74) | class ResNet(nn.Module):
    method __init__ (line 75) | def __init__(self, block, num_blocks, num_classes=10, inchannels = 3):
    method _make_layer (line 88) | def _make_layer(self, block, planes, num_blocks, stride):
    method forward (line 96) | def forward(self, x):
  function ResNet18 (line 108) | def ResNet18(num_classes=10):
  function ResNet18_emnist (line 111) | def ResNet18_emnist(num_classes=62, inchannel = 1):
  function ResNet18_organ (line 114) | def ResNet18_organ(num_classes=11, inchannel = 1):
  function ResNet18_path (line 117) | def ResNet18_path(num_classes=9, inchannel = 3):
  function ResNet18_blood (line 120) | def ResNet18_blood(num_classes=8, inchannel = 3):
  function ResNet34 (line 123) | def ResNet34(num_classes=10):
  function ResNet50 (line 126) | def ResNet50(num_classes=10):
  function ResNet101 (line 129) | def ResNet101(num_classes=10):
  function ResNet152 (line 132) | def ResNet152(num_classes=10):
  function test (line 135) | def test():
  class Res (line 141) | class Res(BaseModel):
    method __init__ (line 144) | def __init__(self, model_config):
    method forward (line 148) | def forward(self,x):
    method loss (line 151) | def loss(self, input: torch.Tensor) -> torch.Tensor:
    method inference (line 166) | def inference(self, input):

FILE: extensions/RL/RL.py
  class SequenceWise (line 19) | class SequenceWise(nn.Module):
    method __init__ (line 20) | def __init__(self, module):
    method forward (line 29) | def forward(self, x):
    method __repr__ (line 37) | def __repr__(self):
  class BatchRNN (line 44) | class BatchRNN(nn.Module):
    method __init__ (line 45) | def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirect...
    method forward (line 65) | def forward(self, x):
  class NeuralNetwork (line 79) | class NeuralNetwork(nn.Module):
    method __init__ (line 80) | def __init__(self, params, wantLSTM=False, batch_norm=False):
    method forward (line 135) | def forward(self, x):
  class RL (line 149) | class RL:
    method __init__ (line 150) | def __init__(self, config=None):
    method set_losses (line 177) | def set_losses(self, losses):
    method set_weights (line 180) | def set_weights(self, weights):
    method forward (line 183) | def forward(self, state=None):
    method train (line 205) | def train(self, batch=None):
    method make_model (line 268) | def make_model(self):
    method load_saved_status (line 286) | def load_saved_status(self):
    method load (line 301) | def load(self):
    method save (line 314) | def save(self, i):

FILE: extensions/privacy/__init__.py
  function compute_LDP_noise_std (line 15) | def compute_LDP_noise_std(eps, max_sensitivity, delta):
  function _beta2betainc_ratio (line 19) | def _beta2betainc_ratio(a, x):
  function _log_m1 (line 23) | def _log_m1(d, alpha, gamma):
  function _log_m2 (line 27) | def _log_m2(p, tau, alpha):
  function _efficient_m (line 31) | def _efficient_m(d, gamma, p):
  function privacy_parameters (line 37) | def privacy_parameters(eps0, eps, d):
  function private_unit2 (line 51) | def private_unit2(grad, gamma, prob):
  function add_gaussian_noise (line 68) | def add_gaussian_noise(grad, eps, max_grad, delta):
  function add_private_unit2_noise (line 75) | def add_private_unit2_noise(eps, grad):
  function scalar_DP (line 82) | def scalar_DP(r, eps, k, r_max):
  function laplace_noise (line 101) | def laplace_noise(max_sens, eps, vocab_size):
  function unroll_network (line 105) | def unroll_network(named_params, select_grad=False):
  function update_network (line 118) | def update_network(named_params, params_ids, flat_params, apply_to_grad=...
  function apply_global_dp (line 128) | def apply_global_dp(config, model, num_clients_curr_iter, select_grad=Tr...
  function apply_local_dp (line 154) | def apply_local_dp(trainer, weight, dp_config, add_weight_noise):
  function update_privacy_accountant (line 204) | def update_privacy_accountant(config, num_clients, curr_iter, num_client...

FILE: extensions/privacy/analysis.py
  function _log_add (line 43) | def _log_add(logx: float, logy: float) -> float:
  function _log_sub (line 60) | def _log_sub(logx: float, logy: float) -> float:
  function _compute_log_a_for_int_alpha (line 88) | def _compute_log_a_for_int_alpha(q: float, sigma: float, alpha: int) -> ...
  function _compute_log_a_for_frac_alpha (line 124) | def _compute_log_a_for_frac_alpha(q: float, sigma: float, alpha: float) ...
  function _compute_log_a (line 178) | def _compute_log_a(q: float, sigma: float, alpha: float) -> float:
  function _log_erfc (line 203) | def _log_erfc(x: float) -> float:
  function _compute_rdp (line 218) | def _compute_rdp(q: float, sigma: float, alpha: float) -> float:
  function compute_rdp (line 245) | def compute_rdp(
  function get_privacy_spent (line 272) | def get_privacy_spent(

FILE: extensions/privacy/dp_kmeans.py
  function sample (line 14) | def sample(ndim, r, num_samples=1):
  function sphere_packing_initialization (line 23) | def sphere_packing_initialization(n_clusters, n_dim, min_cluster_radius,
  function add_gaussian_noise (line 50) | def add_gaussian_noise(centers_new, weight_in_clusters, eps,
  function DPKMeans (line 75) | def DPKMeans(n_dim, eps, max_cluster_l2, max_sample_weight=1.0,
  function resetKMeans (line 191) | def resetKMeans():

FILE: extensions/privacy/metrics.py
  function extract_indices_from_embeddings (line 10) | def extract_indices_from_embeddings(gradients, batch, embed_size, vocab_...
  function compute_perplexity (line 25) | def compute_perplexity(encoded_batch, model):
  function practical_epsilon_leakage (line 33) | def practical_epsilon_leakage(original_params, model, encoded_batches, i...

FILE: extensions/quantization/quant.py
  function quant_model (line 9) | def quant_model(
  function find_min_max_gradient (line 53) | def find_min_max_gradient(
  function quant_bins (line 76) | def quant_bins(

FILE: testing/build_vocab.py
  function build_counter (line 8) | def build_counter(train_data, initial_counter=None):
  function build_vocab (line 30) | def build_vocab(counter, vocab_size=10000):
  function load_leaf_data (line 48) | def load_leaf_data(file_path):
  function save_vocab (line 56) | def save_vocab(vocab, target_dir):
  function main (line 64) | def main():
  function parse_args (line 87) | def parse_args():

FILE: testing/create_data.py
  function get_arg_parser (line 19) | def get_arg_parser() -> argparse.ArgumentParser:
  function reduce_users (line 24) | def reduce_users(file):
  function _process_and_save_to_disk (line 35) | def _process_and_save_to_disk(dataset, n_users, exp, output):
  function _dump_dict_to_hdf5 (line 57) | def _dump_dict_to_hdf5(data_dict: dict, hdf5_file: h5py.File):
  class HeartDataSet (line 73) | class HeartDataSet:
    method __init__ (line 74) | def __init__(self, heartdata, cutoff):
    method __len__ (line 78) | def __len__(self):
  function main (line 81) | def main():

FILE: testing/test_e2e_trainer.py
  function get_info (line 11) | def get_info(task):
  function run_pipeline (line 27) | def run_pipeline(data_path, output_path, config_path, task):
  function test_nlg_gru (line 50) | def test_nlg_gru():
  function test_ecg_cnn (line 56) | def test_ecg_cnn():
  function test_mlm_bert (line 63) | def test_mlm_bert():
  function test_classif_cnn (line 71) | def test_classif_cnn():

FILE: utils/data_utils.py
  class BatchSampler (line 9) | class BatchSampler(sampler.Sampler):
    method __init__ (line 16) | def __init__(self, dataset, batch_size, randomize=True, drop_last=False):
    method __iter__ (line 31) | def __iter__(self):
    method __len__ (line 38) | def __len__(self):
  class DynamicBatchSampler (line 42) | class DynamicBatchSampler(sampler.Sampler):
    method __init__ (line 50) | def __init__(self, sampler, frames_threshold, max_batch_size=0, unsort...
    method __iter__ (line 113) | def __iter__(self):
    method __len__ (line 118) | def __len__(self):

FILE: utils/dataloaders_utils.py
  function get_exp_dataloader (line 9) | def get_exp_dataloader(task):
  function make_train_dataloader (line 25) | def make_train_dataloader(data_config, data_path, clientx, task=None, ve...
  function make_val_dataloader (line 58) | def make_val_dataloader(data_config, data_path, task=None, data_strct=No...
  function make_test_dataloader (line 72) | def make_test_dataloader(data_config, data_path, task=None, data_strct=N...
  function get_dataset (line 85) | def get_dataset(data_path, config, task, mode, test_only=False, user_idx...
  function get_data_config (line 100) | def get_data_config(config, mode):

FILE: utils/optimizers/adamW.py
  class AdamW (line 8) | class AdamW(Optimizer):
    method __init__ (line 17) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weig...
    method step (line 30) | def step(self, closure=None):

FILE: utils/optimizers/lamb.py
  function log_lamb_rs (line 15) | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token...
  function log_lamb_rs (line 29) | def log_lamb_rs(optimizer, event_writer, token_count):
  class LAMB (line 33) | class LAMB(Optimizer):
    method __init__ (line 54) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
    method step (line 69) | def step(self, closure=None):

FILE: utils/optimizers/lars.py
  class LarsSGDV1 (line 10) | class LarsSGDV1(torch.optim.SGD):
    method __init__ (line 16) | def __init__(self, params, lr, momentum=0, dampening=0,
    method step (line 22) | def step(self, closure=None):
  class LarsSGD (line 74) | class LarsSGD(torch.optim.SGD):
    method __init__ (line 80) | def __init__(self, params, lr, momentum=0, dampening=0,
    method step (line 86) | def step(self, closure=None):

FILE: utils/preprocessing/create-hdf5.py
  function local_time (line 12) | def local_time():

FILE: utils/preprocessing/create-json.py
  function local_time (line 11) | def local_time():

FILE: utils/preprocessing/from_json_to_hdf5.py
  function local_time (line 11) | def local_time():

FILE: utils/utils.py
  function make_optimizer (line 27) | def make_optimizer(optimizer_config, model):
  function get_lr (line 67) | def get_lr(optimizer):
  function get_lr_all (line 72) | def get_lr_all(optimizer):
  function softmax (line 78) | def softmax(X, theta = 1.0, axis = None):
  class AverageMeter (line 117) | class AverageMeter(object):
    method __init__ (line 121) | def __init__(self, metric_name):
    method add (line 125) | def add(self, top, bottom):
    method get_macro_average (line 129) | def get_macro_average(self):
    method get_micro_average (line 134) | def get_micro_average(self):
    method get_average (line 138) | def get_average(self, l):
    method reset (line 141) | def reset(self):
    method display_results (line 144) | def display_results(self, loglevel=logging.INFO):
  function make_lr_scheduler (line 151) | def make_lr_scheduler(annealing_config, optimizer, num_batches=1):
  class RampupKeepExpdecayKeepLRScheduler (line 189) | class RampupKeepExpdecayKeepLRScheduler(torch.optim.lr_scheduler._LRSche...
    method __init__ (line 192) | def __init__(self, optimizer, peak_lr=0.001, floor_lr=0.00001, sr=1000...
    method step (line 207) | def step(self, epoch=None):
    method get_lr (line 212) | def get_lr(self):
  class ScheduledSamplingScheduler (line 228) | class ScheduledSamplingScheduler():
    method __init__ (line 236) | def __init__(self, model, ramp_start, ramp_stop,
    method step (line 245) | def step(self):
    method state_dict (line 256) | def state_dict(self):
    method load_state_dict (line 259) | def load_state_dict(self, state_dict):
  class NBestTaskScheduler (line 263) | class NBestTaskScheduler():
    method __init__ (line 269) | def __init__(self, num_tasks, iteration_per_task):
    method current_num_tasks (line 276) | def current_num_tasks(self):
    method no_label_updates (line 279) | def no_label_updates(self):
    method set_iteration_no (line 283) | def set_iteration_no(self, iter_no):
    method step (line 286) | def step(self):
  function init_logging (line 299) | def init_logging(log_dir, loglevel=logging.DEBUG):
  function print_cuda_stats (line 310) | def print_cuda_stats():
  function print_rank (line 319) | def print_rank(str, loglevel=logging.INFO):
  function print_profiler (line 324) | def print_profiler(profiler, loglevel=logging.INFO):
  function write_yaml (line 335) | def write_yaml(save_path, config):
  function torch_save (line 339) | def torch_save(save_path, state_or_model):
  function write_tokens (line 342) | def write_tokens(save_path, token_list):
  function try_except_save (line 348) | def try_except_save(save_fn, **kwargs):
  function write_nbest_jsonl (line 362) | def write_nbest_jsonl(uttid2jsonl, uttid2hypos, uttid2scores, outputpath...
  function write_multitask_jsonl (line 401) | def write_multitask_jsonl(uttid2jsonl, uttid2hypos, uttid2scores, output...
  function load_eval_result_jsonl (line 451) | def load_eval_result_jsonl(resultjsonl, uttid2hypos=OrderedDict(), uttid...
  function find_pretrained_model (line 486) | def find_pretrained_model(model_path, config):
  function flatten_grads_model (line 497) | def flatten_grads_model(learner) -> np.ndarray:
  function flatten_grads_array (line 502) | def flatten_grads_array(param_array)->np.array:
  function dist_weights_to_model (line 511) | def dist_weights_to_model(weights, parameters):
  function dist_params_to_model (line 521) | def dist_params_to_model(grads, model):
  function reshape_params_to_model (line 531) | def reshape_params_to_model(grads, model):
  function to_device (line 543) | def to_device(x):
  function update_json_log (line 546) | def update_json_log(log_path, status_info):
  function scrub_empty_clients (line 563) | def scrub_empty_clients(data_strct):
  function compute_grad_cosines (line 585) | def compute_grad_cosines(grads, model_grad):
  function convex_inference (line 598) | def convex_inference(model_global, model_personal, alpha):
  function alpha_update (line 605) | def alpha_update(model_global, model_personal, alpha, eta):
  function get_label_VAT (line 620) | def get_label_VAT(local_logits, server_logits, thre, comp):
Condensed preview — 151 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (827K chars).
[
  {
    "path": ".flake8",
    "chars": 22,
    "preview": "[flake8]\nignore = E501"
  },
  {
    "path": ".github/workflows/build_docs.yml",
    "chars": 955,
    "preview": "name: Build docs\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\n  workflow_dispatch:\n\njobs:"
  },
  {
    "path": ".github/workflows/codeql.yml",
    "chars": 1131,
    "preview": "# This is based on the standard CodeQL workflow provided by Github\nname: \"CodeQL\"\n\non:\n  push:\n    branches: [ \"main\" ]\n"
  },
  {
    "path": ".gitignore",
    "chars": 87,
    "preview": "__pycache__/\n.vscode/\ndoc/sphinx/_build\ntesting/logs.txt\ntesting/outputs\ntesting/mockup"
  },
  {
    "path": ".gitmodules",
    "chars": 113,
    "preview": "[submodule \"utils/dp-accountant\"]\n\tpath = utils/dp-accountant\n\turl = https://github.com/microsoft/prv_accountant\n"
  },
  {
    "path": "CHANGELOG.md",
    "chars": 2498,
    "preview": "# Changelog\n\nAll notable changes to this project will be documented in this file.\n\n## [0.1.0] - 2021-11-22\n\nWe're super "
  },
  {
    "path": "CITATION.cff",
    "chars": 363,
    "preview": "cff-version: 1.2.0\nmessage: \"To cite Microsoft FLUTE in academic papers, please cite it as below.\"\nauthors:\n  - name: \"M"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 444,
    "preview": "# Microsoft Open Source Code of Conduct\n\nThis project has adopted the [Microsoft Open Source Code of Conduct](https://op"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1038,
    "preview": "# Contributing\n\nThis project welcomes contributions and suggestions. Most contributions require you to\nagree to a Contri"
  },
  {
    "path": "LICENSE.TXT",
    "chars": 1073,
    "preview": "Copyright (c) Microsoft Corporation.\n\nMIT License\n\nPermission is hereby granted, free of charge, to any person obtaining"
  },
  {
    "path": "NOTICE.txt",
    "chars": 3489,
    "preview": "THIRD-PARTY SOFTWARE NOTICES AND INFORMATION\nDo Not Translate or Localize\n\nThis software incorporates components from th"
  },
  {
    "path": "README.md",
    "chars": 12530,
    "preview": "# FLUTE\n\nWelcome to FLUTE (Federated Learning Utilities for Testing and Experimentation), a platform for conducting high"
  },
  {
    "path": "SECURITY.md",
    "chars": 2757,
    "preview": "<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->\n\n## Security\n\nMicrosoft takes the security of our software products an"
  },
  {
    "path": "azure-pipelines.yml",
    "chars": 699,
    "preview": "trigger:\n- main\n\npool:\n  vmImage: 'windows-latest'\n\nsteps:\n- task: CredScan@2\n  inputs:\n    toolMajorVersion: 'V2'\n\n- ta"
  },
  {
    "path": "configs/hello_world_mlm_bert_json.yaml",
    "chars": 7265,
    "preview": "# Basic configuration file for running mlm_bert example using json files.\n# Parameters needed to initialize the model\nmo"
  },
  {
    "path": "configs/hello_world_nlg_gru_json.yaml",
    "chars": 7225,
    "preview": "# Basic configuration file for running nlg_gru example using json files.\n# Parameters needed to initialize the model\nmod"
  },
  {
    "path": "core/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "core/client.py",
    "chars": 24070,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n'''\nThe Client object is short-lived, instantia"
  },
  {
    "path": "core/config.py",
    "chars": 29514,
    "preview": "# Note this import requires python 3.7+\n# Do we want to commit to this?\nfrom __future__ import annotations\nfrom dataclas"
  },
  {
    "path": "core/dataloader.py",
    "chars": 363,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom torch.utils.data import DataLoader as PyT"
  },
  {
    "path": "core/dataset.py",
    "chars": 736,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom torch.utils.data import Dataset as PyTorc"
  },
  {
    "path": "core/evaluation.py",
    "chars": 8853,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n'''\nIn this file we define the functions for ru"
  },
  {
    "path": "core/federated.py",
    "chars": 29176,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport os\nimport cProfile\nimport logging\nimpor"
  },
  {
    "path": "core/metrics.py",
    "chars": 2811,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n'''\nIn this file we define the wrapper class fo"
  },
  {
    "path": "core/model.py",
    "chars": 1413,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch as T\nfrom abc import ABC, abstrac"
  },
  {
    "path": "core/schema.py",
    "chars": 15989,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n# '''\n# In this file we define the  schema for "
  },
  {
    "path": "core/server.py",
    "chars": 27601,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n'''\nIn this file, we define the classes that li"
  },
  {
    "path": "core/strategies/__init__.py",
    "chars": 674,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom .base import BaseStrategy\nfrom .fedavg im"
  },
  {
    "path": "core/strategies/base.py",
    "chars": 1904,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom abc import abstractmethod\n\n\n@abstractmeth"
  },
  {
    "path": "core/strategies/dga.py",
    "chars": 17476,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport copy\nimport json\nimport logging\nimport "
  },
  {
    "path": "core/strategies/fedavg.py",
    "chars": 7013,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport json\nimport logging\nimport os\n\nimport t"
  },
  {
    "path": "core/strategies/fedlabels.py",
    "chars": 9024,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport json\nimport logging\nimport os\n\nimport t"
  },
  {
    "path": "core/strategies/utils.py",
    "chars": 1017,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport logging\n\nimport numpy as np\n\nfrom utils"
  },
  {
    "path": "core/trainer.py",
    "chars": 30371,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport logging\nimport os\nimport re\nimport copy"
  },
  {
    "path": "doc/sphinx/Makefile",
    "chars": 634,
    "preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the "
  },
  {
    "path": "doc/sphinx/advanced.rst",
    "chars": 127,
    "preview": "Advanced Topics\n===============\n\nPrivacy\n-------\n\nAggregation Options\n-------------------\n\n\nOptimizer Options\n----------"
  },
  {
    "path": "doc/sphinx/class_reference.rst",
    "chars": 454,
    "preview": "\n\nClass Reference\n===============\n\nFLUTE Core\n~~~~~~~~~~\n\ncore/server\n-----------\n\n.. automodule:: core.server\n   :membe"
  },
  {
    "path": "doc/sphinx/conf.py",
    "chars": 2010,
    "preview": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common op"
  },
  {
    "path": "doc/sphinx/index.rst",
    "chars": 505,
    "preview": ".. FLUTE documentation master file, created by\n   sphinx-quickstart on Sat Jun 19 09:15:36 2021.\n   You can adapt this f"
  },
  {
    "path": "doc/sphinx/launch.rst",
    "chars": 4640,
    "preview": "Launch FLUTE\n================\n\nLocal run\n------------\n\nInstall the requirements stated inside of requirements.txt. Ideal"
  },
  {
    "path": "doc/sphinx/make.bat",
    "chars": 760,
    "preview": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-bu"
  },
  {
    "path": "doc/sphinx/overview.rst",
    "chars": 2675,
    "preview": "FLUTE Overview\n============\n\nFLUTE: Federated Learning Utilities and Tools for Experimentation is a high-performance ope"
  },
  {
    "path": "doc/sphinx/reference.rst",
    "chars": 1144,
    "preview": "Option Reference\n================\n\nCommand Line Arguments\n----------------------\n\nYAML Configuration\n------------------\n"
  },
  {
    "path": "doc/sphinx/requirements.txt",
    "chars": 31,
    "preview": "sphinx_rtd_theme\njinja2==3.0.3\n"
  },
  {
    "path": "doc/sphinx/scenarios.rst",
    "chars": 10403,
    "preview": "Adding New Scenarios\n====================\n\nData Preparation\n------------\nFLUTE provides the abstract class `BaseDataset`"
  },
  {
    "path": "e2e_trainer.py",
    "chars": 9765,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n'''\nThis is the main script to run on each NCC"
  },
  {
    "path": "experiments/__init__.py",
    "chars": 1644,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nfrom utils import print_rank, pri"
  },
  {
    "path": "experiments/classif_cnn/.gitignore",
    "chars": 24,
    "preview": "utils/data\n*.hdf5\n*.json"
  },
  {
    "path": "experiments/classif_cnn/README.md",
    "chars": 3145,
    "preview": "# Simple example of a CNN on CIFAR-10\n\nOur objective here is to bring a simple experiment from the Pytorch tutorials,\nmo"
  },
  {
    "path": "experiments/classif_cnn/config.yaml",
    "chars": 3412,
    "preview": "# Basic configuration file for running classif_cnn example using torchvision CIFAR10 dataset.\n# Parameters needed to ini"
  },
  {
    "path": "experiments/classif_cnn/dataloaders/cifar_dataset.py",
    "chars": 1931,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\nimport time\nimport torchvision\nimport torchvisi"
  },
  {
    "path": "experiments/classif_cnn/dataloaders/dataloader.py",
    "chars": 863,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\n\nfrom core.dataloader import Base"
  },
  {
    "path": "experiments/classif_cnn/dataloaders/dataset.py",
    "chars": 1744,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nfrom core.dataset import Ba"
  },
  {
    "path": "experiments/classif_cnn/model.py",
    "chars": 2250,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nfrom torch import nn\nfrom torch.n"
  },
  {
    "path": "experiments/classif_cnn/utils/centralized_training.py",
    "chars": 3121,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n'''Simple example of a CNN on CIFAR-10\n\nThis i"
  },
  {
    "path": "experiments/classif_cnn/utils/download_and_convert_data.py",
    "chars": 3157,
    "preview": "import h5py\nimport json\nimport time\n\nimport torchvision\nimport torchvision.transforms as transforms\nimport tqdm\n\n\ndef _d"
  },
  {
    "path": "experiments/cv/README.md",
    "chars": 2658,
    "preview": "# Simple example of ResNet model using personalization\n\nOur objective here is to bring a simple experiment of Computer V"
  },
  {
    "path": "experiments/cv/config.yaml",
    "chars": 3122,
    "preview": "model_config:\n    model_type: resnet50 #vgg11                                  # class w/ `loss` and `inference` methods"
  },
  {
    "path": "experiments/cv/data.py",
    "chars": 14348,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport logging\nimport h5py\nimport json\nimport "
  },
  {
    "path": "experiments/cv/dataloaders/dataloader.py",
    "chars": 897,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nimport numpy as np\n\nfrom core.dat"
  },
  {
    "path": "experiments/cv/dataloaders/dataset.py",
    "chars": 2602,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\n\nfrom core.dataset import B"
  },
  {
    "path": "experiments/cv/model.py",
    "chars": 18058,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n'''\nModified from https://github.com/pytorch/vi"
  },
  {
    "path": "experiments/cv/model_vgg.py",
    "chars": 5854,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n'''\nModified from https://github.com/pytorch/v"
  },
  {
    "path": "experiments/cv/server.py",
    "chars": 727,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n'''\nIn this file, we define the local server th"
  },
  {
    "path": "experiments/cv_cnn_femnist/README.md",
    "chars": 2801,
    "preview": "## FedML Benchmark\n\n### Examples\n\nThe example in this folder was taken from [FedML](https://github.com/FedML-AI/FedML/tr"
  },
  {
    "path": "experiments/cv_cnn_femnist/config.yaml",
    "chars": 3384,
    "preview": "# Basic configuration file for running classif_cnn example using torchvision CIFAR10 dataset.\n# Parameters needed to ini"
  },
  {
    "path": "experiments/cv_cnn_femnist/dataloaders/dataloader.py",
    "chars": 926,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nimport numpy as np \n\nfrom core.da"
  },
  {
    "path": "experiments/cv_cnn_femnist/dataloaders/dataset.py",
    "chars": 2065,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nfrom core.dataset import Ba"
  },
  {
    "path": "experiments/cv_cnn_femnist/dataloaders/preprocess.py",
    "chars": 2460,
    "preview": "import os\nimport h5py\nimport wget\nimport tarfile\n\ndata_cache_dir = \"./data\"\nDEFAULT_TRAIN_FILE = \"fed_emnist_train.h5\"\nD"
  },
  {
    "path": "experiments/cv_cnn_femnist/model.py",
    "chars": 4505,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom core.model import BaseModel\n\n''' \n    The CN"
  },
  {
    "path": "experiments/cv_lr_mnist/README.md",
    "chars": 3791,
    "preview": "## FedML Benchmark\n\n### Examples\n\nThe example in this folder was taken from [FedML](https://github.com/FedML-AI/FedML/tr"
  },
  {
    "path": "experiments/cv_lr_mnist/config.yaml",
    "chars": 3417,
    "preview": "# Basic configuration file for running classif_cnn example using torchvision CIFAR10 dataset.\n# Parameters needed to ini"
  },
  {
    "path": "experiments/cv_lr_mnist/dataloaders/dataloader.py",
    "chars": 922,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nimport numpy as np\n\nfrom core.dat"
  },
  {
    "path": "experiments/cv_lr_mnist/dataloaders/dataset.py",
    "chars": 2051,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nfrom core.dataset import Ba"
  },
  {
    "path": "experiments/cv_lr_mnist/dataloaders/preprocessing.py",
    "chars": 2426,
    "preview": "import os\nimport wget\nimport zipfile\nimport numpy as np\nimport json\n\nFEDML_DATA_MNIST_URL = \"https://fedcv.s3.us-west-1."
  },
  {
    "path": "experiments/cv_lr_mnist/model.py",
    "chars": 1820,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom core.model import BaseModel\n\n''' \n    The Lo"
  },
  {
    "path": "experiments/cv_resnet_fedcifar100/README.md",
    "chars": 2831,
    "preview": "## FedML Benchmark\n\n### Examples\n\nThe example in this folder was taken from [FedML](https://github.com/FedML-AI/FedML/tr"
  },
  {
    "path": "experiments/cv_resnet_fedcifar100/config.yaml",
    "chars": 3390,
    "preview": "# Basic configuration file for running classif_cnn example using torchvision CIFAR10 dataset.\n# Parameters needed to ini"
  },
  {
    "path": "experiments/cv_resnet_fedcifar100/dataloaders/dataloader.py",
    "chars": 932,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nimport numpy as np\n\nfrom core.dat"
  },
  {
    "path": "experiments/cv_resnet_fedcifar100/dataloaders/dataset.py",
    "chars": 2076,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nfrom core.dataset import Ba"
  },
  {
    "path": "experiments/cv_resnet_fedcifar100/dataloaders/preprocessing.py",
    "chars": 2484,
    "preview": "import os\nimport wget\nimport zipfile\nimport tarfile\nimport h5py\n\ndata_cache_dir = \"./data\"\nDEFAULT_TRAIN_FILE = \"fed_cif"
  },
  {
    "path": "experiments/cv_resnet_fedcifar100/group_normalization.py",
    "chars": 5547,
    "preview": "import torch.nn.functional as F\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\n\"\"\" This group normalization script w"
  },
  {
    "path": "experiments/cv_resnet_fedcifar100/model.py",
    "chars": 8953,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\nfrom torch.nn import functional"
  },
  {
    "path": "experiments/ecg_cnn/.gitignore",
    "chars": 31,
    "preview": "./data\n./raw_data\n*.hdf5\n*.json"
  },
  {
    "path": "experiments/ecg_cnn/centralized_model.ipynb",
    "chars": 63171,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 1,\n      \"metadata\": {\n        \"gather\": {\n    "
  },
  {
    "path": "experiments/ecg_cnn/config.yaml",
    "chars": 3178,
    "preview": "# Basic configuration file for running ecg_cnn example using json files.\n# Parameters needed to initialize the model\nmod"
  },
  {
    "path": "experiments/ecg_cnn/dataloaders/dataloader.py",
    "chars": 889,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom experiments.ecg_cnn.dataloaders.dataset i"
  },
  {
    "path": "experiments/ecg_cnn/dataloaders/dataset.py",
    "chars": 2302,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport h5py\nimport numpy as np\n\nfrom core.data"
  },
  {
    "path": "experiments/ecg_cnn/model.py",
    "chars": 5411,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n'''The model architecture used was first creat"
  },
  {
    "path": "experiments/ecg_cnn/readme.md",
    "chars": 6347,
    "preview": "# Example of CNN-LSTM model on Arrhythmia dataset\n\nThe objective of this experiment is to show the capabilities of FLUTE"
  },
  {
    "path": "experiments/ecg_cnn/utils/preprocess.py",
    "chars": 3991,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport h5py\nimport time\nimport tqdm\nimport csv"
  },
  {
    "path": "experiments/fednewsrec/README.md",
    "chars": 1004,
    "preview": "### Data\n\nIn order to run this experiment, you need to previously download the MIND dataset [here](https://msnews.github"
  },
  {
    "path": "experiments/fednewsrec/config.yaml",
    "chars": 3327,
    "preview": "# Parameters needed to initialize the model\nmodel_config:\n    model_type: FEDNEWS                                    # c"
  },
  {
    "path": "experiments/fednewsrec/dataloaders/dataloader.py",
    "chars": 1371,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nimport numpy as np\nfrom core.data"
  },
  {
    "path": "experiments/fednewsrec/dataloaders/dataset.py",
    "chars": 2039,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nimport torch\nfrom core.data"
  },
  {
    "path": "experiments/fednewsrec/dataloaders/preprocess_mind.py",
    "chars": 9505,
    "preview": "from nltk.tokenize import word_tokenize\nimport random\nimport os\nimport numpy as np\nimport torch\n\nMAX_SENTENCE = 30\nMAX_A"
  },
  {
    "path": "experiments/fednewsrec/fednewsrec_model.py",
    "chars": 14908,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\n\nnpratio = 4\n\n''' \n    The FedNewsRec model is taken from FedNewsR"
  },
  {
    "path": "experiments/fednewsrec/model.py",
    "chars": 5033,
    "preview": "import os\nimport torch\nfrom torch.nn import CrossEntropyLoss\nfrom torch.nn import functional as F\nimport numpy as np\nfro"
  },
  {
    "path": "experiments/fednewsrec/utils.py",
    "chars": 605,
    "preview": "import numpy as np\n\ndef mrr_score(y_true, y_score):\n    order = np.argsort(y_score)[::-1]\n    y_true = np.take(y_true, o"
  },
  {
    "path": "experiments/mlm_bert/README.md",
    "chars": 1255,
    "preview": "# Simple example of a MLM task on Reddit Dataset\n\nInstructions on how to run the experiment, given below.\n\n## Preparing "
  },
  {
    "path": "experiments/mlm_bert/config.py",
    "chars": 3594,
    "preview": "from __future__ import annotations\nfrom dataclasses import dataclass\nimport sys\nsys.path.append('../../')\nfrom core.conf"
  },
  {
    "path": "experiments/mlm_bert/dataloaders/dataloader.py",
    "chars": 4370,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom transformers.data.data_collator import de"
  },
  {
    "path": "experiments/mlm_bert/dataloaders/dataset.py",
    "chars": 10165,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom core.dataset import BaseDataset\nfrom tran"
  },
  {
    "path": "experiments/mlm_bert/model.py",
    "chars": 19576,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch as T\nfrom utils import print_rank"
  },
  {
    "path": "experiments/mlm_bert/utils/trainer_pt_utils.py",
    "chars": 21208,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n# coding=utf-8\n# Copyright 2020-present the Hu"
  },
  {
    "path": "experiments/mlm_bert/utils/trainer_utils.py",
    "chars": 3065,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n# coding=utf-8\n# Copyright 2020-present the Hu"
  },
  {
    "path": "experiments/nlg_gru/README.md",
    "chars": 1371,
    "preview": "# Simple example of a NLG task on Reddit Dataset\n\nInstructions on how to run the experiment, given below.\n\n## Preparing "
  },
  {
    "path": "experiments/nlg_gru/config.py",
    "chars": 1005,
    "preview": "from __future__ import annotations\nfrom dataclasses import dataclass\nimport sys\nsys.path.append('../../')\nfrom core.conf"
  },
  {
    "path": "experiments/nlg_gru/dataloaders/dataloader.py",
    "chars": 3454,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport random\nimport torch\nimport numpy as np\n"
  },
  {
    "path": "experiments/nlg_gru/dataloaders/dataset.py",
    "chars": 2819,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nimport logging\nimport json\n"
  },
  {
    "path": "experiments/nlg_gru/model.py",
    "chars": 5412,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch as T\nfrom torch import Tensor\nfro"
  },
  {
    "path": "experiments/nlg_gru/utils/utility.py",
    "chars": 6224,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport os\nimport json\nimport time\nfrom argpars"
  },
  {
    "path": "experiments/nlp_rnn_fedshakespeare/README.md",
    "chars": 2831,
    "preview": "## FedML Benchmark\n\n### Examples\n\nThe example in this folder was taken from [FedML](https://github.com/FedML-AI/FedML/tr"
  },
  {
    "path": "experiments/nlp_rnn_fedshakespeare/config.yaml",
    "chars": 3385,
    "preview": "# Basic configuration file for running classif_cnn example using torchvision CIFAR10 dataset.\n# Parameters needed to ini"
  },
  {
    "path": "experiments/nlp_rnn_fedshakespeare/dataloaders/dataloader.py",
    "chars": 933,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nimport numpy as np\n\nfrom core.dat"
  },
  {
    "path": "experiments/nlp_rnn_fedshakespeare/dataloaders/dataset.py",
    "chars": 2077,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nfrom core.dataset import Ba"
  },
  {
    "path": "experiments/nlp_rnn_fedshakespeare/dataloaders/preprocessing.py",
    "chars": 4794,
    "preview": "import logging\nimport os\nimport wget\nimport tarfile\nimport h5py\nimport collections\nimport numpy as np\n\ndata_cache_dir = "
  },
  {
    "path": "experiments/nlp_rnn_fedshakespeare/model.py",
    "chars": 2657,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom core.model import BaseModel\n\n''' \n    The CN"
  },
  {
    "path": "experiments/semisupervision/README.md",
    "chars": 599,
    "preview": "### Data\n\nIn order to run this experiment, you need to previously run the script [cifar_dataset.py](dataloaders/cifar_da"
  },
  {
    "path": "experiments/semisupervision/config.yaml",
    "chars": 3858,
    "preview": "# Basic configuration file for running semisupervision with data loaded on-the-fly\n# Parameters needed to initialize the"
  },
  {
    "path": "experiments/semisupervision/dataloaders/RandAugment.py",
    "chars": 7759,
    "preview": "'''\nCode in this file is adapted from rpmcruz/autoaugment\nhttps://github.com/rpmcruz/autoaugment/blob/master/transformat"
  },
  {
    "path": "experiments/semisupervision/dataloaders/cifar_dataset.py",
    "chars": 9006,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport os \nimport time\nimport json\n\nimport tor"
  },
  {
    "path": "experiments/semisupervision/dataloaders/dataloader.py",
    "chars": 934,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport torch\nimport numpy as np\n\nfrom core.dat"
  },
  {
    "path": "experiments/semisupervision/dataloaders/dataset.py",
    "chars": 1878,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nfrom core.dataset import Ba"
  },
  {
    "path": "experiments/semisupervision/model.py",
    "chars": 6400,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom core.model import"
  },
  {
    "path": "extensions/RL/RL.py",
    "chars": 12219,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport logging\nimport os\nimport json\nimport ra"
  },
  {
    "path": "extensions/__init__.py",
    "chars": 149,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom extensions.RL.RL import *\nfrom extensions"
  },
  {
    "path": "extensions/privacy/__init__.py",
    "chars": 10382,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport numpy as np\nimport torch as T\nimport lo"
  },
  {
    "path": "extensions/privacy/analysis.py",
    "chars": 9405,
    "preview": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n\n\"\"\"\n*Borrowed from Facebo"
  },
  {
    "path": "extensions/privacy/dp_kmeans.py",
    "chars": 8046,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport sys\nimport numpy as np\nfrom scipy.speci"
  },
  {
    "path": "extensions/privacy/metrics.py",
    "chars": 3873,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport logging\nimport numpy as np\nimport torch"
  },
  {
    "path": "extensions/quantization/quant.py",
    "chars": 3561,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport logging\nimport torch\nfrom utils import "
  },
  {
    "path": "requirements.txt",
    "chars": 206,
    "preview": "torch==1.11.0\nmpi4py\neasydict\nscipy\npsutil\ntransformers\ntorchvision\npandas\nh5py\nsphinx_rtd_theme\nazureml-core\nazureml-de"
  },
  {
    "path": "testing/README.md",
    "chars": 881,
    "preview": "## Information\n\nThe tests are designed to evaluate the operation of the tasks, not the performance. Therefore, we are us"
  },
  {
    "path": "testing/build_vocab.py",
    "chars": 2741,
    "preview": "\"\"\"Builds vocabulary file from data.\"\"\"\n\nimport argparse\nimport collections\nimport json\nimport os\n\ndef build_counter(tra"
  },
  {
    "path": "testing/create_data.py",
    "chars": 7611,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport os\nimport csv\nimport json\nimport random"
  },
  {
    "path": "testing/hello_world_classif_cnn.yaml",
    "chars": 3134,
    "preview": "# Basic configuration file for running classif_cnn example using hdf5 files.\n# Parameters needed to initialize the model"
  },
  {
    "path": "testing/hello_world_ecg_cnn.yaml",
    "chars": 3201,
    "preview": "# Basic configuration file for running ecg_cnn example using json files.\n# Parameters needed to initialize the model\nmod"
  },
  {
    "path": "testing/hello_world_mlm_bert.yaml",
    "chars": 6252,
    "preview": "# Basic configuration file for running mlm_bert example using json files.\n# Parameters needed to initialize the model\nmo"
  },
  {
    "path": "testing/hello_world_nlg_gru.yaml",
    "chars": 4647,
    "preview": "# Basic configuration file for running nlg_gru example using json files.\n# Parameters needed to initialize the model\nmod"
  },
  {
    "path": "testing/test_e2e_trainer.py",
    "chars": 2309,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport subprocess\nimport os\nimport platform\nim"
  },
  {
    "path": "utils/__init__.py",
    "chars": 168,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nfrom .utils import *\nfrom utils.optimizers.lar"
  },
  {
    "path": "utils/data_utils.py",
    "chars": 4553,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport random\nimport logging\nfrom torch.utils."
  },
  {
    "path": "utils/dataloaders_utils.py",
    "chars": 4620,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport os\nimport logging\nfrom importlib.machin"
  },
  {
    "path": "utils/optimizers/adamW.py",
    "chars": 4159,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport math\nimport torch\nfrom torch.optim impo"
  },
  {
    "path": "utils/optimizers/lamb.py",
    "chars": 5294,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n\"\"\"Lamb optimizer.\"\"\"\n\nimport collections\nimpo"
  },
  {
    "path": "utils/optimizers/lars.py",
    "chars": 4187,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\n\"\"\"distoptim.hit package\"\"\"\nimport logging\nimp"
  },
  {
    "path": "utils/preprocessing/create-hdf5.py",
    "chars": 1835,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport h5py\nimport time\nfrom tqdm import tqdm\n"
  },
  {
    "path": "utils/preprocessing/create-json.py",
    "chars": 1450,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport json\nimport time\nfrom tqdm import tqdm\n"
  },
  {
    "path": "utils/preprocessing/from_json_to_hdf5.py",
    "chars": 915,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport json\nimport h5py\nfrom tqdm import tqdm\n"
  },
  {
    "path": "utils/utils.py",
    "chars": 25462,
    "preview": "# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT license.\n\nimport os\nimport sys\nimport numpy as np\nimport"
  }
]

About this extraction

This page contains the full source code of the microsoft/msrflute GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 151 files (775.6 KB), approximately 202.7k tokens, and a symbol index with 773 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!