Repository: iQua/flsim Branch: master Commit: 59d39e6f50bd Files: 27 Total size: 73.0 KB Directory structure: gitextract_y9rb20wz/ ├── .gitignore ├── LICENSE ├── README.md ├── client.py ├── config.py ├── configs/ │ ├── CIFAR-10/ │ │ └── cifar-10.json │ ├── FashionMNIST/ │ │ └── fashionmnist.json │ ├── MNIST/ │ │ └── mnist.json │ └── config.json.template ├── environment.yml ├── load_data.py ├── models/ │ ├── CIFAR-10/ │ │ └── fl_model.py │ ├── FashionMNIST/ │ │ └── fl_model.py │ ├── MNIST/ │ │ └── fl_model.py │ └── fl_model.py ├── run.py ├── scripts/ │ ├── analyze_logs.py │ └── pca.py ├── server/ │ ├── __init__.py │ ├── accavg.py │ ├── directed.py │ ├── kcenter.py │ ├── kmeans.py │ ├── magavg.py │ └── server.py └── utils/ ├── dists.py └── kcenter.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # Data files data/ # Global model files **/global ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # FLSim ## About Welcome to **FLSim**, a PyTorch based federated learning simulation framework, created for experimental research in a paper accepted by [IEEE INFOCOM 2020](https://infocom2020.ieee-infocom.org): [Hao Wang](https://www.haow.ca), Zakhary Kaplan, [Di Niu](https://sites.ualberta.ca/~dniu/Homepage/Home.html), [Baochun Li](http://iqua.ece.toronto.edu/bli/index.html). "Optimizing Federated Learning on Non-IID Data with Reinforcement Learning," in the Proceedings of IEEE INFOCOM, Beijing, China, April 27-30, 2020. ## Installation To install **FLSim**, all that needs to be done is clone this repository to the desired directory. ### Dependencies **FLSim** uses [Anaconda](https://www.anaconda.com/distribution/) to manage Python and it's dependencies, listed in [`environment.yml`](environment.yml). To install the `fl-py37` Python environment, set up Anaconda (or Miniconda), then download the environment dependencies with: ```shell conda env create -f environment.yml ``` ## Usage Before using the repository, make sure to activate the `fl-py37` environment with: ```shell conda activate fl-py37 ``` ### Simulation To start a simulation, run [`run.py`](run.py) from the repository's root directory: ```shell python run.py --config=config.json --log=INFO ``` ##### `run.py` flags * `--config` (`-c`): path to the configuration file to be used. * `--log` (`-l`): level of logging info to be written to console, defaults to `INFO`. ##### `config.json` files **FLSim** uses a JSON file to manage the configuration parameters for a federated learning simulation. Provided in the repository is a generic template and three preconfigured simulation files for the CIFAR-10, FashionMNIST, and MNIST datasets. For a detailed list of configuration options, see the [wiki page](https://github.com/iQua/flsim/wiki/Configuration). If you have any questions, please feel free to contact Hao Wang (haowang@ece.utoronto.ca) ================================================ FILE: client.py ================================================ import logging import torch import torch.nn as nn import torch.optim as optim class Client(object): """Simulated federated learning client.""" def __init__(self, client_id): self.client_id = client_id def __repr__(self): return 'Client #{}: {} samples in labels: {}'.format( self.client_id, len(self.data), set([label for _, label in self.data])) # Set non-IID data configurations def set_bias(self, pref, bias): self.pref = pref self.bias = bias def set_shard(self, shard): self.shard = shard # Server interactions def download(self, argv): # Download from the server. try: return argv.copy() except: return argv def upload(self, argv): # Upload to the server try: return argv.copy() except: return argv # Federated learning phases def set_data(self, data, config): # Extract from config do_test = self.do_test = config.clients.do_test test_partition = self.test_partition = config.clients.test_partition # Download data self.data = self.download(data) # Extract trainset, testset (if applicable) data = self.data if do_test: # Partition for testset if applicable self.trainset = data[:int(len(data) * (1 - test_partition))] self.testset = data[int(len(data) * (1 - test_partition)):] else: self.trainset = data def configure(self, config): import fl_model # pylint: disable=import-error # Extract from config model_path = self.model_path = config.paths.model # Download from server config = self.download(config) # Extract machine learning task from config self.task = config.fl.task self.epochs = config.fl.epochs self.batch_size = config.fl.batch_size # Download most recent global model path = model_path + '/global' self.model = fl_model.Net() self.model.load_state_dict(torch.load(path)) self.model.eval() # Create optimizer self.optimizer = fl_model.get_optimizer(self.model) def run(self): # Perform federated learning task { "train": self.train() }[self.task] def get_report(self): # Report results to server. return self.upload(self.report) # Machine learning tasks def train(self): import fl_model # pylint: disable=import-error logging.info('Training on client #{}'.format(self.client_id)) # Perform model training trainloader = fl_model.get_trainloader(self.trainset, self.batch_size) fl_model.train(self.model, trainloader, self.optimizer, self.epochs) # Extract model weights and biases weights = fl_model.extract_weights(self.model) # Generate report for server self.report = Report(self) self.report.weights = weights # Perform model testing if applicable if self.do_test: testloader = fl_model.get_testloader(self.testset, 1000) self.report.accuracy = fl_model.test(self.model, testloader) def test(self): # Perform model testing raise NotImplementedError class Report(object): """Federated learning client report.""" def __init__(self, client): self.client_id = client.client_id self.num_samples = len(client.data) ================================================ FILE: config.py ================================================ from collections import namedtuple import json class Config(object): """Configuration module.""" def __init__(self, config): self.paths = "" # Load config file with open(config, 'r') as config: self.config = json.load(config) # Extract configuration self.extract() def extract(self): config = self.config # -- Clients -- fields = ['total', 'per_round', 'label_distribution', 'do_test', 'test_partition'] defaults = (0, 0, 'uniform', False, None) params = [config['clients'].get(field, defaults[i]) for i, field in enumerate(fields)] self.clients = namedtuple('clients', fields)(*params) assert self.clients.per_round <= self.clients.total # -- Data -- fields = ['loading', 'partition', 'IID', 'bias', 'shard'] defaults = ('static', 0, True, None, None) params = [config['data'].get(field, defaults[i]) for i, field in enumerate(fields)] self.data = namedtuple('data', fields)(*params) # Determine correct data loader assert self.data.IID ^ bool(self.data.bias) ^ bool(self.data.shard) if self.data.IID: self.loader = 'basic' elif self.data.bias: self.loader = 'bias' elif self.data.shard: self.loader = 'shard' # -- Federated learning -- fields = ['rounds', 'target_accuracy', 'task', 'epochs', 'batch_size'] defaults = (0, None, 'train', 0, 0) params = [config['federated_learning'].get(field, defaults[i]) for i, field in enumerate(fields)] self.fl = namedtuple('fl', fields)(*params) # -- Model -- self.model = config['model'] # -- Paths -- fields = ['data', 'model', 'reports'] defaults = ('./data', './models', None) params = [config['paths'].get(field, defaults[i]) for i, field in enumerate(fields)] # Set specific model path params[fields.index('model')] += '/' + self.model self.paths = namedtuple('paths', fields)(*params) # -- Server -- self.server = config['server'] ================================================ FILE: configs/CIFAR-10/cifar-10.json ================================================ { "clients": { "total": 100, "per_round": 10 }, "data": { "loading": "static", "partition": { "size": 600 }, "IID": true }, "federated_learning": { "rounds": 10000, "target_accuracy": 0.99, "task": "train", "epochs": 5, "batch_size": 10 }, "model": "CIFAR-10", "paths": { "data": "./data", "model": "./models" }, "server": "basic" } ================================================ FILE: configs/FashionMNIST/fashionmnist.json ================================================ { "clients": { "total": 100, "per_round": 10 }, "data": { "loading": "static", "partition": { "size": 600 }, "IID": true }, "federated_learning": { "rounds": 10000, "target_accuracy": 0.99, "task": "train", "epochs": 5, "batch_size": 10 }, "model": "FashionMNIST", "paths": { "data": "./data", "model": "./models" }, "server": "basic" } ================================================ FILE: configs/MNIST/mnist.json ================================================ { "clients": { "total": 100, "per_round": 10 }, "data": { "loading": "static", "partition": { "size": 600 }, "IID": true }, "federated_learning": { "rounds": 10000, "target_accuracy": 0.99, "task": "train", "epochs": 5, "batch_size": 10 }, "model": "MNIST", "paths": { "data": "./data", "model": "./models" }, "server": "basic" } ================================================ FILE: configs/config.json.template ================================================ { "clients": { "total": 1000, "per_round": 20, "label_distribution": "uniform", "do_test": false, "test_partition": 0.2 }, "data": { "loading": "dynamic", "partition": { "size": 600, "range": [ 50, 200 ] }, "IID": false, "bias": { "primary": 0.8, "secondary": false } }, "federated_learning": { "rounds": 200, "target_accuracy": 0.95, "task": "train", "epochs": 5, "batch_size": 10 }, "model": "MNIST", "paths": { "data": "./data", "model": "./models", "reports": "reports.pkl" }, "server": "basic" } ================================================ FILE: environment.yml ================================================ name: fl-py37 channels: - pytorch - defaults dependencies: - astroid=2.2.5 - autopep8=1.4.4 - blas=1.0 - ca-certificates=2019.5.15 - certifi=2019.6.16 - cffi=1.12.3 - cycler=0.10.0 - freetype=2.9.1 - intel-openmp=2019.4 - isort=4.3.20 - joblib=0.13.2 - jpeg=9b - kiwisolver=1.1.0 - lazy-object-proxy=1.4.1 - libedit=3.1.20181209 - libffi=3.2.1 - libpng=1.6.37 - libtiff=4.0.10 - matplotlib=3.1.0 - mccabe=0.6.1 - mkl=2019.4 - mkl-service=2.0.2 - mkl_fft=1.0.12 - mkl_random=1.0.2 - ncurses=6.1 - ninja=1.9.0 - numpy=1.16.4 - numpy-base=1.16.4 - olefile=0.46 - openssl=1.1.1c - pandas=0.24.2 - pillow=6.0.0 - pip=19.1.1 - pycodestyle=2.5.0 - pycparser=2.19 - pylint=2.3.1 - pyparsing=2.4.0 - python=3.7.3 - python-dateutil=2.8.0 - pytorch=1.1.0 - pytz=2019.1 - readline=7.0 - rope=0.14.0 - scikit-learn=0.21.2 - scipy=1.2.1 - setuptools=41.0.1 - six=1.12.0 - sqlite=3.28.0 - tk=8.6.8 - torchvision=0.3.0 - tornado=6.0.2 - wheel=0.33.4 - wrapt=1.11.1 - xz=5.2.4 - zlib=1.2.11 - zstd=1.3.7 prefix: /Users/zakharykaplan/.miniconda3/envs/fl-py37 ================================================ FILE: load_data.py ================================================ import logging import random from torchvision import datasets, transforms import utils.dists as dists class Generator(object): """Generate federated learning training and testing data.""" # Abstract read function def read(self, path): # Read the dataset, set: trainset, testset, labels raise NotImplementedError # Group the data by label def group(self): # Create empty dict of labels grouped_data = {label: [] for label in self.labels} # pylint: disable=no-member # Populate grouped data dict for datapoint in self.trainset: # pylint: disable=all _, label = datapoint # Extract label label = self.labels[label] grouped_data[label].append( # pylint: disable=no-member datapoint) self.trainset = grouped_data # Overwrite trainset with grouped data # Run data generation def generate(self, path): self.read(path) self.trainset_size = len(self.trainset) # Extract trainset size self.group() return self.trainset class Loader(object): """Load and pass IID data partitions.""" def __init__(self, config, generator): # Get data from generator self.config = config self.trainset = generator.trainset self.testset = generator.testset self.labels = generator.labels self.trainset_size = generator.trainset_size # Store used data seperately self.used = {label: [] for label in self.labels} self.used['testset'] = [] def extract(self, label, n): if len(self.trainset[label]) > n: extracted = self.trainset[label][:n] # Extract data self.used[label].extend(extracted) # Move data to used del self.trainset[label][:n] # Remove from trainset return extracted else: logging.warning('Insufficient data in label: {}'.format(label)) logging.warning('Dumping used data for reuse') # Unmark data as used for label in self.labels: self.trainset[label].extend(self.used[label]) self.used[label] = [] # Extract replenished data return self.extract(label, n) def get_partition(self, partition_size): # Get an partition uniform across all labels # Use uniform distribution dist = dists.uniform(partition_size, len(self.labels)) partition = [] # Extract data according to distribution for i, label in enumerate(self.labels): partition.extend(self.extract(label, dist[i])) # Shuffle data partition random.shuffle(partition) return partition def get_testset(self): # Return the entire testset return self.testset class BiasLoader(Loader): """Load and pass 'preference bias' data partitions.""" def get_partition(self, partition_size, pref): # Get a non-uniform partition with a preference bias # Extract bias configuration from config bias = self.config.data.bias['primary'] secondary = self.config.data.bias['secondary'] # Calculate sizes of majorty and minority portions majority = int(partition_size * bias) minority = partition_size - majority # Calculate number of minor labels len_minor_labels = len(self.labels) - 1 if secondary: # Distribute to random secondary label dist = [0] * len_minor_labels dist[random.randint(0, len_minor_labels - 1)] = minority else: # Distribute among all minority labels dist = dists.uniform(minority, len_minor_labels) # Add majority data to distribution dist.insert(self.labels.index(pref), majority) partition = [] # Extract data according to distribution for i, label in enumerate(self.labels): partition.extend(self.extract(label, dist[i])) # Shuffle data partition random.shuffle(partition) return partition class ShardLoader(Loader): """Load and pass 'shard' data partitions.""" def create_shards(self): # Extract shard configuration from config per_client = self.config.data.shard['per_client'] # Determine correct total shards, shard size total = self.config.clients.total * per_client shard_size = int(self.trainset_size / total) data = [] # Flatten data for _, items in self.trainset.items(): data.extend(items) shards = [data[(i * shard_size):((i + 1) * shard_size)] for i in range(total)] random.shuffle(shards) self.shards = shards self.used = [] logging.info('Created {} shards of size {}'.format( len(shards), shard_size)) def extract_shard(self): shard = self.shards[0] self.used.append(shard) del self.shards[0] return shard def get_partition(self): # Get a partition shard # Extract number of shards per client per_client = self.config.data.shard['per_client'] # Create data partition partition = [] for i in range(per_client): partition.extend(self.extract_shard()) # Shuffle data partition random.shuffle(partition) return partition ================================================ FILE: models/CIFAR-10/fl_model.py ================================================ import load_data import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms # Training settings lr = 0.01 momentum = 0.9 log_interval = 10 # Cuda settings use_cuda = torch.cuda.is_available() device = torch.device( # pylint: disable=no-member 'cuda' if use_cuda else 'cpu') class Generator(load_data.Generator): """Generator for CIFAR-10 dataset.""" # Extract CIFAR-10 data using torchvision datasets def read(self, path): self.trainset = datasets.CIFAR10( path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) self.testset = datasets.CIFAR10( path, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) self.labels = list(self.trainset.classes) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def get_optimizer(model): return optim.SGD(model.parameters(), lr=lr, momentum=momentum) def get_trainloader(trainset, batch_size): return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True) def get_testloader(testset, batch_size): return torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True) def extract_weights(model): weights = [] for name, weight in model.to(torch.device('cpu')).named_parameters(): # pylint: disable=no-member if weight.requires_grad: weights.append((name, weight.data)) return weights def load_weights(model, weights): updated_state_dict = {} for name, weight in weights: updated_state_dict[name] = weight model.load_state_dict(updated_state_dict, strict=False) def train(model, trainloader, optimizer, epochs): model.to(device) model.train() criterion = nn.CrossEntropyLoss() for epoch in range(1, epochs + 1): for batch_id, data in enumerate(trainloader): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() if batch_id % log_interval == 0: logging.debug('Epoch: [{}/{}]\tLoss: {:.6f}'.format( epoch, epochs, loss.item())) def test(model, testloader): model.to(device) model.eval() correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max( # pylint: disable=no-member outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total logging.debug('Accuracy: {:.2f}%'.format(100 * accuracy)) return accuracy ================================================ FILE: models/FashionMNIST/fl_model.py ================================================ import load_data import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms # Training settings lr = 0.01 momentum = 0.5 log_interval = 10 # Cuda settings use_cuda = torch.cuda.is_available() device = torch.device ( # pylint: disable=no-member 'cuda' if use_cuda else 'cpu') class Generator(load_data.Generator): """Generator for FashionMNIST dataset.""" # Extract FashionMNIST data using torchvision datasets def read(self, path): self.trainset = datasets.FashionMNIST( path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.1307,), (0.3081,)) ])) self.testset = datasets.FashionMNIST( path, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.1307,), (0.3081,)) ])) self.labels = list(self.trainset.classes) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.layer2 = nn.Sequential( nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.fc = nn.Linear(7 * 7 * 32, 10) def forward(self, x): out = self.layer1(x) out = self.layer2(out) out = out.reshape(out.size(0), -1) out = self.fc(out) return out def get_optimizer(model): return optim.Adam(model.parameters(), lr=lr) def get_trainloader(trainset, batch_size): return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True) def get_testloader(testset, batch_size): return torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True) def extract_weights(model): weights = [] for name, weight in model.to(torch.device('cpu')).named_parameters(): # pylint: disable=no-member if weight.requires_grad: weights.append((name, weight.data)) return weights def load_weights(model, weights): updated_state_dict = {} for name, weight in weights: updated_state_dict[name] = weight model.load_state_dict(updated_state_dict, strict=False) def train(model, trainloader, optimizer, epochs): model.to(device) model.train() criterion = nn.CrossEntropyLoss() for epoch in range(1, epochs + 1): for batch_id, data in enumerate(trainloader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if batch_id % log_interval == 0: logging.debug('Epoch: [{}/{}]\tLoss: {:.6f}'.format( epoch, epochs, loss.item())) def test(model, testloader): model.to(device) model.eval() with torch.no_grad(): correct = 0 total = 0 for data in testloader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = model(images) predicted = torch.argmax( # pylint: disable=no-member outputs, dim=1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total logging.debug('Accuracy: {:.2f}%'.format(100 * accuracy)) return accuracy ================================================ FILE: models/MNIST/fl_model.py ================================================ import load_data import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms # Training settings lr = 0.01 momentum = 0.5 log_interval = 10 # Cuda settings use_cuda = torch.cuda.is_available() device = torch.device( # pylint: disable=no-member 'cuda' if use_cuda else 'cpu') class Generator(load_data.Generator): """Generator for MNIST dataset.""" # Extract MNIST data using torchvision datasets def read(self, path): self.trainset = datasets.MNIST( path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.1307,), (0.3081,)) ])) self.testset = datasets.MNIST( path, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.1307,), (0.3081,)) ])) self.labels = list(self.trainset.classes) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1) self.fc1 = nn.Linear(4 * 4 * 50, 500) self.fc2 = nn.Linear(500, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, 2, 2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2, 2) x = x.view(-1, 4 * 4 * 50) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) def get_optimizer(model): return optim.SGD(model.parameters(), lr=lr, momentum=momentum) def get_trainloader(trainset, batch_size): return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True) def get_testloader(testset, batch_size): return torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True) def extract_weights(model): weights = [] for name, weight in model.to(torch.device('cpu')).named_parameters(): # pylint: disable=no-member if weight.requires_grad: weights.append((name, weight.data)) return weights def load_weights(model, weights): updated_state_dict = {} for name, weight in weights: updated_state_dict[name] = weight model.load_state_dict(updated_state_dict, strict=False) def train(model, trainloader, optimizer, epochs): model.to(device) model.train() for epoch in range(1, epochs + 1): for batch_id, (image, label) in enumerate(trainloader): image, label = image.to(device), label.to(device) optimizer.zero_grad() output = model(image) loss = F.nll_loss(output, label) loss.backward() optimizer.step() if batch_id % log_interval == 0: logging.debug('Epoch: [{}/{}]\tLoss: {:.6f}'.format( epoch, epochs, loss.item())) def test(model, testloader): model.to(device) model.eval() test_loss = 0 correct = 0 total = len(testloader.dataset) with torch.no_grad(): for image, label in testloader: image, label = image.to(device), label.to(device) output = model(image) # sum up batch loss test_loss += F.nll_loss(output, label, reduction='sum').item() # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(label.view_as(pred)).sum().item() accuracy = correct / total logging.debug('Accuracy: {:.2f}%'.format(100 * accuracy)) return accuracy ================================================ FILE: models/fl_model.py ================================================ # pylint: skip-file import load_data import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms # Training settings lr = 0.01 # CHECKME momentum = 0.5 # CHECKME log_interval = 10 # CHECKME class Generator(load_data.Generator): # CHECKME """Generator for UNNAMED dataset.""" # Extract UNNAMED data using torchvision datasets def read(self, path): self.trainset = datasets.UNNAMED( path, train=True, download=True, transform=transforms.Compose([ """ Add transforms here... """ ])) self.testset = datasets.UNNAMED( path, train=False, transform=transforms.Compose([ """ Add transforms here... """ ])) self.labels = list(self.trainset.classes) class Net(nn.Module): # CHECKME def __init__(self): super(Net, self).__init__() raise NotImplementedError def forward(self, x): raise NotImplementedError def get_optimizer(model): # CHECKME return optim.SGD(model.parameters(), lr=lr, momentum=momentum) def get_trainloader(trainset, batch_size): # CHECKME return torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True) def get_testloader(testset, batch_size): # CHECKME return torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True) def extract_weights(model): # CHECKME weights = [] for UNNAMED, weight in model.UNNAMEDd_parameters(): if weight.requires_grad: weights.append((UNNAMED, weight.data)) return weights def load_weights(model, weights): # CHECKME updated_weights_dict = {} for UNNAMED, weight in weights: updated_weights_dictUNNAMED = weight model.load_state_dict(updated_weights_dict, strict=False) def train(model, trainloader, optimizer, epochs): # CHECKME """ Set up for training here... """ for epoch in range(1, epochs + 1): for batch_id, (image, label) in enumerate(trainloader): """ Train model here... """ if batch_id % log_interval == 0: logging.debug('Epoch: [{}/{}]\tLoss: {:.6f}'.format( epoch, epochs, loss.item())) def test(model, testloader): # CHECKME """ Set up for testing here... """ correct = 0 total = 0 with torch.no_grad(): for image, label in testloader: """ Test model here... """ accuracy = correct / total logging.debug('Accuracy: {:.2f}%'.format(100 * accuracy)) return accuracy ================================================ FILE: run.py ================================================ import argparse import client import config import logging import os import server # Set up parser parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, default='./config.json', help='Federated learning configuration file.') parser.add_argument('-l', '--log', type=str, default='INFO', help='Log messages level.') args = parser.parse_args() # Set logging logging.basicConfig( format='[%(levelname)s][%(asctime)s]: %(message)s', level=getattr(logging, args.log.upper()), datefmt='%H:%M:%S') def main(): """Run a federated learning simulation.""" # Read configuration file fl_config = config.Config(args.config) # Initialize server fl_server = { "basic": server.Server(fl_config), "accavg": server.AccAvgServer(fl_config), "directed": server.DirectedServer(fl_config), "kcenter": server.KCenterServer(fl_config), "kmeans": server.KMeansServer(fl_config), "magavg": server.MagAvgServer(fl_config), # "dqn": server.DQNServer(fl_config), # DQN server disabled # "dqntrain": server.DQNTrainServer(fl_config), # DQN server disabled }[fl_config.server] fl_server.boot() # Run federated learning fl_server.run() # Delete global model os.remove(fl_config.paths.model + '/global') if __name__ == "__main__": main() ================================================ FILE: scripts/analyze_logs.py ================================================ import argparse from datetime import datetime import re # Set up parser parser = argparse.ArgumentParser() parser.add_argument('--log', type=str, help='Simmulation log file.') args = parser.parse_args() # Read log with open(args.log, 'r') as f: log = [x for x in f.readlines() if x != '\n'] # Extract time def extract_time(line): return datetime.strptime([x for x in re.split('\[|\]', line) if x][1], '%H:%M:%S') # Extract lines training = [] for line in log: if 'Round 1/' in line: training.append(line) training.append(log[-1]) # Calculate duration training_duration = (extract_time(training[1]) - extract_time(training[0])).seconds print('{}: training time: {} s'.format(args.log, training_duration)) ================================================ FILE: scripts/pca.py ================================================ import argparse import client import config import logging import os import pickle from sklearn.decomposition import PCA import server # Set logging logging.basicConfig( format='[%(levelname)s][%(asctime)s]: %(message)s', level=logging.INFO, datefmt='%H:%M:%S') # Set up parser parser = argparse.ArgumentParser() parser.add_argument('-c', '--config', type=str, default='./template.json', help='Configuration file for server.') parser.add_argument('-o', '--output', type=str, default='./output.pkl', help='Output pickle file') args = parser.parse_args() def main(): """Extract PCA vectors from FL clients.""" # Read configuration file fl_config = config.Config(args.config) # Initialize server fl_server = server.KMeansServer(fl_config) fl_server.boot() # Run client profiling fl_server.profile_clients() # Extract clients, reports, weights clients = [client for client in group for group in [ fl_server.clients[profile] for profile in fl_server.clients.keys()]] reports = [client.get_report() for client in clients] weights = [report.weights for report in reports] # Flatten weights def flatten_weights(weights): weight_vecs = [] for _, weight in weights: weight_vecs.extend(weight.flatten()) return weight_vecs logging.info('Flattening weights...') weight_vecs = [flatten_weights(weight) for weight in weights] # Perform PCA on weight vectors logging.info('Assembling output...') output = [(clients[i].client_id, clients[i].pref, weight) for i, weight in enumerate(weight_vecs)] logging.into('Writing output to binary...') with open(args.output, 'wb') as f: pickle.dump(output, f) logging.info('Done!') if __name__ == "__main__": main() ================================================ FILE: server/__init__.py ================================================ from .server import Server from .accavg import AccAvgServer from .directed import DirectedServer from .kcenter import KCenterServer from .kmeans import KMeansServer from .magavg import MagAvgServer # from .dqn import DQNServer # DQN server disbled # from .dqn import DQNTrainServer # DQN server disabled ================================================ FILE: server/accavg.py ================================================ from server import Server import numpy as np import torch class AccAvgServer(Server): """Federated learning server that performs accuracy weighted federated averaging.""" # Federated learning phases def aggregation(self, reports): return self.accuracy_fed_avg(reports) # Report aggregation def accuracy_fed_avg(self, reports): import fl_model # pylint: disable=import-error # Extract updates from reports updates = self.extract_client_updates(reports) # Extract client accuracies accuracies = np.array([report.accuracy for report in reports]) # Determine weighting based on accuracies factor = 8 # Exponentiation factor w = accuracies**factor / sum(accuracies**factor) # Perform weighted averaging avg_update = [torch.zeros(x.size()) # pylint: disable=no-member for _, x in updates[0]] for i, update in enumerate(updates): for j, (_, delta) in enumerate(update): # Use weighted average by magnetude of updates avg_update[j] += delta * w[i] # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Load updated weights into model updated_weights = [] for i, (name, weight) in enumerate(baseline_weights): updated_weights.append((name, weight + avg_update[i])) return updated_weights # Server operations def set_client_data(self, client): super().set_client_data(client) # Send each client a testing partition client.testset = client.download(self.loader.get_testset()) client.do_test = True # Tell client to perform testing ================================================ FILE: server/directed.py ================================================ import logging from server import Server import numpy as np from threading import Thread class DirectedServer(Server): """Federated learning server that uses profiles to direct during selection.""" # Run federated learning def run(self): # Perform profiling on all clients self.profiling() # Continue federated learning super().run() # Federated learning phases def selection(self): import fl_model # pylint: disable=import-error clients = self.clients clients_per_round = self.config.clients.per_round profiles = self.profiles w_previous = self.w_previous # Extract directors from profiles directors = [d for _, d in profiles] # Extract most recent model weights w_current = self.flatten_weights(fl_model.extract_weights(self.model)) model_direction = w_current - w_previous # Normalize model direction model_direction = model_direction / \ np.sqrt(np.dot(model_direction, model_direction)) # Update previous model weights self.w_previous = w_current # Generate client director scores (closer direction is better) scores = [np.dot(director, model_direction) for director in directors] # Apply punishment for repeatedly selected clients p = self.punishment scores = [x * (0.9)**p[i] for i, x in enumerate(scores)] # Select clients with highest scores sample_clients_index = [] for _ in range(clients_per_round): top_score_index = scores.index(max(scores)) sample_clients_index.append(top_score_index) # Overwrite to avoid reselection scores[top_score_index] = min(scores) - 1 # Extract selected sample clients sample_clients = [clients[i] for i in sample_clients_index] # Update punishment factors self.punishment = [ p[i] + 1 if i in sample_clients_index else 0 for i in range(len(clients))] return sample_clients def profiling(self): import fl_model # pylint: disable=import-error # Use all clients for profiling clients = self.clients # Configure clients for training self.configuration(clients) # Train on clients to generate profile weights threads = [Thread(target=client.train) for client in self.clients] [t.start() for t in threads] [t.join() for t in threads] # Recieve client reports reports = self.reporting(clients) # Extract weights from reports weights = [report.weights for report in reports] weights = [self.flatten_weights(weight) for weight in weights] # Extract initial model weights w0 = self.flatten_weights(fl_model.extract_weights(self.model)) # Save as initial previous model weights self.w_previous = w0.copy() # Update initial model using results of profiling # Perform weight aggregation logging.info('Aggregating updates') updated_weights = self.aggregation(reports) # Load updated weights fl_model.load_weights(self.model, updated_weights) # Calculate direction vectors (directors) directors = [(w - w0) for w in weights] # Normalize directors to unit length directors = [d / np.sqrt(np.dot(d, d)) for d in directors] # Initialize punishment factors self.punishment = [0 for _ in range(len(clients))] # Use directors for client profiles self.profiles = [(client, directors[i]) for i, client in enumerate(clients)] return self.profiles ================================================ FILE: server/kcenter.py ================================================ import logging import random from server import Server from threading import Thread from utils.kcenter import GreedyKCenter # pylint: disable=no-name-in-module class KCenterServer(Server): """Federated learning server that performs KCenter profiling during selection.""" # Run federated learning def run(self): # Perform profiling on all clients self.profiling() # Designate space for storing used client profiles self.used_profiles = [] # Continue federated learning super().run() # Federated learning phases def selection(self): # Select devices to participate in round profiles = self.profiles k = self.config.clients.per_round if len(profiles) < k: # Reuse clients when needed logging.warning('Not enough unused clients') logging.warning('Dumping clients for reuse') self.profiles.extend(self.used_profiles) self.used_profiles = [] # Shuffle profiles random.shuffle(profiles) # Cluster clients based on profile weights weights = [weight for _, weight in profiles] KCenter = GreedyKCenter() KCenter.fit(weights, k) logging.info('KCenter: {} clients, {} centers'.format( len(profiles), k)) # Select clients marked as cluster centers centers_index = KCenter.centers_index sample_profiles = [profiles[i] for i in centers_index] sample_clients = [client for client, _ in sample_profiles] # Mark sample profiles as used self.used_profiles.extend(sample_profiles) for i in sorted(centers_index, reverse=True): del self.profiles[i] return sample_clients def profiling(self): # Use all clients for profiling clients = self.clients # Configure clients for training self.configuration(clients) # Train on clients to generate profile weights threads = [Thread(target=client.train) for client in self.clients] [t.start() for t in threads] [t.join() for t in threads] # Recieve client reports reports = self.reporting(clients) # Extract weights from reports weights = [report.weights for report in reports] weights = [self.flatten_weights(weight) for weight in weights] # Use weights for client profiles self.profiles = [(client, weights[i]) for i, client in enumerate(clients)] return self.profiles ================================================ FILE: server/kmeans.py ================================================ import logging import random from server import Server from sklearn.cluster import KMeans from threading import Thread import utils.dists as dists # pylint: disable=no-name-in-module class KMeansServer(Server): """Federated learning server that performs KMeans profiling during selection.""" # Run federated learning def run(self): # Perform profiling on all clients self.profile_clients() # Continue federated learning super().run() # Federated learning phases def selection(self): # Select devices to participate in round clients_per_round = self.config.clients.per_round cluster_labels = self.clients.keys() # Generate uniform distribution for selecting clients dist = dists.uniform(clients_per_round, len(cluster_labels)) # Select clients from KMeans clusters sample_clients = [] for i, cluster in enumerate(cluster_labels): # Select clients according to distribution if len(self.clients[cluster]) >= dist[i]: k = dist[i] else: # If not enough clients in cluster, use all avaliable k = len(self.clients[cluster]) sample_clients.extend(random.sample( self.clients[cluster], k)) # Shuffle selected sample clients random.shuffle(sample_clients) return sample_clients # Output model weights def model_weights(self, clients): # Configure clients to train on local data self.configuration(clients) # Train on local data for profiling purposes threads = [Thread(target=client.train) for client in self.clients] [t.start() for t in threads] [t.join() for t in threads] # Recieve client reports reports = self.reporting(clients) # Extract weights from reports weights = [report.weights for report in reports] return [self.flatten_weights(weight) for weight in weights] def prefs_to_weights(self): prefs = [client.pref for client in self.clients] return list(zip(prefs, self.model_weights(self.clients))) def profiling(self, clients): # Perform clustering weight_vecs = self.model_weights(clients) # Use the number of clusters as there are labels n_clusters = len(self.loader.labels) logging.info('KMeans: {} clients, {} clusters'.format( len(weight_vecs), n_clusters)) kmeans = KMeans( # Use KMeans clustering algorithm n_clusters=n_clusters).fit(weight_vecs) return kmeans.labels_ # Server operations def profile_clients(self): # Perform profiling on all clients kmeans = self.profiling(self.clients) # Group clients by profile grouped_clients = {cluster: [] for cluster in range(len(self.loader.labels))} for i, client in enumerate(self.clients): grouped_clients[kmeans[i]].append(client) self.clients = grouped_clients # Replace linear client list with dict def add_client(self): # Add a new client to the server raise NotImplementedError ================================================ FILE: server/magavg.py ================================================ from server import Server import numpy as np import torch class MagAvgServer(Server): """Federated learning server that performs magnetude weighted federated averaging.""" # Federated learning phases def aggregation(self, reports): return self.magnetude_fed_avg(reports) # Report aggregation def magnetude_fed_avg(self, reports): import fl_model # pylint: disable=import-error # Extract updates from reports updates = self.extract_client_updates(reports) # Extract update magnetudes magnetudes = [] for update in updates: magnetude = 0 for _, weight in update: magnetude += weight.norm() ** 2 magnetudes.append(np.sqrt(magnetude)) # Perform weighted averaging avg_update = [torch.zeros(x.size()) # pylint: disable=no-member for _, x in updates[0]] for i, update in enumerate(updates): for j, (_, delta) in enumerate(update): # Use weighted average by magnetude of updates avg_update[j] += delta * (magnetudes[i] / sum(magnetudes)) # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Load updated weights into model updated_weights = [] for i, (name, weight) in enumerate(baseline_weights): updated_weights.append((name, weight + avg_update[i])) return updated_weights ================================================ FILE: server/server.py ================================================ import client import load_data import logging import numpy as np import pickle import random import sys from threading import Thread import torch import utils.dists as dists # pylint: disable=no-name-in-module class Server(object): """Basic federated learning server.""" def __init__(self, config): self.config = config # Set up server def boot(self): logging.info('Booting {} server...'.format(self.config.server)) model_path = self.config.paths.model total_clients = self.config.clients.total # Add fl_model to import path sys.path.append(model_path) # Set up simulated server self.load_data() self.load_model() self.make_clients(total_clients) def load_data(self): import fl_model # pylint: disable=import-error # Extract config for loaders config = self.config # Set up data generator generator = fl_model.Generator() # Generate data data_path = self.config.paths.data data = generator.generate(data_path) labels = generator.labels logging.info('Dataset size: {}'.format( sum([len(x) for x in [data[label] for label in labels]]))) logging.debug('Labels ({}): {}'.format( len(labels), labels)) # Set up data loader self.loader = { 'basic': load_data.Loader(config, generator), 'bias': load_data.BiasLoader(config, generator), 'shard': load_data.ShardLoader(config, generator) }[self.config.loader] logging.info('Loader: {}, IID: {}'.format( self.config.loader, self.config.data.IID)) def load_model(self): import fl_model # pylint: disable=import-error model_path = self.config.paths.model model_type = self.config.model logging.info('Model: {}'.format(model_type)) # Set up global model self.model = fl_model.Net() self.save_model(self.model, model_path) # Extract flattened weights (if applicable) if self.config.paths.reports: self.saved_reports = {} self.save_reports(0, []) # Save initial model def make_clients(self, num_clients): IID = self.config.data.IID labels = self.loader.labels loader = self.config.loader loading = self.config.data.loading if not IID: # Create distribution for label preferences if non-IID dist = { "uniform": dists.uniform(num_clients, len(labels)), "normal": dists.normal(num_clients, len(labels)) }[self.config.clients.label_distribution] random.shuffle(dist) # Shuffle distribution # Make simulated clients clients = [] for client_id in range(num_clients): # Create new client new_client = client.Client(client_id) if not IID: # Configure clients for non-IID data if self.config.data.bias: # Bias data partitions bias = self.config.data.bias # Choose weighted random preference pref = random.choices(labels, dist)[0] # Assign preference, bias config new_client.set_bias(pref, bias) elif self.config.data.shard: # Shard data partitions shard = self.config.data.shard # Assign shard config new_client.set_shard(shard) clients.append(new_client) logging.info('Total clients: {}'.format(len(clients))) if loader == 'bias': logging.info('Label distribution: {}'.format( [[client.pref for client in clients].count(label) for label in labels])) if loading == 'static': if loader == 'shard': # Create data shards self.loader.create_shards() # Send data partition to all clients [self.set_client_data(client) for client in clients] self.clients = clients # Run federated learning def run(self): rounds = self.config.fl.rounds target_accuracy = self.config.fl.target_accuracy reports_path = self.config.paths.reports if target_accuracy: logging.info('Training: {} rounds or {}% accuracy\n'.format( rounds, 100 * target_accuracy)) else: logging.info('Training: {} rounds\n'.format(rounds)) # Perform rounds of federated learning for round in range(1, rounds + 1): logging.info('**** Round {}/{} ****'.format(round, rounds)) # Run the federated learning round accuracy = self.round() # Break loop when target accuracy is met if target_accuracy and (accuracy >= target_accuracy): logging.info('Target accuracy reached.') break if reports_path: with open(reports_path, 'wb') as f: pickle.dump(self.saved_reports, f) logging.info('Saved reports: {}'.format(reports_path)) def round(self): import fl_model # pylint: disable=import-error # Select clients to participate in the round sample_clients = self.selection() # Configure sample clients self.configuration(sample_clients) # Run clients using multithreading for better parallelism threads = [Thread(target=client.run) for client in sample_clients] [t.start() for t in threads] [t.join() for t in threads] # Recieve client updates reports = self.reporting(sample_clients) # Perform weight aggregation logging.info('Aggregating updates') updated_weights = self.aggregation(reports) # Load updated weights fl_model.load_weights(self.model, updated_weights) # Extract flattened weights (if applicable) if self.config.paths.reports: self.save_reports(round, reports) # Save updated global model self.save_model(self.model, self.config.paths.model) # Test global model accuracy if self.config.clients.do_test: # Get average accuracy from client reports accuracy = self.accuracy_averaging(reports) else: # Test updated model on server testset = self.loader.get_testset() batch_size = self.config.fl.batch_size testloader = fl_model.get_testloader(testset, batch_size) accuracy = fl_model.test(self.model, testloader) logging.info('Average accuracy: {:.2f}%\n'.format(100 * accuracy)) return accuracy # Federated learning phases def selection(self): # Select devices to participate in round clients_per_round = self.config.clients.per_round # Select clients randomly sample_clients = [client for client in random.sample( self.clients, clients_per_round)] return sample_clients def configuration(self, sample_clients): loader_type = self.config.loader loading = self.config.data.loading if loading == 'dynamic': # Create shards if applicable if loader_type == 'shard': self.loader.create_shards() # Configure selected clients for federated learning task for client in sample_clients: if loading == 'dynamic': self.set_client_data(client) # Send data partition to client # Extract config for client config = self.config # Continue configuraion on client client.configure(config) def reporting(self, sample_clients): # Recieve reports from sample clients reports = [client.get_report() for client in sample_clients] logging.info('Reports recieved: {}'.format(len(reports))) assert len(reports) == len(sample_clients) return reports def aggregation(self, reports): return self.federated_averaging(reports) # Report aggregation def extract_client_updates(self, reports): import fl_model # pylint: disable=import-error # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Extract weights from reports weights = [report.weights for report in reports] # Calculate updates from weights updates = [] for weight in weights: update = [] for i, (name, weight) in enumerate(weight): bl_name, baseline = baseline_weights[i] # Ensure correct weight is being updated assert name == bl_name # Calculate update delta = weight - baseline update.append((name, delta)) updates.append(update) return updates def federated_averaging(self, reports): import fl_model # pylint: disable=import-error # Extract updates from reports updates = self.extract_client_updates(reports) # Extract total number of samples total_samples = sum([report.num_samples for report in reports]) # Perform weighted averaging avg_update = [torch.zeros(x.size()) # pylint: disable=no-member for _, x in updates[0]] for i, update in enumerate(updates): num_samples = reports[i].num_samples for j, (_, delta) in enumerate(update): # Use weighted average by number of samples avg_update[j] += delta * (num_samples / total_samples) # Extract baseline model weights baseline_weights = fl_model.extract_weights(self.model) # Load updated weights into model updated_weights = [] for i, (name, weight) in enumerate(baseline_weights): updated_weights.append((name, weight + avg_update[i])) return updated_weights def accuracy_averaging(self, reports): # Get total number of samples total_samples = sum([report.num_samples for report in reports]) # Perform weighted averaging accuracy = 0 for report in reports: accuracy += report.accuracy * (report.num_samples / total_samples) return accuracy # Server operations @staticmethod def flatten_weights(weights): # Flatten weights into vectors weight_vecs = [] for _, weight in weights: weight_vecs.extend(weight.flatten().tolist()) return np.array(weight_vecs) def set_client_data(self, client): loader = self.config.loader # Get data partition size if loader != 'shard': if self.config.data.partition.get('size'): partition_size = self.config.data.partition.get('size') elif self.config.data.partition.get('range'): start, stop = self.config.data.partition.get('range') partition_size = random.randint(start, stop) # Extract data partition for client if loader == 'basic': data = self.loader.get_partition(partition_size) elif loader == 'bias': data = self.loader.get_partition(partition_size, client.pref) elif loader == 'shard': data = self.loader.get_partition() else: logging.critical('Unknown data loader type') # Send data to client client.set_data(data, self.config) def save_model(self, model, path): path += '/global' torch.save(model.state_dict(), path) logging.info('Saved global model: {}'.format(path)) def save_reports(self, round, reports): import fl_model # pylint: disable=import-error if reports: self.saved_reports['round{}'.format(round)] = [(report.client_id, self.flatten_weights( report.weights)) for report in reports] # Extract global weights self.saved_reports['w{}'.format(round)] = self.flatten_weights( fl_model.extract_weights(self.model)) ================================================ FILE: utils/dists.py ================================================ import numpy as np import random def uniform(N, k): """Uniform distribution of 'N' items into 'k' groups.""" dist = [] avg = N / k # Make distribution for i in range(k): dist.append(int((i + 1) * avg) - int(i * avg)) # Return shuffled distribution random.shuffle(dist) return dist def normal(N, k): """Normal distribution of 'N' items into 'k' groups.""" dist = [] # Make distribution for i in range(k): x = i - (k - 1) / 2 dist.append(int(N * (np.exp(-x) / (np.exp(-x) + 1)**2))) # Add remainders remainder = N - sum(dist) dist = list(np.add(dist, uniform(remainder, k))) # Return non-shuffled distribution return dist ================================================ FILE: utils/kcenter.py ================================================ import numpy as np class GreedyKCenter(object): def fit(self, points, k): centers = [] centers_index = [] # Initialize distances distances = [np.inf for u in points] # Initialize cluster labels labels = [np.inf for u in points] for cluster in range(k): # Let u be the point of P such that d[u] is maximum u_index = distances.index(max(distances)) u = points[u_index] # u is the next cluster center centers.append(u) centers_index.append(u_index) # Update distance to nearest center for i, v in enumerate(points): distance_to_u = self.distance(u, v) # Calculate from v to u if distance_to_u < distances[i]: distances[i] = distance_to_u labels[i] = cluster # Update the bottleneck distance max_distance = max(distances) # Return centers, labels, max delta, labels self.centers = centers self.centers_index = centers_index self.max_distance = max_distance self.labels = labels @staticmethod def distance(u, v): displacement = u - v return np.sqrt(displacement.dot(displacement))