Repository: google-research/rigl
Branch: master
Commit: d39fc7d46505
Files: 141
Total size: 902.2 KB
Directory structure:
gitextract_rhnta46k/
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── rigl/
│ ├── __init__.py
│ ├── cifar_resnet/
│ │ ├── data_helper.py
│ │ ├── data_helper_test.py
│ │ ├── resnet_model.py
│ │ └── resnet_train_eval.py
│ ├── experimental/
│ │ └── jax/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── analysis/
│ │ │ └── plot_summary_json.ipynb
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── cifar10.py
│ │ │ ├── cifar10_test.py
│ │ │ ├── dataset_base.py
│ │ │ ├── dataset_base_test.py
│ │ │ ├── dataset_factory.py
│ │ │ ├── dataset_factory_test.py
│ │ │ ├── mnist.py
│ │ │ └── mnist_test.py
│ │ ├── fixed_param.py
│ │ ├── fixed_param_test.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── cifar10_cnn.py
│ │ │ ├── cifar10_cnn_test.py
│ │ │ ├── mnist_cnn.py
│ │ │ ├── mnist_cnn_test.py
│ │ │ ├── mnist_fc.py
│ │ │ ├── mnist_fc_test.py
│ │ │ ├── model_factory.py
│ │ │ └── model_factory_test.py
│ │ ├── prune.py
│ │ ├── prune_test.py
│ │ ├── pruning/
│ │ │ ├── __init__.py
│ │ │ ├── init.py
│ │ │ ├── init_test.py
│ │ │ ├── mask_factory.py
│ │ │ ├── mask_factory_test.py
│ │ │ ├── masked.py
│ │ │ ├── masked_test.py
│ │ │ ├── pruning.py
│ │ │ ├── pruning_test.py
│ │ │ ├── symmetry.py
│ │ │ └── symmetry_test.py
│ │ ├── random_mask.py
│ │ ├── random_mask_test.py
│ │ ├── requirements.txt
│ │ ├── run.sh
│ │ ├── shuffled_mask.py
│ │ ├── shuffled_mask_test.py
│ │ ├── train.py
│ │ ├── train_test.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── training.py
│ │ │ └── training_test.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── imagenet_resnet/
│ │ ├── colabs/
│ │ │ ├── MobileNet_Counting.ipynb
│ │ │ └── Resnet_50_Param_Flops_Counting.ipynb
│ │ ├── imagenet_train_eval.py
│ │ ├── mobilenetv1_model.py
│ │ ├── mobilenetv2_model.py
│ │ ├── pruning_layers.py
│ │ ├── resnet_model.py
│ │ ├── train_test.py
│ │ ├── utils.py
│ │ └── vgg.py
│ ├── mnist/
│ │ ├── mnist_train_eval.py
│ │ └── visualize_mask_records.py
│ ├── requirements.txt
│ ├── rigl_tf2/
│ │ ├── README.md
│ │ ├── colabs/
│ │ │ └── MnistProp.ipynb
│ │ ├── configs/
│ │ │ ├── dense.gin
│ │ │ ├── grasp.gin
│ │ │ ├── hessian.gin
│ │ │ ├── interpolate.gin
│ │ │ ├── lottery.gin
│ │ │ ├── prune.gin
│ │ │ ├── rigl.gin
│ │ │ ├── scratch.gin
│ │ │ ├── set.gin
│ │ │ ├── small_dense.gin
│ │ │ └── snip.gin
│ │ ├── init_utils.py
│ │ ├── interpolate.py
│ │ ├── mask_updaters.py
│ │ ├── metainit.py
│ │ ├── mlp_configs/
│ │ │ ├── dense.gin
│ │ │ ├── lottery.gin
│ │ │ ├── prune.gin
│ │ │ ├── rigl.gin
│ │ │ ├── scratch.gin
│ │ │ ├── set.gin
│ │ │ └── small_dense.gin
│ │ ├── networks.py
│ │ ├── train.py
│ │ └── utils.py
│ ├── rl/
│ │ ├── README.md
│ │ ├── dqn_agents.py
│ │ ├── requirements.txt
│ │ ├── run.sh
│ │ ├── run_experiment.py
│ │ ├── sparse_utils.py
│ │ ├── sparsetrain_configs/
│ │ │ ├── dqn_atari_dense.gin
│ │ │ ├── dqn_atari_dense_impala_net.gin
│ │ │ ├── dqn_atari_prune.gin
│ │ │ ├── dqn_atari_prune_impala_net.gin
│ │ │ ├── dqn_atari_rigl.gin
│ │ │ ├── dqn_atari_rigl_impala_net.gin
│ │ │ ├── dqn_atari_set.gin
│ │ │ ├── dqn_atari_set_impala_net.gin
│ │ │ ├── dqn_atari_static.gin
│ │ │ └── dqn_atari_static_impala_net.gin
│ │ ├── tfagents/
│ │ │ ├── configs/
│ │ │ │ ├── dqn_gym_dense_config.gin
│ │ │ │ ├── dqn_gym_pruning_config.gin
│ │ │ │ ├── dqn_gym_sparse_config.gin
│ │ │ │ ├── ppo_mujoco_dense_config.gin
│ │ │ │ ├── ppo_mujoco_pruning_config.gin
│ │ │ │ ├── ppo_mujoco_sparse_config.gin
│ │ │ │ ├── sac_mujoco_dense_config.gin
│ │ │ │ ├── sac_mujoco_pruning_config.gin
│ │ │ │ └── sac_mujoco_sparse_config.gin
│ │ │ ├── dqn_train_eval.py
│ │ │ ├── ppo_train_eval.py
│ │ │ ├── sac_train_eval.py
│ │ │ ├── sparse_encoding_network.py
│ │ │ ├── sparse_ppo_actor_network.py
│ │ │ ├── sparse_ppo_discrete_actor_network.py
│ │ │ ├── sparse_ppo_discrete_actor_network_test.py
│ │ │ ├── sparse_tanh_normal_projection_network.py
│ │ │ ├── sparse_value_network.py
│ │ │ └── tf_sparse_utils.py
│ │ └── train.py
│ ├── sparse_optimizers.py
│ ├── sparse_optimizers_base.py
│ ├── sparse_optimizers_test.py
│ ├── sparse_utils.py
│ ├── sparse_utils_test.py
│ └── str_sparsities.py
└── run.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute
We'd love to accept your patches and contributions to this project.
- If you want to contribute to the library please check `Issues` tab and feel
free to take on any problem/issue you find interesting.
- If your `issue` is not reported yet, please create a new one. It is
important to discuss the problem/request before implementing the solution.
- Reach us at rigl.authors@gmail.com any time!
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Rigging the Lottery: Making All Tickets Winners
**Paper**: [https://arxiv.org/abs/1911.11134](https://arxiv.org/abs/1911.11134)
**15min Presentation** [[pml4dc](https://pml4dc.github.io/iclr2020/program/pml4dc_7.html)] [[icml](https://icml.cc/virtual/2020/paper/5808)]
**ML Reproducibility Challenge 2020** [report](https://openreview.net/forum?id=riCIeP6LzEE)
## Colabs for Calculating FLOPs of Sparse Models
[MobileNet-v1](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb)
[ResNet-50](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb)
## Best Sparse Models
Parameters are float, so each parameter is represented with 4 bytes. Uniform
sparsity distribution keeps first layer dense therefore have slightly larger size
and parameters. ERK applies to all layers except for 99% sparse model, in which
we set the first layer to be dense, since otherwise we observe much worse
performance.
### Extended Training Results
Performance of RigL increases significantly with extended training iterations.
In this section we extend the training of sparse models by 5x. Note that sparse
models require much less FLOPs per training iteration and therefore most of the
extended trainings cost less FLOPs than baseline dense training.
Observing improving performance we wanted to understand where the performance of sparse networks saturates. Longest training we ran had 100x training length of the original
100 epoch ImageNet training. This training costs 5.8x of the original dense training FLOPS and the resulting 99% sparse Resnet-50 achieves an impressive 68.15% test accuracy (vs 5x training accuracy of 61.86%).
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| - (DENSE) | 0 | 3.2e18 | 8.2e9 | 102.122 | 76.8 | - |
| ERK | 0.8 | 2.09x | 0.42x | 23.683 | 77.17 | [link](https://storage.googleapis.com/gresearch/rigl/s80erk5x.tar.gz) |
| Uniform | 0.8 | 1.14x | 0.23x | 23.685 | 76.71 | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform5x.tar.gz) |
| ERK | 0.9 | 1.23x | 0.24x | 13.499 | 76.42 | [link](https://storage.googleapis.com/gresearch/rigl/s90erk5x.tar.gz) |
| Uniform | 0.9 | 0.66x | 0.13x | 13.532 | 75.73 | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform5x.tar.gz) |
| ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.63 | [link](https://storage.googleapis.com/gresearch/rigl/s95erk5x.tar.gz) |
| Uniform | 0.95 | 0.42x | 0.08x | 8.433 | 73.22 | [link](https://storage.googleapis.com/gresearch/rigl/s95uniform5x.tar.gz) |
| ERK | 0.965 | 0.45x | 0.09x | 6.904 | 72.77 | [link](https://storage.googleapis.com/gresearch/rigl/s965erk5x.tar.gz) |
| Uniform | 0.965 | 0.34x | 0.07x | 6.904 | 71.31 | [link](https://storage.googleapis.com/gresearch/rigl/s965uniform5x.tar.gz) |
| ERK | 0.99 | 0.29x | 0.05x | 4.354 | 61.86 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk5x.tar.gz) |
| ERK | 0.99 | 0.58x | 0.05x | 4.354 | 63.89 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk10x.tar.gz) |
| ERK | 0.99 | 2.32x | 0.05x | 4.354 | 66.94 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk40x.tar.gz) |
| ERK | **0.99** | 5.8x | 0.05x | 4.354 | **68.15** | [link](https://storage.googleapis.com/gresearch/rigl/s99erk100x.tar.gz) |
We also ran extended training runs with MobileNet-v1. Again training 100x more,
we were not able saturate the performance. Training longer consistently achieved
better results.
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| - (DENSE) | 0 | 4.5e17 | 1.14e9 | 16.864 | 72.1 | - |
| ERK | 0.89 | 1.39x | 0.21x | 2.392 | 69.31 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk10x.tar.gz) |
| ERK | 0.89 | 2.79x | 0.21x | 2.392 | 70.63 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk50x.tar.gz) |
| Uniform | 0.89 | 1.25x | 0.09x | 2.392 | 69.28 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform10x.tar.gz) |
| Uniform | 0.89 | 6.25x | 0.09x | 2.392 | 70.25 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform50x.tar.gz) |
| Uniform | 0.89 | 12.5x | 0.09x | 2.392 | 70.59 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform100x.tar.gz) |
### 1x Training Results
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.12 | [link](https://storage.googleapis.com/gresearch/rigl/s80erk1x.tar.gz) |
| Uniform | 0.8 | 0.23x | 0.23x | 23.685 | 74.60 | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform1x.tar.gz) |
| ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.07 | [link](https://storage.googleapis.com/gresearch/rigl/s90erk1x.tar.gz) |
| Uniform | 0.9 | 0.13x | 0.13x | 13.532 | 72.02 | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform1x.tar.gz) |
### Results w/o label smoothing
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.02 | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_1x.tar.gz) |
| ERK | 0.8 | 2.09x | 0.42x | 23.683 | 76.17 | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_5x.tar.gz) |
| ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.4 | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_1x.tar.gz) |
| ERK | 0.9 | 1.23x | 0.24x | 13.499 | 75.9 | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_5x.tar.gz) |
| ERK | 0.95 | 0.13x | 0.12x | 8.399 | 70.39 | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_1x.tar.gz) |
| ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.36 | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_5x.tar.gz) |
### Evaluating checkpoints
Download the checkpoints and run the evaluation on ERK checkpoints with the
following:
```python
python imagenet_train_eval.py --mode=eval_once --output_dir=path/to/ckpt/folder \
--eval_once_ckpt_prefix=model.ckpt-3200000 --use_folder_stub=False \
--training_method=rigl --mask_init_method=erdos_renyi_kernel \
--first_layer_sparsity=-1
```
When running checkpoints with uniform sparsity distribution use `--mask_init_method=random` and `--first_layer_sparsity=0`. Set
`--model_architecture=mobilenet_v1` when evaluating mobilenet checkpoints.
## Sparse Training Algorithms
In this repository we implement following dynamic sparsity strategies:
1. [SET](https://www.nature.com/articles/s41467-018-04316-3): Implements Sparse
Evalutionary Training (SET) which corresponds to replacing low magnitude
connections randomly with new ones.
2. [SNFS](https://arxiv.org/abs/1907.04840): Implements momentum based training
*without* sparsity re-distribution:
3. [RigL](https://arxiv.org/abs/1911.11134): Our method, RigL, removes a
fraction of connections based on weight magnitudes and activates new ones
using instantaneous gradient information.
And the following one-shot pruning algorithm:
1. [SNIP](https://arxiv.org/abs/1810.02340): Single-shot Network Pruning based
on connection sensitivity prunes the least salient connections before training.
We have code for following settings:
- [Imagenet2012](https://github.com/google-research/rigl/tree/master/rigl/imagenet_resnet):
TPU compatible code with Resnet-50 and MobileNet-v1/v2.
- [CIFAR-10](https://github.com/google-research/rigl/tree/master/rigl/cifar_resnet)
with WideResNets.
- [MNIST](https://github.com/google-research/rigl/tree/master/rigl/mnist) with
2 layer fully connected network.
## Setup
First clone this repo.
```bash
git clone https://github.com/google-research/rigl.git
cd rigl
```
We use [Neurips 2019 MicroNet Challenge](https://micronet-challenge.github.io/)
code for counting operations and size of our networks. Let's clone the
google_research repo and add current folder to the python path.
```bash
git clone https://github.com/google-research/google-research.git
mv google-research/ google_research/
export PYTHONPATH=$PYTHONPATH:$PWD
```
Now we can run some tests. Following script creates a virtual environment and
installs the necessary libraries. Finally, it runs few tests.
```bash
bash run.sh
```
We need to activate the virtual environment before running an experiment. With
that, we are ready to run some trivial MNIST experiments.
```bash
source env/bin/activate
python rigl/mnist/mnist_train_eval.py
```
You can load and verify the performance of the Resnet-50 checkpoints
like following.
```bash
python rigl/imagenet_resnet/imagenet_train_eval.py --mode=eval_once --training_method=baseline --eval_batch_size=100 --output_dir=/path/to/folder --eval_once_ckpt_prefix=s80_model.ckpt-1280000 --use_folder_stub=False
```
We use the [Official TPU Code](https://github.com/tensorflow/tpu/tree/master/models/official/resnet)
for loading ImageNet data. First clone the
tensorflow/tpu repo and then add models/ folder to the python path.
```bash
git clone https://github.com/tensorflow/tpu.git
export PYTHONPATH=$PYTHONPATH:$PWD/tpu/models/
```
## Other Implementations
- [Graphcore-TF-MNIST](https://github.com/graphcore/examples/tree/master/applications/tensorflow/dynamic_sparsity/mnist_rigl): with sparse matrix ops!
- [Pytorch implementation](https://github.com/McCrearyD/rigl-torch) by Dyllan McCreary.
- [Micrograd-Pure Python](https://evcu.github.io/ml/sparse-micrograd/): This is
a toy example with pure python sparse implementation. Caution, very slow but fun.
## Citation
```
@incollection{rigl,
author = {Evci, Utku and Gale, Trevor and Menick, Jacob and Castro, Pablo Samuel and Elsen, Erich},
booktitle = {Proceedings of Machine Learning and Systems 2020},
pages = {471--481},
title = {Rigging the Lottery: Making All Tickets Winners},
year = {2020}
}
```
## Disclaimer
This is not an official Google product.
================================================
FILE: rigl/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This repo involves the code for training sparse neural networks."""
name = 'rigl'
================================================
FILE: rigl/cifar_resnet/data_helper.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for CIFAR10 data input pipeline.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
IMG_SIZE = 32
def pad_input(x, crop_dim=4):
"""Concatenates sides of image with pixels cropped from the border of image.
Args:
x: Input image float32 tensor.
crop_dim: Number of pixels to crop from the edge of the image.
Cropped pixels are then concatenated to the original image.
Returns:
x: input image float32 tensor. Transformed by padding edges with cropped
pixels.
"""
x = tf.concat(
[x[:crop_dim, :, :][::-1], x, x[-crop_dim:, :, :][::-1]], axis=0)
x = tf.concat(
[x[:, :crop_dim, :][:, ::-1], x, x[:, -crop_dim:, :][:, ::-1]], axis=1)
return x
def preprocess_train(x, width, height):
"""Pre-processing applied to training data set.
Args:
x: Input image float32 tensor.
width: int specifying intended width in pixels of image after preprocessing.
height: int specifying intended height in pixels of image after
preprocessing.
Returns:
x: transformed input with random crops, flips and reflection.
"""
x = pad_input(x, crop_dim=4)
x = tf.random_crop(x, [width, height, 3])
x = tf.image.random_flip_left_right(x)
return x
def input_fn(params):
"""Provides batches of CIFAR data.
Args:
params: A dictionary with a set of arguments, namely:
* batch_size (int32), specifies data points in a batch
* data_split (string), designates train or eval
* data_dictionary (string), specifies directory location of input dataset
Returns:
images: A float32`Tensor` of size [batch_size, 32, 32, 3].
labels: A int32`Tensor` of size [batch_size, num_classes].
"""
def parse_serialized_example(record):
"""Parses a CIFAR10 example."""
image = record['image']
label = tf.cast(record['label'], tf.int32)
image = tf.cast(image, tf.float32)
image = tf.image.per_image_standardization(image)
if data_split == 'train':
image = preprocess_train(image, IMG_SIZE, IMG_SIZE)
return image, label
data_split = params['data_split']
batch_size = params['batch_size']
if data_split == 'eval':
data_split = 'test'
dataset = tfds.load('cifar10:3.*.*', split=data_split)
# we only repeat an example and shuffle inputs during training
if data_split == 'train':
dataset = dataset.repeat().shuffle(buffer_size=50000)
# deserialize record into tensors and apply pre-processing.
dataset = dataset.map(parse_serialized_example).prefetch(batch_size)
# at test time, for the final batch we drop remaining examples so that no
# example is seen twice.
dataset = dataset.batch(batch_size)
images_batch, labels_batch = tf.data.make_one_shot_iterator(
dataset).get_next()
return (tf.reshape(images_batch, [batch_size, IMG_SIZE, IMG_SIZE, 3]),
tf.reshape(labels_batch, [batch_size]))
================================================
FILE: rigl/cifar_resnet/data_helper_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Tests for the data_helper input pipeline and the training process.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from absl import logging
import absl.testing.parameterized as parameterized
from rigl.cifar_resnet import resnet_train_eval
from rigl.cifar_resnet.data_helper import input_fn
import tensorflow.compat.v1 as tf
from tensorflow.contrib.model_pruning.python import pruning
FLAGS = flags.FLAGS
BATCH_SIZE = 1
NUM_IMAGES = 1
JITTER_MULTIPLIER = 2
class DataHelperTest(tf.test.TestCase, parameterized.TestCase):
def get_next(self):
data_directory = FLAGS.data_directory
# we pass the updated eval and train string to the params dictionary.
params = {
'mode': 'test',
'data_split': 'eval',
'batch_size': BATCH_SIZE,
'data_directory': data_directory
}
test_inputs, test_labels = input_fn(params)
return test_inputs, test_labels
def testInputPipeline(self):
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():
test_inputs, test_labels = self.get_next()
with self.test_session() as sess:
test_images_out, test_labels_out = sess.run([test_inputs, test_labels])
self.assertAllEqual(test_images_out.shape, [BATCH_SIZE, 32, 32, 3])
self.assertAllEqual(test_labels_out.shape, [BATCH_SIZE])
@parameterized.parameters(
{
'training_method': 'baseline',
},
{
'training_method': 'threshold',
},
{
'training_method': 'rigl',
},
)
def testTrainingStep(self, training_method):
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():
images, labels = self.get_next()
global_step, _, _, logits = resnet_train_eval.build_model(
mode='train',
images=images,
labels=labels,
training_method=training_method,
num_classes=FLAGS.num_classes,
depth=FLAGS.resnet_depth,
width=FLAGS.resnet_width)
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)
learning_rate = 0.1
opt = tf.train.MomentumOptimizer(
learning_rate, momentum=FLAGS.momentum, use_nesterov=True)
if training_method in ['threshold']:
# Create a pruning object using the pruning hyperparameters
pruning_obj = pruning.Pruning()
logging.info('starting mask update op')
mask_update_op = pruning_obj.conditional_mask_update_op()
# Create the training op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = opt.minimize(total_loss, global_step)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
# test that we can train successfully for 1 step
sess.run(init_op)
for _ in range(1):
sess.run(train_op)
if training_method in ['threshold']:
sess.run(mask_update_op)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: rigl/cifar_resnet/resnet_model.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Model implementation of wide resnet model.
Implements masking layer if pruning method is selected.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from rigl.imagenet_resnet.pruning_layers import sparse_conv2d
from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
_BN_EPS = 1e-5
_BN_MOMENTUM = 0.9
class WideResNetModel(object):
"""Implements WideResNet model."""
def __init__(self,
is_training,
regularizer=None,
data_format='channels_last',
pruning_method='baseline',
droprate=0.3,
prune_first_layer=True,
prune_last_layer=True):
"""WideResnet as described in https://arxiv.org/pdf/1605.07146.pdf.
Args:
is_training: Boolean, True during model training,
false for evaluation/inference.
regularizer: A regularization function (mapping variables to
regularization losses), or None.
data_format: A string that indicates whether the channels are the second
or last index in the matrix. 'channels_first' or 'channels_last'.
pruning_method: str, 'threshold' or 'baseline'.
droprate: float, dropout rate to apply activations.
prune_first_layer: bool, if True first layer is pruned.
prune_last_layer: bool, if True last layer is pruned.
"""
self._training = is_training
self._regularizer = regularizer
self._data_format = data_format
self._pruning_method = pruning_method
self._droprate = droprate
self._prune_first_layer = prune_first_layer
self._prune_last_layer = prune_last_layer
if data_format == 'channels_last':
self._channel_axis = -1
elif data_format == 'channels_first':
self._channel_axis = 1
def build(self, inputs, depth, width, num_classes, name=None):
"""Model architecture to train the model.
The configuration of the resnet blocks requires that depth should be
6n+4 where n is the number of resnet blocks desired.
Args:
inputs: A 4D float tensor containing the model inputs.
depth: Number of convolutional layers in the network.
width: Size of the convolutional filters in the residual blocks.
num_classes: Positive integer number of possible classes.
name: Optional string, the name of the resulting op in the TF graph.
Returns:
A 2D float logits tensor of shape (batch_size, num_classes).
Raises:
ValueError: if depth is not the minimum amount required to build the
model.
"""
if (depth - 4) % 6 != 0:
raise ValueError('Depth of ResNet specified not sufficient.')
resnet_blocks = (depth - 4) // 6
with tf.variable_scope(name, 'resnet_model'):
first_layer_technique = self._pruning_method
if not self._prune_first_layer:
first_layer_technique = 'baseline'
net = self._conv(
inputs,
'conv_1',
output_size=16,
sparsity_technique=first_layer_technique)
net = self._residual_block(
net, 'conv_2', 16 * width, subsample=False, blocks=resnet_blocks)
net = self._residual_block(
net, 'conv_3', 32 * width, subsample=True, blocks=resnet_blocks)
net = self._residual_block(
net, 'conv_4', 64 * width, subsample=True, blocks=resnet_blocks)
# Put the final BN, relu before the max pooling.
with tf.name_scope('Pooling'):
net = self._batch_norm(net)
net = tf.nn.relu(net)
net = tf.layers.average_pooling2d(
net, pool_size=8, strides=1, data_format=self._data_format)
net = contrib_layers.flatten(net)
last_layer_technique = self._pruning_method
if not self._prune_last_layer:
last_layer_technique = 'baseline'
net = self._dense(
net, num_classes, 'logits', sparsity_technique=last_layer_technique)
return net
def _batch_norm(self, net, name=None):
"""Adds batchnorm to the model.
Input gradients cannot be computed with fused batch norm; causes recursive
loop of tf.gradient call. If regularizer is specified, fused batchnorm must
be set to False (default setting).
Args:
net: Pre-batch norm tensor activations.
name: Specified name for batch normalization layer.
Returns:
batch norm layer: Activations from the batch normalization layer.
"""
return tf.layers.batch_normalization(
inputs=net,
fused=False,
training=self._training,
axis=self._channel_axis,
momentum=_BN_MOMENTUM,
epsilon=_BN_EPS,
name=name)
def _dense(self, net, num_units, name=None, sparsity_technique='baseline'):
return sparse_fully_connected(
x=net,
units=num_units,
sparsity_technique=sparsity_technique,
kernel_regularizer=self._regularizer,
name=name)
def _conv(self,
net,
name,
output_size,
strides=(1, 1),
padding='SAME',
sparsity_technique='baseline'):
"""returns conv layer."""
return sparse_conv2d(
x=net,
units=output_size,
activation=None,
kernel_size=[3, 3],
use_bias=False,
kernel_initializer=None,
kernel_regularizer=self._regularizer,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique=sparsity_technique,
normalizer_fn=None,
strides=strides,
padding=padding,
data_format=self._data_format,
name=name)
def _residual_block(self, net, name, output_size, subsample, blocks):
"""Adds a residual block to the model."""
with tf.name_scope(name):
for n in range(blocks):
with tf.name_scope('res_%d' % n):
# when subsample is true + first block a larger stride is used.
if subsample and n == 0:
strides = [2, 2]
else:
strides = [1, 1]
# Create the skip connection
skip = net
end_point = 'skip_%s' % name
net = self._batch_norm(net)
net = tf.nn.relu(net)
if net.get_shape()[3].value != output_size:
skip = sparse_conv2d(
x=net,
units=output_size,
activation=None,
kernel_size=[1, 1],
use_bias=False,
kernel_initializer=None,
kernel_regularizer=self._regularizer,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique=self._pruning_method,
normalizer_fn=None,
strides=strides,
padding='VALID',
data_format=self._data_format,
name=end_point)
# Create residual
net = self._conv(
net,
'%s_%d_1' % (name, n),
output_size,
strides,
sparsity_technique=self._pruning_method)
net = self._batch_norm(net)
net = tf.nn.relu(net)
net = tf.keras.layers.Dropout(self._droprate)(net, self._training)
net = self._conv(
net,
'%s_%d_2' % (name, n),
output_size,
sparsity_technique=self._pruning_method)
# Combine the residual and the skip connection
net += skip
return net
================================================
FILE: rigl/cifar_resnet/resnet_train_eval.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""This script trains a ResNet model that implements various pruning methods.
Implement pruning method during training:
Specify the pruning method to use using FLAGS.training_method
- To train a model with no pruning, specify FLAGS.training_method='baseline'
Specify desired end sparsity using FLAGS.end_sparsity
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from rigl import sparse_optimizers
from rigl import sparse_utils
from rigl.cifar_resnet.data_helper import input_fn
from rigl.cifar_resnet.resnet_model import WideResNetModel
from rigl.imagenet_resnet import utils
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import training as contrib_training
from tensorflow.contrib.model_pruning.python import pruning
flags.DEFINE_string('master', 'local',
'BNS name of the TensorFlow runtime to use.')
flags.DEFINE_integer('ps_task', 0,
'Task id of the replica running the training.')
flags.DEFINE_integer('keep_checkpoint_max', 5,
'Number of checkpoints to save, set 0 for all.')
flags.DEFINE_string('pruning_hparams', '',
'Comma separated list of pruning-related hyperparameters')
flags.DEFINE_string('train_dir', '/tmp/cifar10/',
'Directory where to write event logs and checkpoint.')
flags.DEFINE_string(
'load_mask_dir', '',
'Directory of a trained model from which to load only the mask')
flags.DEFINE_string(
'initial_value_checkpoint', '',
'Directory of a model from which to load only the parameters')
flags.DEFINE_integer(
'seed', default=0, help=('Sets the random seed.'))
flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
# 250 Epochs
flags.DEFINE_integer('max_steps', 97656, 'Number of steps to run.')
flags.DEFINE_float('l2', 5e-4, 'Scale factor for L2 weight decay.')
flags.DEFINE_integer('resnet_depth', 16, 'Number of core convolutional layers'
'in the network.')
flags.DEFINE_integer('resnet_width', 4, 'Width of the residual blocks.')
flags.DEFINE_string(
'data_directory', '', 'data directory where cifar10 records are stored')
flags.DEFINE_integer('num_classes', 10, 'Number of classes.')
flags.DEFINE_integer('dataset_size', 50000, 'Size of training dataset.')
flags.DEFINE_integer('batch_size', 128, 'Batch size.')
flags.DEFINE_integer('checkpoint_steps', 5000, 'Specifies step interval for'
'saving model checkpoints.')
flags.DEFINE_integer(
'summaries_steps', 300, 'Specifies interval in steps for'
'saving model summaries.')
flags.DEFINE_bool('per_class_metrics', True, 'Whether to add per-class'
'performance summaries.')
flags.DEFINE_enum('mode', 'train', ('train_and_eval', 'train', 'eval'),
'String that specifies either inference or training')
# pruning flags
flags.DEFINE_integer('sparsity_begin_step', 20000, 'Step to begin pruning at.')
flags.DEFINE_integer('sparsity_end_step', 75000, 'Step to end pruning at.')
flags.DEFINE_integer('pruning_frequency', 1000,
'Step interval between pruning steps.')
flags.DEFINE_float('end_sparsity', 0.9,
'Target sparsity desired by end of training.')
flags.DEFINE_enum(
'training_method', 'baseline',
('scratch', 'set', 'baseline', 'momentum', 'rigl', 'static', 'snip',
'prune'),
'Method used for training sparse network. `scratch` means initial mask is '
'kept during training. `set` is for sparse evalutionary training and '
'`baseline` is for dense baseline.')
flags.DEFINE_bool('prune_first_layer', False,
'Whether or not to apply sparsification to the first layer')
flags.DEFINE_bool('prune_last_layer', True,
'Whether or not to apply sparsification to the last layer')
flags.DEFINE_float('drop_fraction', 0.3,
'When changing mask dynamically, this fraction decides how '
'much of the ')
flags.DEFINE_string('drop_fraction_anneal', 'constant',
'If not empty the drop fraction is annealed during sparse'
' training. One of the following: `constant`, `cosine` or '
'`exponential_(\\d*\\.?\\d*)$`. For example: '
'`exponential_3`, `exponential_.3`, `exponential_0.3`. '
'The number after `exponential` defines the exponent.')
flags.DEFINE_string('grow_init', 'zeros',
'Passed to the SparseInitializer, one of: zeros, '
'initial_value, random_normal, random_uniform.')
flags.DEFINE_float('s_momentum', 0.9,
'Momentum values for exponential moving average of '
'gradients. Used when training_method="momentum".')
flags.DEFINE_float('rigl_acc_scale', 0.,
'Used to scale initial accumulated gradients for new '
'connections.')
flags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin mask updates.')
flags.DEFINE_integer('maskupdate_end_step', 75000, 'Step to end mask updates.')
flags.DEFINE_integer('maskupdate_frequency', 100,
'Step interval between mask updates.')
flags.DEFINE_string(
'mask_init_method',
default='random',
help='If not empty string and mask is not loaded from a checkpoint, '
'indicates the method used for mask initialization. One of the following: '
'`random`, `erdos_renyi`.')
flags.DEFINE_float('training_steps_multiplier', 1.0,
'Training schedule is shortened or extended with the '
'multiplier, if it is not 1.')
FLAGS = flags.FLAGS
PARAM_SUFFIXES = ('gamma', 'beta', 'weights', 'biases')
MASK_SUFFIX = 'mask'
CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
'ship', 'truck'
]
def create_eval_metrics(labels, logits):
"""Creates the evaluation metrics for the model."""
eval_metrics = {}
label_keys = CLASSES
predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
labels=labels, predictions=predictions)
if FLAGS.per_class_metrics:
with tf.name_scope('class_level_summaries') as scope:
for i in range(len(label_keys)):
labels = tf.cast(labels, tf.int64)
name = scope + '/' + label_keys[i]
eval_metrics[('class_level_summaries/precision/' +
label_keys[i])] = tf.metrics.precision_at_k(
labels=labels,
predictions=logits,
class_id=i,
k=1,
name=name)
eval_metrics[('class_level_summaries/recall/' +
label_keys[i])] = tf.metrics.recall_at_k(
labels=labels,
predictions=logits,
class_id=i,
k=1,
name=name)
return eval_metrics
def train_fn(training_method, global_step, total_loss, train_dir, accuracy,
top_5_accuracy):
"""Training script for resnet model.
Args:
training_method: specifies the method used to sparsify networks.
global_step: the current step of training/eval.
total_loss: tensor float32 of the cross entropy + regularization losses.
train_dir: string specifying where directory where summaries are saved.
accuracy: tensor float32 batch classification accuracy.
top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes).
Returns:
hooks: summary tensors to be computed at each training step.
eval_metrics: set to None during training.
train_op: the optimization term.
"""
# Rougly drops at every 30k steps.
boundaries = [30000, 60000, 90000]
if FLAGS.training_steps_multiplier != 1.0:
multiplier = FLAGS.training_steps_multiplier
boundaries = [int(x * multiplier) for x in boundaries]
tf.logging.info(
'Learning Rate boundaries are updated with multiplier:%.2f', multiplier)
learning_rate = tf.train.piecewise_constant(
global_step,
boundaries,
values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)],
name='lr_schedule')
optimizer = tf.train.MomentumOptimizer(
learning_rate, momentum=FLAGS.momentum, use_nesterov=True)
if training_method == 'set':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseSETOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal)
elif training_method == 'static':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseStaticOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal)
elif training_method == 'momentum':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseMomentumOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
grow_init=FLAGS.grow_init,
drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)
elif training_method == 'rigl':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseRigLOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency,
drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal,
initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)
elif training_method == 'snip':
optimizer = sparse_optimizers.SparseSnipOptimizer(
optimizer, mask_init_method=FLAGS.mask_init_method,
default_sparsity=FLAGS.end_sparsity, use_tpu=False)
elif training_method in ('scratch', 'baseline', 'prune'):
pass
else:
raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
# Create the training op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(total_loss, global_step)
if training_method == 'prune':
# construct the necessary hparams string from the FLAGS
hparams_string = ('begin_pruning_step={0},'
'sparsity_function_begin_step={0},'
'end_pruning_step={1},'
'sparsity_function_end_step={1},'
'target_sparsity={2},'
'pruning_frequency={3},'
'threshold_decay=0,'
'use_tpu={4}'.format(
FLAGS.sparsity_begin_step,
FLAGS.sparsity_end_step,
FLAGS.end_sparsity,
FLAGS.pruning_frequency,
False,
))
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
# Create a pruning object using the pruning hyperparameters
pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
tf.logging.info('starting mask update op')
# We override the train op to also update the mask.
with tf.control_dependencies([train_op]):
train_op = pruning_obj.conditional_mask_update_op()
masks = pruning.get_masks()
mask_metrics = utils.mask_summaries(masks)
for name, tensor in mask_metrics.items():
tf.summary.scalar(name, tensor)
tf.summary.scalar('learning_rate', learning_rate)
tf.summary.scalar('accuracy', accuracy)
tf.summary.scalar('total_loss', total_loss)
tf.summary.scalar('top_5_accuracy', top_5_accuracy)
# Logging drop_fraction if dynamic sparse training.
if training_method in ('set', 'momentum', 'rigl', 'static'):
tf.summary.scalar('drop_fraction', optimizer.drop_fraction)
summary_op = tf.summary.merge_all()
summary_hook = tf.train.SummarySaverHook(
save_secs=300, output_dir=train_dir, summary_op=summary_op)
hooks = [summary_hook]
eval_metrics = None
return hooks, eval_metrics, train_op
def build_model(mode,
images,
labels,
training_method='baseline',
num_classes=10,
depth=10,
width=4):
"""Build the wide ResNet model for training or eval.
If regularizer is specified, a regularizer term is added to the loss function.
The regularizer term is computed using either the pre-softmax activation or an
auxiliary network logits layer based upon activations earlier in the network
after the first resnet block.
Args:
mode: String for whether training or evaluation is taking place.
images: A 4D float32 tensor containing the model input images.
labels: A int32 tensor of size (batch size, number of classes)
containing the model labels.
training_method: The method used to sparsify the network weights.
num_classes: The number of distinct labels in the dataset.
depth: Number of core convolutional layers in the network.
width: The width of the convolurional filters in the resnet block.
Returns:
total_loss: A 1D float32 tensor that is the sum of cross-entropy and
all regularization losses.
accuracy: A 1D float32 accuracy tensor.
Raises:
ValueError: if depth is not the minimum amount required to build the
model.
"""
regularizer_term = tf.constant(FLAGS.l2, tf.float32)
kernel_regularizer = contrib_layers.l2_regularizer(scale=regularizer_term)
# depth should be 6n+4 where n is the desired number of resnet blocks
# if n=2,depth=10 n=3,depth=22, n=5,depth=34 n=7,depth=46
if (depth - 4) % 6 != 0:
raise ValueError('Depth of ResNet specified not sufficient.')
if mode == 'train':
is_training = True
else:
is_training = False
# 'threshold' would create layers with mask.
pruning_method = 'baseline' if training_method == 'baseline' else 'threshold'
model = WideResNetModel(
is_training=is_training,
regularizer=kernel_regularizer,
data_format='channels_last',
pruning_method=pruning_method,
prune_first_layer=FLAGS.prune_first_layer,
prune_last_layer=FLAGS.prune_last_layer)
logits = model.build(
images, depth=depth, width=width, num_classes=num_classes)
global_step = tf.train.get_or_create_global_step()
predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32))
in_top_5 = tf.cast(
tf.nn.in_top_k(predictions=logits, targets=labels, k=5), tf.float32)
top_5_accuracy = tf.cast(tf.reduce_mean(in_top_5), tf.float32)
return global_step, accuracy, top_5_accuracy, logits
def wide_resnet_w_pruning(features, labels, mode, params):
"""The model_fn for ResNet wide with pruning.
Args:
features: A float32 batch of images.
labels: A int32 batch of labels.
mode: Specifies whether training or evaluation.
params: Dictionary of parameters passed to the model.
Returns:
A EstimatorSpec for the model
Raises:
ValueError: if mode is not recognized as train or eval.
"""
if isinstance(features, dict):
features = features['feature']
train_dir = params['train_dir']
training_method = params['training_method']
global_step, accuracy, top_5_accuracy, logits = build_model(
mode=mode,
images=features,
labels=labels,
training_method=training_method,
num_classes=FLAGS.num_classes,
depth=FLAGS.resnet_depth,
width=FLAGS.resnet_width)
if mode == tf_estimator.ModeKeys.PREDICT:
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
return tf_estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf_estimator.export.PredictOutput(predictions)
})
with tf.name_scope('computing_cross_entropy_loss'):
entropy_loss = tf.losses.sparse_softmax_cross_entropy(
labels=labels, logits=logits)
tf.summary.scalar('cross_entropy_loss', entropy_loss)
with tf.name_scope('computing_total_loss'):
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)
if mode == tf_estimator.ModeKeys.TRAIN:
hooks, eval_metrics, train_op = train_fn(training_method, global_step,
total_loss, train_dir, accuracy,
top_5_accuracy)
elif mode == tf_estimator.ModeKeys.EVAL:
hooks = None
train_op = None
with tf.name_scope('summaries'):
eval_metrics = create_eval_metrics(labels, logits)
else:
raise ValueError('mode not recognized as training or eval.')
# If given load parameter values.
if FLAGS.initial_value_checkpoint:
tf.logging.info('Loading inital values from: %s',
FLAGS.initial_value_checkpoint)
utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,
FLAGS.train_dir, PARAM_SUFFIXES)
# Load or randomly initialize masks.
if (FLAGS.load_mask_dir and
FLAGS.training_method not in ('snip', 'baseline', 'prune')):
# Init masks.
tf.logging.info('Loading masks from %s', FLAGS.load_mask_dir)
utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir, FLAGS.train_dir,
MASK_SUFFIX)
scaffold = tf.train.Scaffold()
elif (FLAGS.mask_init_method and
FLAGS.training_method not in ('snip', 'baseline', 'scratch', 'prune')):
tf.logging.info('Initializing masks using method: %s',
FLAGS.mask_init_method)
all_masks = pruning.get_masks()
assigner = sparse_utils.get_mask_init_fn(
all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, {})
def init_fn(scaffold, session):
"""A callable for restoring variable from a checkpoint."""
del scaffold # Unused.
session.run(assigner)
scaffold = tf.train.Scaffold(init_fn=init_fn)
else:
assert FLAGS.training_method in ('snip', 'baseline', 'prune')
scaffold = None
tf.logging.info('No mask is set, starting dense.')
return tf_estimator.EstimatorSpec(
mode=mode,
training_hooks=hooks,
loss=total_loss,
train_op=train_op,
eval_metric_ops=eval_metrics,
scaffold=scaffold)
def main(argv):
del argv # Unused.
tf.set_random_seed(FLAGS.seed)
if FLAGS.training_steps_multiplier != 1.0:
multiplier = FLAGS.training_steps_multiplier
FLAGS.max_steps = int(FLAGS.max_steps * multiplier)
FLAGS.maskupdate_begin_step = int(FLAGS.maskupdate_begin_step * multiplier)
FLAGS.maskupdate_end_step = int(FLAGS.maskupdate_end_step * multiplier)
FLAGS.sparsity_begin_step = int(FLAGS.sparsity_begin_step * multiplier)
FLAGS.sparsity_end_step = int(FLAGS.sparsity_end_step * multiplier)
tf.logging.info(
'Training schedule is updated with multiplier: %.2f', multiplier)
# configures train directories based upon hyperparameters used.
if FLAGS.training_method == 'prune':
folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),
str(FLAGS.sparsity_begin_step),
str(FLAGS.sparsity_end_step),
str(FLAGS.pruning_frequency))
elif FLAGS.training_method in ('set', 'momentum', 'rigl', 'static'):
folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),
str(FLAGS.maskupdate_begin_step),
str(FLAGS.maskupdate_end_step),
str(FLAGS.maskupdate_frequency))
elif FLAGS.training_method in ('baseline', 'snip', 'scratch'):
folder_stub = os.path.join(FLAGS.training_method, str(0.0), str(0.0),
str(0.0), str(0.0))
else:
raise ValueError('Training method is not known %s' % FLAGS.training_method)
train_dir = os.path.join(FLAGS.train_dir, folder_stub)
# we pass the updated eval and train string to the params dictionary.
params = {}
params['train_dir'] = train_dir
params['data_split'] = FLAGS.mode
params['batch_size'] = FLAGS.batch_size
params['data_directory'] = FLAGS.data_directory
params['mode'] = FLAGS.mode
params['training_method'] = FLAGS.training_method
run_config = tf_estimator.RunConfig(
model_dir=train_dir,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
save_summary_steps=FLAGS.summaries_steps,
save_checkpoints_steps=FLAGS.checkpoint_steps,
log_step_count_steps=100)
classifier = tf_estimator.Estimator(
model_fn=wide_resnet_w_pruning,
model_dir=train_dir,
config=run_config,
params=params)
if FLAGS.mode == 'eval':
eval_steps = 10000 // FLAGS.batch_size
# Run evaluation when there's a new checkpoint
for ckpt in contrib_training.checkpoints_iterator(train_dir):
print('Starting to evaluate.')
try:
classifier.evaluate(
input_fn=input_fn,
steps=eval_steps,
checkpoint_path=ckpt,
name='eval')
# Terminate eval job when final checkpoint is reached
global_step = int(os.path.basename(ckpt).split('-')[1])
if global_step >= FLAGS.max_steps:
print('Evaluation finished after training step %d' % global_step)
break
except tf.errors.NotFoundError:
print('Checkpoint no longer exists,skipping checkpoint.')
else:
print('Starting training...')
if FLAGS.mode == 'train':
classifier.train(input_fn=input_fn, max_steps=FLAGS.max_steps)
if __name__ == '__main__':
tf.app.run(main)
================================================
FILE: rigl/experimental/jax/README.md
================================================
# Weight Symmetry Research Code
This code is mostly written by Yani Ioannou.
## Experiment Summary
There are a number of experiment drivers defined in the base directory:
### Experiment Types {#experiment-types}
random_mask
: Random Variable Sparsity Masks
: This experiment generates random masks of a given type (see
[Mask Types](#mask-types)) within the *given a sparsity range*, and trains
the models, tracking mask statistics and training details. Masks are
generated with a random number of connections and randomly shuffled.
shuffled_mask
: Random Fixed Sparsity Masks
: This experiment generates random masks of a given type (see
[Mask Types](#mask-types)) *of a fixed sparsity*, and trains the models,
tracking mask statistics and training details. Masks are generated with a
fixed number of connections and simply shuffled.
fixed_param
: Train models with (approximately) fixed number of parameters, but varying
depth/width.
: Train models with (approximately) fixed number of parameters, but varying
depth/width, with shuffled mask (as in shuffled_mask driver), and only the
MNIST_FC model type.
prune
: Simple Pruning/Training Driver
: This experiment trains a dense model pruning either iteratively or one-shot,
tracking mask statistics and training details.
train
: Simple Training Driver (Without Masking/Pruning)
: This experiment simply trains a dense model, tracking mask statistics and
training details.
### Mask Types {#mask-types}
symmetric
: Structured Mask.
: The mask is a structured
random
: Unstructured Mask.
: The mask as a whole is a random mask of a given sparsity, with some neurons
having fewer/more connections than others.
per-neuron
: Unstructured Mask.
: Each neuron has the same sparsity (# of masked connections), but is shuffled
randomly.
per-neuron-no-input-ablation:
: Unstructured Mask.
: As with per-neuron, each neuron has the same sparsity, but randomly shuffled
connections. Also at least one connection is maintained to each of the input
neurons (i.e. the input neurons are not effectively ablated), although these
connections are also randomly shuffled amongst the neurons of a given layer.
### Model Types {#model-types}
MNIST_FC
: A small fully-connected model, accepting number of neurons and depth as
parameters. No batch normalization, configurable drop-out rate (default: 0).
MNIST_CNN
: A small convolutional model designed for MNIST, accepting number of filters
for each layer and depth as parameters. Uses batch normalization and
configurable drop-out rate (default: 0).
CIFAR10_CNN
: A larger convolutional model designed for CIFAR10, accepting number of
filters for each layer and depth as parameters. No batch normalization,
configurable drop-out rate (default: 0).
### Dataset Types {#dataset-types}
MNIST
: Wrapper of the Tensorflow Datasets (TFDS) MNIST dataset.
CIFAR10
: Wrapper of the Tensorflow Datasets (TFDS) CIFAR10 dataset.
## Running Experiments
### Running on a Workstation
Train:
```shell
python -m weight_symmetry:${EXPERIMENT_TYPE}
```
## Result Processing/Analysis
### Plotting Results from a JSON Summary File
You can convert the results to a Pandas dataframe from a JSON summary file for
plotting/analysis using the example colab in `analysis/plot_summary_json.ipynb`.
================================================
FILE: rigl/experimental/jax/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains code for weight symmetry experiments."""
name = 'weight_symmetry'
================================================
FILE: rigl/experimental/jax/analysis/plot_summary_json.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6iEEw5OwSlnz"
},
"source": [
"# Plot Results from an Experiment Summary JSON File",
"Licensed under the Apache License, Version 2.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Eg6FmoCaTCHM"
},
"source": [
"## Parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ML0hUJMzYF0W"
},
"outputs": [],
"source": [
"from google.colab import files\n",
"\n",
"# Experiment summary filenames (one per experiment)\n",
"SUMMARY_FILES = files.upload()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "MHubbscQSLGm"
},
"outputs": [],
"source": [
"# Labels to use for each of the summaries listed above (in the same order!)\n",
"XID_LABELS=['structured', 'unstructured'] #@param"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "x0jDBWKdU_2A"
},
"source": [
"## Loading of JSON Summary/Conversion to Pandas Dataframe"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Lz-HwS1tU-ie"
},
"outputs": [],
"source": [
"import json\n",
"import pandas as pd\n",
"import os\n",
"\n",
"from colabtools.interactive_widgets import ProgressIter\n",
"\n",
"dfs = []\n",
"for i, summary_file in enumerate(SUMMARY_FILES):\n",
" with open(summary_file) as summary_file:\n",
" data = json.load(summary_file)\n",
" dataframe = pd.DataFrame.from_dict(data, orient='index')\n",
" dataframe['experiment_label'] = XID_LABELS[i]\n",
" dfs.append(dataframe)\n",
"\n",
"df=pd.concat(dfs)\n",
"\n",
"print('Loaded {} rows for experiment'.format(len(data)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DhO6oT1nVpTV"
},
"source": [
"## Measurements and Labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XFRR3XrXVopB"
},
"outputs": [],
"source": [
"DATA_LABELS={\n",
" 'best_train_loss/test_accuracy': 'Test Accuracy (of best train loss)',\n",
" 'best_train_loss/train_accuracy': 'Train Accuracy (of best train loss)',\n",
" 'best_train_loss/test_avg_loss': 'Test Loss (of best train loss)',\n",
" 'best_train_loss/train_avg_loss': 'Train Loss (of best train loss)',\n",
" 'best_train_loss/step': 'Training Iterations (of best train loss)',\n",
" 'best_train_loss/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best train loss)',\n",
" 'best_train_loss/vector_difference_norm': 'Vector Difference Norm. (of best train loss)',\n",
" 'best_train_loss/cosine_distance': 'Cosine Similarity (of best train loss)',\n",
" 'best_test_acc/test_accuracy': 'Test Accuracy (of best test acc.)',\n",
" 'best_test_acc/train_accuracy': 'Train Accuracy (of best test acc.)',\n",
" 'best_test_acc/test_avg_loss': 'Test Loss (of best test acc.)',\n",
" 'best_test_acc/train_avg_loss': 'Train Loss (of best test acc.)',\n",
" 'best_test_acc/step': 'Training Iterations (of best test acc.)',\n",
" 'best_test_acc/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best Test Acc.)',\n",
" 'best_test_acc/cosine_distance': 'Cosine Similarity (of best Test Acc.)',\n",
" 'best_test_acc/vector_difference_norm': 'Vector Difference Norm. (of best Test Acc.)',\n",
" 'mask/sparsity': 'Sparsity',\n",
" 'mask/unique_neurons': '# Unique Neurons',\n",
" 'mask/zeroed_neurons': '# Zeroed Neurons',\n",
" 'mask/permutation_log10': 'log10(1 + Permutations)',\n",
" 'mask/permutation_num_digits': 'Permutation # of Digits',\n",
" 'mask/permutations': 'Permutation',\n",
" 'mask/total_neurons': 'Total # of Neurons',\n",
" 'propagated_mask/sparsity': 'Mask Sparsity',\n",
" 'propagated_mask/unique_neurons': '# Unique Neurons (prop.)',\n",
" 'propagated_mask/zeroed_neurons': '# Zeroed Neurons (prop.)',\n",
" 'propagated_mask/permutation_log10': 'log10(1 + Permutations) (prop.)',\n",
" 'propagated_mask/permutation_num_digits': 'Permutation # of Digits (prop.)',\n",
" 'propagated_mask/permutations': 'Mask Permutations',\n",
" 'propagated_mask/total_neurons': 'Total # of Neurons (prop.)',\n",
" 'training/train_avg_loss': 'Train Loss',\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HAVkz8ZzV0Hd"
},
"source": [
"# Seaborn Plot Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "eoxoJH4gWHbb"
},
"outputs": [],
"source": [
"# Choose the X/Y/Z labels from the parameter list above.\n",
"X_LABEL='propagated_mask/sparsity' #@param {type:\"string\"}\n",
"Y_LABEL='best_train_loss/cumulative_gradient_norm' #@param {type:\"string\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "pudAXLl1VzFl"
},
"outputs": [],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Seaborn style - remove outer plot ticks, white plot background.\n",
"np.set_printoptions(linewidth=128, precision=3, edgeitems=5)\n",
"sns.set_style(\"whitegrid\")\n",
"sns.color_palette(\"muted\")\n",
"sns.set_context(\"paper\", font_scale=1, rc={\n",
" \"lines.linewidth\": 1.2,\n",
" \"xtick.major.size\": 0,\n",
" \"xtick.minor.size\": 0,\n",
" \"ytick.major.size\": 0,\n",
" \"ytick.minor.size\": 0\n",
"})\n",
"\n",
"# Higher resolution plots\n",
"%config InlineBackend.figure_format = 'retina'"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "lYUK9xi_aym3"
},
"source": [
"### Plot Raw Data Points"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uWcT76L6Wbv6"
},
"outputs": [],
"source": [
"\n",
"plt.figure(figsize=(16,8))\n",
"axis = sns.scatterplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', s=50, alpha=.5)\n",
"axis.set_ylabel(DATA_LABELS[Y_LABEL])\n",
"axis.set_xlabel(DATA_LABELS[X_LABEL])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Kws6tjfTa7h0"
},
"source": [
"### Plot Mean/StdDev"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jR04tmMnaxjG"
},
"outputs": [],
"source": [
"plt.figure(figsize=(16,8))\n",
"axis = sns.lineplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', alpha=.5, ci=\"sd\", markers=True)\n",
"axis.set_ylabel(DATA_LABELS[Y_LABEL])\n",
"axis.set_xlabel(DATA_LABELS[X_LABEL])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jyNFtKQfajiq"
},
"outputs": [],
"source": [
"# Code to save output files for publication.\n",
"PARAM_STR=X_LABEL.replace('/', '-')+'_'+Y_LABEL.replace('/', '-')\n",
"\n",
"OUT_FILE_PDF=f'/tmp/{PARAM_STR}.pdf'\n",
"OUT_FILE_SVG=f'/tmp/{PARAM_STR}.svg'\n",
"OUT_FILE_PNG=f'/tmp/{PARAM_STR}.png'\n",
"\n",
"plt.savefig(OUT_FILE_PDF, pi=600)\n",
"files.download(OUT_FILE_PDF)\n",
"\n",
"plt.savefig(OUT_FILE_SVG)\n",
"files.download(OUT_FILE_SVG)\n",
"\n",
"plt.savefig(OUT_FILE_PNG)\n",
"files.download(OUT_FILE_PNG)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "plot_summary_json",
"provenance": [
{
"file_id": "1g2aTwv76XMrLfEwryfj_tGzNnvZWjIVl",
"timestamp": 1600990155741
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: rigl/experimental/jax/datasets/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/datasets/cifar10.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CIFAR10 Dataset.
Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
In this case, the CIFAR10 dataset.
"""
from typing import MutableMapping, Sequence
from rigl.experimental.jax.datasets import dataset_base
import tensorflow.compat.v2 as tf
class CIFAR10Dataset(dataset_base.ImageDataset):
"""CIFAR10 dataset.
Attributes:
NAME: The Tensorflow Dataset's dataset name.
"""
NAME: str = 'cifar10'
# Computed from the training set by taking the per-channel mean/std-dev
# over sample, height and width axes of all training samples.
MEAN_RGB: Sequence[float] = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255]
STDDEV_RGB: Sequence[float] = [0.2470 * 255, 0.2435 * 255, 0.2616 * 255]
def __init__(self,
batch_size,
batch_size_test,
shuffle_buffer_size = 1024,
seed = 42):
"""CIFAR10 dataset.
Args:
batch_size: The batch size to use for the training datasets.
batch_size_test: The batch size used for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
seed: The random seed used to shuffle.
Returns:
Dataset: A dataset object.
Raises:
ValueError: If the test dataset is not evenly divisible by the
test batch size.
"""
super().__init__(CIFAR10Dataset.NAME, batch_size, batch_size_test,
shuffle_buffer_size, seed)
if self.get_test_len() % batch_size_test != 0:
raise ValueError(
'Test data not evenly divisible by batch size: {} % {} != 0.'.format(
self.get_test_len(), batch_size_test))
def preprocess(
self, data):
"""Normalizes CIFAR10 images: `uint8` -> `float32`.
Args:
data: Data sample.
Returns:
Data after being augmented/normalized/transformed.
"""
data = super().preprocess(data)
mean_rgb = tf.constant(self.MEAN_RGB, shape=[1, 1, 3], dtype=tf.float32)
std_rgb = tf.constant(self.STDDEV_RGB, shape=[1, 1, 3], dtype=tf.float32)
data['image'] = (tf.cast(data['image'], tf.float32) - mean_rgb) / std_rgb
return data
================================================
FILE: rigl/experimental/jax/datasets/cifar10_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.cifar10."""
from absl.testing import absltest
import numpy as np
from rigl.experimental.jax.datasets import cifar10
class CIFAR10DatasetTest(absltest.TestCase):
"""Test cases for CIFAR10 Dataset."""
def setUp(self):
"""Common setup routines/variables for test cases."""
super().setUp()
self._batch_size = 16
self._batch_size_test = 10
self._shuffle_buffer_size = 8
self._dataset = cifar10.CIFAR10Dataset(
self._batch_size,
batch_size_test=self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_create_dataset(self):
"""Tests creation of dataset."""
self.assertIsInstance(self._dataset, cifar10.CIFAR10Dataset)
def test_train_image_dims_content(self):
"""Tests dimensions and contents of test data."""
iterator = self._dataset.get_train()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='DataShape'):
self.assertTupleEqual(image.shape, (self._batch_size, 32, 32, 3))
with self.subTest(name='DataType'):
self.assertTrue(np.issubdtype(image.dtype, float))
with self.subTest(name='DataValues'):
# Normalized by stddev., expect nothing to fall outside 3 stddev.
self.assertTrue((image >= -3.).all() and (image <= 3.).all())
with self.subTest(name='LabelShape'):
self.assertLen(label, self._batch_size)
with self.subTest(name='LabelType'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='LabelValues'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_test_image_dims_content(self):
"""Tests dimensions and contents of train data."""
iterator = self._dataset.get_test()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='DataShape'):
self.assertTupleEqual(image.shape, (self._batch_size_test, 32, 32, 3))
with self.subTest(name='DataType'):
self.assertTrue(np.issubdtype(image.dtype, float))
with self.subTest(name='DataValues'):
# Normalized by stddev., expect nothing to fall outside 3 stddev.
self.assertTrue((image >= -3.).all() and (image <= 3.).all())
with self.subTest(name='LabelShape'):
self.assertLen(label, self._batch_size_test)
with self.subTest(name='LabelType'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='LabelValues'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_train_data_length(self):
"""Tests length of training dataset."""
total_count = 0
for batch in self._dataset.get_train():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_train_len())
def test_test_data_length(self):
"""Tests length of test dataset."""
total_count = 0
for batch in self._dataset.get_test():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_test_len())
def test_dataset_nonevenly_divisible_batch_size(self):
"""Tests non-evenly divisible test batch size."""
with self.assertRaisesRegex(
ValueError, 'Test data not evenly divisible by batch size: .*'):
self._dataset = cifar10.CIFAR10Dataset(
self._batch_size, batch_size_test=101)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/datasets/dataset_base.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset Classes.
Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
"""
import abc
from typing import MutableMapping, Optional
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
class Dataset(metaclass=abc.ABCMeta):
"""Base class for datasets.
Attributes:
DATAKEY: The key used for the data component of a Tensorflow Dataset
(TFDS) sample, e.g. 'image' for image datasets.
LABELKEY: The key used fot the label component of a Tensorflow Dataset
sample, i.e. 'label'.
name: The TFDS name of the dataset.
batch_size: The batch size to use for the training dataset.
batch_size_test: The batch size to use for the test dataset.
num_classes: the number of supervised classes in the dataset.
shape: the shape of an input data array.
"""
DATAKEY: Optional[str] = None
LABELKEY: str = 'label'
def __init__(self,
name,
batch_size,
batch_size_test,
shuffle_buffer_size,
prefetch_size = 1,
seed = None): # pytype: disable=annotation-type-mismatch
"""Base class for datasets.
Args:
name: The TFDS name of the dataset.
batch_size: The batch size to use for the training dataset.
batch_size_test: The batch size to use for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
prefetch_size: The number of mini-batches to prefetch.
seed: The random seed used to shuffle.
Returns:
A Dataset object.
"""
super().__init__()
self.name = name
self.batch_size = batch_size
self.batch_size_test = batch_size_test
self._shuffle_buffer_size = shuffle_buffer_size
self._prefetch_size = prefetch_size
self._train_ds, self._train_info = tfds.load(
self.name,
split=tfds.Split.TRAIN,
data_dir=self._dataset_dir(),
with_info=True)
self._train_ds = self._train_ds.shuffle(
self._shuffle_buffer_size,
seed).map(self.preprocess).cache().map(self.augment).batch(
self.batch_size, drop_remainder=True).prefetch(self._prefetch_size)
self._test_ds, self._test_info = tfds.load(
self.name,
split=tfds.Split.TEST,
data_dir=self._dataset_dir(),
with_info=True)
self._test_ds = self._test_ds.map(self.preprocess).cache().batch(
self.batch_size_test).prefetch(self._prefetch_size)
self.num_classes = self._train_info.features['label'].num_classes
self.shape = self._train_info.features['image'].shape
def _dataset_dir(self):
"""Returns the dataset path for the TFDS data."""
return None
def get_train(self):
"""Returns the training dataset."""
return iter(tfds.as_numpy(self._train_ds))
def get_train_len(self):
"""Returns the length of the training dataset."""
return self._train_info.splits['train'].num_examples
def get_test(self):
"""Returns the test dataset."""
return iter(tfds.as_numpy(self._test_ds))
def get_test_len(self):
"""Returns the length of the test dataset."""
return self._test_info.splits['test'].num_examples
def preprocess(
self, data):
"""Preprocessing fn used by TFDS map for normalization.
This function is for transformations that can be cached, e.g.
normalization/whitening.
Args:
data: Data sample.
Returns:
Data after being normalized/transformed.
"""
return data
def augment(
self, data):
"""Preprocessing fn used by TFDS map for augmentation at training time.
This function is for transformations that should not be cached, e.g. random
augmentation that should change for every sample, and are only applied at
training time.
Args:
data: Data sample.
Returns:
Data after being augmented/transformed.
"""
return data
class ImageDataset(Dataset):
"""Base class for image datasets."""
DATAKEY = 'image'
def preprocess(
self, data):
"""Preprocessing function used by TFDS map for normalization.
This function is for transformations that can be cached, e.g.
normalization/whitening.
Args:
data: Data sample.
Returns:
Data after being normalized/transformed.
"""
data = super().preprocess(data)
# Ensure we only provide the image and label, stripping out other keys.
return dict((key, val)
for key, val in data.items()
if key in [self.LABELKEY, self.DATAKEY])
================================================
FILE: rigl/experimental/jax/datasets/dataset_base_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.dataset_base."""
from absl.testing import absltest
from rigl.experimental.jax.datasets import dataset_base
class DummyDataset(dataset_base.ImageDataset):
"""A dummy implementation of the abstract dataset class.
Attributes:
NAME: The Tensorflow Dataset's dataset name.
"""
NAME: str = 'mnist'
def __init__(self,
batch_size,
batch_size_test,
shuffle_buffer_size = 1024,
seed = 42):
"""Dummy MNIST dataset.
Args:
batch_size: The batch size to use for the training datasets.
batch_size_test: The batch size to used for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
seed: The random seed used to shuffle.
Returns:
Dataset: A dataset object.
"""
super().__init__(DummyDataset.NAME, batch_size, batch_size_test,
shuffle_buffer_size, seed)
class DummyDatasetTest(absltest.TestCase):
"""Test cases for dummy dataset."""
def setUp(self):
"""Common setup routines/variables for test cases."""
super().setUp()
self._batch_size = 16
self._batch_size_test = 10
self._shuffle_buffer_size = 8
self._dataset = DummyDataset(
self._batch_size,
batch_size_test=self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_create_dataset(self):
"""Tests creation of dataset."""
self.assertIsInstance(self._dataset, DummyDataset)
def test_train_image_dims_content(self):
"""Tests dimensions and contents of test data."""
iterator = iter(self._dataset.get_train())
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1))
with self.subTest(name='data_values'):
self.assertBetween(image.all(), 0, 256)
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size)
with self.subTest(name='label_values'):
self.assertBetween(label.all(), 0, self._dataset.num_classes)
def test_test_image_dims_content(self):
"""Tests dimensions and contents of train data."""
iterator = iter(self._dataset.get_test())
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1))
with self.subTest(name='data_values'):
self.assertBetween(image.all(), 0, 256)
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size_test)
with self.subTest(name='label_values'):
self.assertBetween(label.all(), 0, self._dataset.num_classes)
def test_train_data_length(self):
"""Tests length of training dataset."""
total_count = 0
for batch in self._dataset.get_train():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_train_len())
def test_test_data_length(self):
"""Tests length of test dataset."""
total_count = 0
for batch in self._dataset.get_test():
total_count += len(batch['label'])
# Check image size/content.
self.assertEqual(total_count, self._dataset.get_test_len())
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/datasets/dataset_factory.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset Factory.
Dataset factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
Attributes:
DATASETS: A list of the datasets that can be created.
"""
from typing import Any, Mapping, Type
from rigl.experimental.jax.datasets import cifar10
from rigl.experimental.jax.datasets import dataset_base
from rigl.experimental.jax.datasets import mnist
import tensorflow.compat.v2 as tf
DATASETS: Mapping[str, Type[dataset_base.Dataset]] = {
'MNIST': mnist.MNISTDataset,
'CIFAR10': cifar10.CIFAR10Dataset,
}
def create_dataset(name, *args, **kwargs):
"""Creates a Tensorflow datasets (TFDS) dataset.
Args:
name: The TFDS name of the dataset.
*args: Dataset arguments.
**kwargs: Dataset keyword arguments.
Returns:
Dataset: An abstracted dataset object.
Raises:
ValueError if a dataset with the given name does not exist.
"""
if name not in DATASETS:
raise ValueError(f'No such dataset: {name}')
return DATASETS[name](*args, **kwargs)
================================================
FILE: rigl/experimental/jax/datasets/dataset_factory_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.dataset_common."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from rigl.experimental.jax.datasets import dataset_base
from rigl.experimental.jax.datasets import dataset_factory
class DatasetCommonTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self._batch_size = 32
self._batch_size_test = 10
self._shuffle_buffer_size = 128
def _create_dataset(self, dataset_name):
"""Helper function for creating a dataset."""
return dataset_factory.create_dataset(
dataset_name,
self._batch_size,
self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_dataset_supported(self):
"""Tests supported datasets."""
for dataset_name in dataset_factory.DATASETS:
dataset = self._create_dataset(dataset_name)
self.assertIsInstance(dataset, dataset_base.Dataset)
@parameterized.parameters(*dataset_factory.DATASETS.keys())
def test_dataset_train_iterators(self, dataset_name):
"""Tests dataset's train iterator."""
dataset = self._create_dataset(dataset_name)
sample = next(dataset.get_train())
with self.subTest(name='{}_sample'.format(dataset_name)):
self.assertNotEmpty(sample)
with self.subTest(name='{}_label_type'.format(dataset_name)):
self.assertIsInstance(sample['label'], np.ndarray)
with self.subTest(name='{}_label_batch_size'.format(dataset_name)):
self.assertLen(sample['label'], self._batch_size)
with self.subTest(name='{}_image_type'.format(dataset_name)):
self.assertIsInstance(sample['image'], np.ndarray)
with self.subTest(name='{}_image_shape'.format(dataset_name)):
self.assertLen(sample['image'].shape, 4)
with self.subTest(name='{}_image_batch_size'.format(dataset_name)):
self.assertEqual(sample['image'].shape[0], self._batch_size)
with self.subTest(
name='{}_non_zero_image_dimensions'.format(dataset_name)):
self.assertGreater(sample['image'].shape[1], 1)
@parameterized.parameters(*dataset_factory.DATASETS.keys())
def test_dataset_test_iterators(self, dataset_name):
"""Tests dataset's test iterator."""
dataset = self._create_dataset(dataset_name)
sample = next(dataset.get_test())
with self.subTest(name='{}_sample'.format(dataset_name)):
self.assertNotEmpty(sample)
with self.subTest(name='{}_label_type'.format(dataset_name)):
self.assertIsInstance(sample['label'], np.ndarray)
with self.subTest(name='{}_label_batch_size'.format(dataset_name)):
self.assertLen(sample['label'], self._batch_size_test)
with self.subTest(name='{}_image_type'.format(dataset_name)):
self.assertIsInstance(sample['image'], np.ndarray)
with self.subTest(name='{}_image_shape'.format(dataset_name)):
self.assertLen(sample['image'].shape, 4)
with self.subTest(name='{}_image_batch_size'.format(dataset_name)):
self.assertEqual(sample['image'].shape[0], self._batch_size_test)
with self.subTest(
name='{}_non_zero_image_dimensions'.format(dataset_name)):
self.assertGreater(sample['image'].shape[1], 1)
def test_dataset_unsupported(self):
"""Tests unsupported datasets."""
with self.assertRaisesRegex(ValueError, 'No such dataset: unsupported'):
self._create_dataset('unsupported')
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/datasets/mnist.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST Dataset.
Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
In this case, the MNIST dataset.
"""
from typing import MutableMapping
from rigl.experimental.jax.datasets import dataset_base
import tensorflow.compat.v2 as tf
class MNISTDataset(dataset_base.ImageDataset):
"""MNIST dataset.
Attributes:
NAME: The Tensorflow Dataset's dataset name.
"""
NAME: str = 'mnist'
def __init__(self,
batch_size,
batch_size_test,
shuffle_buffer_size = 1024,
seed = 42):
"""MNIST dataset.
Args:
batch_size: The batch size to use for the training datasets.
batch_size_test: The batch size to used for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
seed: The random seed used to shuffle.
Returns:
Dataset: A dataset object.
"""
super().__init__(MNISTDataset.NAME, batch_size, batch_size_test,
shuffle_buffer_size, seed)
def preprocess(
self, data):
"""Normalizes MNIST images: `uint8` -> `float32`.
Args:
data: Data sample.
Returns:
Data after being augmented/normalized/transformed.
"""
data = super().preprocess(data)
data['image'] = (tf.cast(data['image'], tf.float32) / 255.) - 0.5
return data
================================================
FILE: rigl/experimental/jax/datasets/mnist_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.mnist."""
from absl.testing import absltest
import numpy as np
from rigl.experimental.jax.datasets import mnist
class MNISTDatasetTest(absltest.TestCase):
"""Test cases for MNIST Dataset."""
def setUp(self):
"""Common setup routines/variables for test cases."""
super().setUp()
self._batch_size = 16
self._batch_size_test = 10
self._shuffle_buffer_size = 8
self._dataset = mnist.MNISTDataset(
self._batch_size,
batch_size_test=self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_create_dataset(self):
"""Tests creation of dataset."""
self.assertIsInstance(self._dataset, mnist.MNISTDataset)
def test_train_image_dims_content(self):
"""Tests dimensions and contents of test data."""
iterator = self._dataset.get_train()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1))
with self.subTest(name='data_values'):
self.assertTrue((image >= -1.).all() and (image <= 1.).all())
with self.subTest(name='data_type'):
self.assertTrue(np.issubdtype(image.dtype, float))
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size)
with self.subTest(name='label_type'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='label_values'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_test_image_dims_content(self):
"""Tests dimensions and contents of train data."""
iterator = self._dataset.get_test()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1))
with self.subTest(name='data_type'):
self.assertTrue(np.issubdtype(image.dtype, float))
# TODO: Find a better approach to testing with JAX arrays.
with self.subTest(name='data_values'):
self.assertTrue((image >= -1.).all() and (image <= 1.).all())
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size_test)
with self.subTest(name='label_type'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='label_values'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_train_data_length(self):
"""Tests length of training dataset."""
total_count = 0
for batch in self._dataset.get_train():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_train_len())
def test_test_data_length(self):
"""Tests length of test dataset."""
total_count = 0
for batch in self._dataset.get_test():
total_count += len(batch['label'])
# Check image size/content.
self.assertEqual(total_count, self._dataset.get_test_len())
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/fixed_param.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Symmetry: Train models with fixed param, but diff. depth and width."""
import ast
import functools
import operator
from os import path
from typing import List, Sequence
import uuid
from absl import app
from absl import flags
from absl import logging
import flax
from flax.metrics import tensorboard
from flax.training import lr_schedule
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import mnist_fc
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.pruning import symmetry
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils
experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))
logging.info('Saving experimental results to %s', experiment_dir)
host_count = jax.host_count()
local_device_count = jax.local_device_count()
logging.info('Device count: %d, host count: %d, local device count: %d',
jax.device_count(), host_count, local_device_count)
if jax.host_id() == 0:
summary_writer = tensorboard.SummaryWriter(experiment_dir)
dataset = dataset_factory.create_dataset(
FLAGS.dataset,
FLAGS.batch_size,
FLAGS.batch_size_test,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
logging.info('Training %s on the %s dataset...', MODEL, FLAGS.dataset)
rng = jax.random.PRNGKey(FLAGS.random_seed)
input_shape = (1,) + dataset.shape
input_len = functools.reduce(operator.mul, dataset.shape)
features = mnist_fc.feature_dim_for_param(
input_len,
FLAGS.param_count,
FLAGS.depth)
logging.info('Model Configuration: %s', str(features))
base_model, _ = model_factory.create_model(
MODEL,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
features=features)
model_param_count = utils.count_param(base_model, ('kernel',))
logging.info(
'Model Config: param.: %d, depth: %d. max width: %d, min width: %d',
model_param_count, len(features), max(features), min(features))
logging.info('Generating random mask based on model')
# Re-initialize the RNG to maintain same training pattern (as in prune code).
mask_rng = jax.random.PRNGKey(FLAGS.random_seed)
mask = masked.shuffled_mask(
base_model,
rng=mask_rng,
sparsity=FLAGS.mask_sparsity)
if jax.host_id() == 0:
mask_stats = symmetry.get_mask_stats(mask)
logging.info('Mask stats: %s', str(mask_stats))
for label, value in mask_stats.items():
try:
summary_writer.scalar(f'mask/{label}', value, 0)
# This is needed because permutations (long int) can't be cast to float32.
except (OverflowError, ValueError):
summary_writer.text(f'mask/{label}', str(value), 0)
logging.error('Could not write mask/%s to tensorflow summary as float32'
', writing as string instead.', label)
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(
mask_stats, path.join(experiment_dir, 'mask_stats.json'))
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(mask_stats,
path.join(experiment_dir, 'mask_stats.json'))
model_stats = {
'depth': len(features),
'max_width': max(features),
'min_width': min(features),
}
model_stats.update(
{'feature_{}'.format(i): value for i, value in enumerate(features)})
if FLAGS.dump_json:
utils.dump_dict_json(model_stats,
path.join(experiment_dir, 'model_stats.json'))
model, initial_state = model_factory.create_model(
'MNIST_FC',
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
features=features, masks=mask)
if FLAGS.opt == 'Adam':
optimizer = flax.optim.Adam(
learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
elif FLAGS.opt == 'Momentum':
optimizer = flax.optim.Momentum(
learning_rate=FLAGS.lr,
beta=FLAGS.momentum,
weight_decay=FLAGS.weight_decay,
nesterov=False)
else:
raise ValueError('Unknown Optimizer: {}'.format(FLAGS.opt))
steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size
if FLAGS.lr_schedule == 'constant':
lr_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.lr, steps_per_epoch)
elif FLAGS.lr_schedule == 'stepped':
lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, lr_schedule_steps)
elif FLAGS.lr_schedule == 'cosine':
lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, FLAGS.epochs)
else:
raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule))
if jax.host_id() == 0:
trainer = training.Trainer(
optimizer,
model,
initial_state,
dataset,
rng,
summary_writer=summary_writer,
)
else:
trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)
_, best_metrics = trainer.train(
FLAGS.epochs,
lr_fn=lr_fn,
update_iter=FLAGS.update_iterations,
update_epoch=FLAGS.update_epoch,
)
logging.info('Best metrics: %s', str(best_metrics))
if jax.host_id() == 0:
if FLAGS.dump_json:
utils.dump_dict_json(best_metrics,
path.join(experiment_dir, 'best_metrics.json'))
for label, value in best_metrics.items():
summary_writer.scalar('best/{}'.format(label), value,
FLAGS.epochs * steps_per_epoch)
summary_writer.close()
def main(argv: List[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
run_training()
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/experimental/jax/fixed_param_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.fixed_param."""
import glob
from os import path
import tempfile
from absl.testing import absltest
from absl.testing import flagsaver
from rigl.experimental.jax import fixed_param
class FixedParamTest(absltest.TestCase):
def test_run(self):
"""Tests if the driver for shuffled training runs correctly."""
experiment_dir = tempfile.mkdtemp()
eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
fixed_param.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/models/cifar10_cnn.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CIFAR10 CNN.
A small CNN for the CIFAR10 dataset, consists of a number of convolutional
layers (determined by length of filters parameter), followed by a
fully-connected layer.
"""
from typing import Callable, Mapping, Optional, Sequence
from absl import logging
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
class CIFAR10CNN(flax.deprecated.nn.Module):
"""Small CIFAR10 CNN."""
def apply(self,
inputs,
num_classes,
filter_shape = (3, 3),
filters = (32, 32, 64, 64, 128, 128),
init_fn=flax.deprecated.nn.initializers.kaiming_normal,
train=True,
activation_fn = flax.deprecated.nn.relu,
masks = None,
masked_layer_indices = None):
"""Applies a convolution to the inputs.
Args:
inputs: Input data with dimensions (batch, spatial_dims..., features).
num_classes: Number of classes in the dataset.
filter_shape: Shape of the convolutional filters.
filters: Number of filters in each convolutional layer, and number of conv
layers (given by length of sequence).
init_fn: Initialization function used for convolutional layers.
train: If model is being evaluated in training mode or not.
activation_fn: Activation function to be used for convolutional layers.
masks: Masks of the layers in this model, in the same form as
module params, or None.
masked_layer_indices: The layer indices of layers in model to be masked.
Returns:
A tensor of shape (batch, num_classes), containing the logit output.
Raises:
ValueError if the number of pooling layers is too many for the given input
size, or if the provided mask is not of the correct depth for the model.
"""
# Note: First dim is batch, last dim is channels, other dims are "spatial".
if not all([(dim >= 2**(len(filters)//2)) for dim in inputs.shape[1:-2]]):
raise ValueError(
'Input spatial size, {}, does not allow {} pooling layers.'.format(
str(inputs.shape[1:-2]), len(filters))
)
depth = 1 + len(filters)
masks = masked.generate_model_masks(depth, masks,
masked_layer_indices)
batch_norm = flax.deprecated.nn.BatchNorm.partial(
use_running_average=not train, momentum=0.99, epsilon=1e-5)
for i, filter_num in enumerate(filters):
if f'MaskedModule_{i}' in masks:
logging.info('Layer %d is masked in model', i)
mask = masks[f'MaskedModule_{i}']
inputs = masked.masked(flax.deprecated.nn.Conv, mask)(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init.sparse_init(
init_fn(), mask['kernel'] if mask is not None else None))
else:
inputs = flax.deprecated.nn.Conv(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init_fn())
inputs = batch_norm(inputs, name='bn_conv_{}'.format(i))
inputs = activation_fn(inputs)
if i % 2 == 1:
inputs = flax.deprecated.nn.max_pool(
inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID')
# Global average pooling if we have spatial dimensions left.
inputs = flax.deprecated.nn.avg_pool(
inputs, window_shape=(inputs.shape[1:-1]), padding='VALID')
inputs = inputs.reshape((inputs.shape[0], -1))
# This is effectively a Dense layer, but we cast it as a convolution layer
# to allow us to easily propagate masks, avoiding b/156135283.
inputs = flax.deprecated.nn.Conv(
inputs,
features=num_classes,
kernel_size=inputs.shape[1:-1],
kernel_init=flax.deprecated.nn.initializers.xavier_normal())
inputs = batch_norm(inputs, name='bn_dense_1')
inputs = jnp.squeeze(inputs)
return flax.deprecated.nn.log_softmax(inputs)
================================================
FILE: rigl/experimental/jax/models/cifar10_cnn_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.cifar10_cnn."""
from absl.testing import absltest
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import cifar10_cnn
class CIFAR10CNNTest(absltest.TestCase):
"""Tests the CIFAR10CNN model."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._num_classes = 10
self._batch_size = 2
self._input_shape = ((self._batch_size, 32, 32, 3), jnp.float32)
self._input = jnp.zeros(*self._input_shape)
def test_output_shapes(self):
"""Tests the output shapes of the model."""
with flax.deprecated.nn.stateful() as initial_state:
_, initial_params = cifar10_cnn.CIFAR10CNN.init_by_shape(
self._rng, (self._input_shape,), num_classes=self._num_classes)
model = flax.deprecated.nn.Model(cifar10_cnn.CIFAR10CNN, initial_params)
with flax.deprecated.nn.stateful(initial_state, mutable=False):
logits = model(self._input, num_classes=self._num_classes, train=False)
self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))
def test_invalid_spatial_dimensions(self):
"""Tests model with an invalid spatial dimension parameters."""
with self.assertRaisesRegex(ValueError, 'Input spatial size, '):
cifar10_cnn.CIFAR10CNN.init_by_shape(
self._rng, (self._input_shape,),
num_classes=self._num_classes,
filters=20 * (32,))
def test_invalid_masks_depth(self):
"""Tests model mask with the incorrect depth for the given model."""
invalid_masks = {
'MaskedModule_0': {
'kernel':
jnp.zeros((self._batch_size, 3, 3, 32))
}
}
with self.assertRaisesRegex(
ValueError, 'Mask is invalid for model.'):
cifar10_cnn.CIFAR10CNN.init_by_shape(
self._rng, (self._input_shape,),
num_classes=self._num_classes,
masks=invalid_masks)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/mnist_cnn.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST CNN.
A small CNN for the MNIST dataset, consists of a number of convolutional layers
(determined by length of filters parameter), followed by a fully-connected
layer.
"""
from typing import Callable, Mapping, Optional, Sequence
from absl import logging
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
class MNISTCNN(flax.deprecated.nn.Module):
"""Small MNIST CNN."""
def apply(self,
inputs,
num_classes,
filter_shape = (5, 5),
filters = (16, 32),
dense_size = 64,
train=True,
init_fn = flax.deprecated.nn.initializers.kaiming_normal,
activation_fn = flax.deprecated.nn.relu,
masks = None,
masked_layer_indices = None):
"""Applies a convolution to the inputs.
Args:
inputs: Input data with dimensions (batch, spatial_dims..., features).
num_classes: Number of classes in the dataset.
filter_shape: Shape of the convolutional filters.
filters: Number of filters in each convolutional layer, and number of conv
layers (given by length of sequence).
dense_size: Number of filters in each convolutional layer, and number of
conv layers (given by length of sequence).
train: If model is being evaluated in training mode or not.
init_fn: Initialization function used for convolutional layers.
activation_fn: Activation function to be used for convolutional layers.
masks: Masks of the layers in this model, in the same form as
module params, or None.
masked_layer_indices: The layer indices of layers in model to be masked.
Returns:
A tensor of shape (batch, num_classes), containing the logit output.
Raises:
ValueError if the number of pooling layers is too many for the given input
size.
"""
# Note: First dim is batch, last dim is channels, other dims are "spatial".
if not all([(dim >= 2**len(filters)) for dim in inputs.shape[1:-2]]):
raise ValueError(
'Input spatial size, {}, does not allow {} pooling layers.'.format(
str(inputs.shape[1:-2]), len(filters))
)
depth = 2 + len(filters)
masks = masked.generate_model_masks(depth, masks,
masked_layer_indices)
batch_norm = flax.deprecated.nn.BatchNorm.partial(
use_running_average=not train, momentum=0.99, epsilon=1e-5)
for i, filter_num in enumerate(filters):
if f'MaskedModule_{i}' in masks:
logging.info('Layer %d is masked in model', i)
mask = masks[f'MaskedModule_{i}']
inputs = masked.masked(flax.deprecated.nn.Conv, mask)(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init.sparse_init(
init_fn(), mask['kernel'] if mask is not None else None))
else:
inputs = flax.deprecated.nn.Conv(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init_fn())
inputs = batch_norm(inputs, name='bn_conv_{}'.format(i))
inputs = activation_fn(inputs)
if i < len(filters) - 1:
inputs = flax.deprecated.nn.max_pool(
inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID')
# Global average pool at end of convolutional layers.
inputs = flax.deprecated.nn.avg_pool(
inputs, window_shape=inputs.shape[1:-1], padding='VALID')
# This is effectively a Dense layer, but we cast it as a convolution layer
# to allow us to easily propagate masks, avoiding b/156135283.
if f'MaskedModule_{depth - 2}' in masks:
mask_dense_1 = masks[f'MaskedModule_{depth - 2}']
inputs = masked.masked(flax.deprecated.nn.Conv, mask_dense_1)(
inputs,
features=dense_size,
kernel_size=inputs.shape[1:-1],
kernel_init=init.sparse_init(
init_fn(),
mask_dense_1['kernel'] if mask_dense_1 is not None else None))
else:
inputs = flax.deprecated.nn.Conv(
inputs,
features=dense_size,
kernel_size=inputs.shape[1:-1],
kernel_init=init_fn())
inputs = batch_norm(inputs, name='bn_dense_1')
inputs = activation_fn(inputs)
inputs = flax.deprecated.nn.Dense(
inputs,
features=num_classes,
kernel_init=flax.deprecated.nn.initializers.xavier_normal())
inputs = batch_norm(inputs, name='bn_dense_2')
inputs = jnp.squeeze(inputs)
return flax.deprecated.nn.log_softmax(inputs)
================================================
FILE: rigl/experimental/jax/models/mnist_cnn_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.mnist_cnn."""
from absl.testing import absltest
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import mnist_cnn
class MNISTCNNTest(absltest.TestCase):
"""Tests the MNISTCNN model."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._num_classes = 10
self._batch_size = 2
self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
self._input = jnp.zeros(*self._input_shape)
def test_output_shapes(self):
"""Tests the output shapes of the model."""
with flax.deprecated.nn.stateful() as initial_state:
_, initial_params = mnist_cnn.MNISTCNN.init_by_shape(
self._rng, (self._input_shape,), num_classes=self._num_classes)
model = flax.deprecated.nn.Model(mnist_cnn.MNISTCNN, initial_params)
with flax.deprecated.nn.stateful(initial_state, mutable=False):
logits = model(self._input, num_classes=self._num_classes, train=False)
self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))
def test_invalid_depth(self):
"""Tests model mask with the incorrect depth for the given model."""
with self.assertRaisesRegex(ValueError, 'Input spatial size, '):
mnist_cnn.MNISTCNN.init_by_shape(
self._rng, (self._input_shape,),
num_classes=self._num_classes,
filters=10 * (32,))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/mnist_fc.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST Fully-Connected Neural Network.
A fully-connected model for the MNIST dataset, consists of a number of
dense layers (determined by length of features parameter).
"""
import math
from typing import Callable, Mapping, Optional, Sequence, Tuple
from absl import logging
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
def feature_dim_for_param(input_len,
param_count,
depth,
depth_mult = 2.):
"""Calculates feature dimensions for a fixed parameter count and depth.
This is calculated for the specific case of a fully-connected neural
network, where each layer consists of l * a**i neurons, where a is a
multiplier for each layer.
Assume,
x is the input size,
a is the depth multiplier,
l is the initial layer width,
d is the depth.
The total number of parameters, n, is then given by,
$$n = x*l + l^2 * sum_{i=2}^d a^{2i-3})$$.
Args:
input_len: Input size.
param_count: Number of parameters model should maintain.
depth: Depth of the model.
depth_mult: The layer width multiplier w.r.t. depth.
Returns:
The feature specification for a fully-connected model, as a tuple of layer
widths.
Raises:
ValueError: If the given number of parameters is too low for the given
depth and input size.
"""
# Calculate the initial width for the first layer.
if depth == 1:
initial_width = param_count / input_len
else:
# l = ((x^2 + 4cn)^{1/2} - x)/(2c) where c = sum_{i=2}^d a^{2i-3}.
depth_sum = sum(depth_mult**(2 * i - 3) for i in range(2, depth + 1))
initial_width = (math.sqrt(input_len**2 + 4 * depth_sum * param_count) -
input_len) / (2 * depth_sum)
if initial_width < 1:
raise ValueError(
'Expected parameter count too low for given depth and input size.')
return tuple(int(int(initial_width) * depth_mult**i) for i in range(depth))
class MNISTFC(flax.deprecated.nn.Module):
"""MNIST Fully-Connected Neural Network."""
def apply(self,
inputs,
num_classes,
features = (32, 32),
train=True,
init_fn = flax.deprecated.nn.initializers.kaiming_normal,
activation_fn = flax.deprecated.nn.relu,
masks = None,
masked_layer_indices = None,
dropout_rate = 0.):
"""Applies fully-connected neural network to the inputs.
Args:
inputs: Input data with dimensions (batch, features), if features has more
than one dimension, it is flattened.
num_classes: Number of classes in the dataset.
features: Number of neurons in each layer, and number of layers (given by
length of sequence) + one layer for softmax.
train: If model is being evaluated in training mode or not.
init_fn: Initialization function used for dense layers.
activation_fn: Activation function to be used for dense layers.
masks: Masks of the layers in this model, in the same form as module
params, or None.
masked_layer_indices: The layer indices of layers in model to be masked.
dropout_rate: Dropout rate, if 0 then dropout is not used (default).
Returns:
A tensor of shape (batch, num_classes), containing the logit output.
"""
batch_norm = flax.deprecated.nn.BatchNorm.partial(
use_running_average=not train, momentum=0.99, epsilon=1e-5)
depth = 1 + len(features)
masks = masked.generate_model_masks(depth, masks,
masked_layer_indices)
# If inputs are in image dimensions, flatten image.
inputs = inputs.reshape(inputs.shape[0], -1)
for i, feature_num in enumerate(features):
if f'MaskedModule_{i}' in masks:
logging.info('Layer %d is masked in model', i)
mask = masks[f'MaskedModule_{i}']
inputs = masked.masked(flax.deprecated.nn.Dense, mask)(
inputs,
features=feature_num,
kernel_init=init.sparse_init(
init_fn(), mask['kernel'] if mask is not None else None))
else:
inputs = flax.deprecated.nn.Dense(
inputs, features=feature_num, kernel_init=init_fn())
inputs = batch_norm(inputs, name=f'bn_conv_{i}')
inputs = activation_fn(inputs)
if dropout_rate > 0.0:
inputs = flax.deprecated.nn.dropout(
inputs, dropout_rate, deterministic=not train)
inputs = flax.deprecated.nn.Dense(
inputs,
features=num_classes,
kernel_init=flax.deprecated.nn.initializers.xavier_normal())
return flax.deprecated.nn.log_softmax(inputs)
================================================
FILE: rigl/experimental/jax/models/mnist_fc_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.mnist_fc."""
from typing import Sequence
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import mnist_fc
from rigl.experimental.jax.utils import utils
PARAM_COUNT_PARAM: Sequence[str] = ('kernel',)
class MNISTFCTest(parameterized.TestCase):
"""Tests the MNISTFC model."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._num_classes = 10
self._batch_size = 2
self._input_len = 28*28*1
self._input_shape = ((self._batch_size, self._input_len), jnp.float32)
self._input = jnp.zeros((self._batch_size, self._input_len), jnp.float32)
self._param_count = 1e7
def test_output_shapes(self):
"""Tests the output shape from the model."""
with flax.deprecated.nn.stateful() as initial_state:
_, initial_params = mnist_fc.MNISTFC.init_by_shape(
self._rng, (self._input_shape,), num_classes=self._num_classes)
model = flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params)
with flax.deprecated.nn.stateful(initial_state, mutable=False):
logits = model(self._input, num_classes=self._num_classes, train=False)
self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))
def test_invalid_masks_depth(self):
"""Tests a model with an invalid mask."""
invalid_masks = {
'MaskedModule_0': {
'kernel':
jnp.zeros((self._batch_size, 5 * 5 * 16))
}
}
with self.assertRaisesRegex(
ValueError, 'Mask is invalid for model.'):
mnist_fc.MNISTFC.init_by_shape(
self._rng,
(self._input_shape,),
num_classes=self._num_classes,
masks=invalid_masks)
def _create_model(self, features):
"""Convenience fn to create a FLAX model ."""
_, initial_params = mnist_fc.MNISTFC.init_by_shape(
self._rng,
(self._input_shape,),
num_classes=self._num_classes,
features=features)
return flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params)
@parameterized.parameters(*range(1, 6))
def test_feature_dim_for_param_depth(self, depth):
"""Tests feature_dim_for_param with multiple depths."""
features = mnist_fc.feature_dim_for_param(self._input_len,
self._param_count, depth)
model = self._create_model(features)
total_size = utils.count_param(model, PARAM_COUNT_PARAM)
with self.subTest(name='FeatureDimLen'):
self.assertLen(features, depth)
with self.subTest(name='FeatureDimParamCount'):
self.assertBetween(total_size, self._param_count * 0.95,
self._param_count * 1.05)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/model_factory.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for neural network models.
Attributes:
MODELS: A list of the models that can be created.
"""
from typing import Any, Callable, Mapping, Sequence, Tuple, Type
import flax
import jax.numpy as jnp
from rigl.experimental.jax.models import cifar10_cnn
from rigl.experimental.jax.models import mnist_cnn
from rigl.experimental.jax.models import mnist_fc
MODELS: Mapping[str, Type[flax.deprecated.nn.Model]] = {
'MNIST_CNN': mnist_cnn.MNISTCNN,
'MNIST_FC': mnist_fc.MNISTFC,
'CIFAR10_CNN': cifar10_cnn.CIFAR10CNN,
}
def create_model(
name, rng,
input_specs, **kwargs
):
"""Creates a Model.
Args:
name: the name of the model to instantiate.
rng : the random number generator to use for init.
input_specs: an iterable of (shape, dtype) pairs specifying the inputs.
**kwargs: list of model specific keyword arguments.
Returns:
A tuple of FLAX model (flax.deprecated.nn.Model), and initial model state.
Raises:
ValueError if a model with the given name does not exist.
"""
if name not in MODELS:
raise ValueError('No such model: {}'.format(name))
with flax.deprecated.nn.stateful() as init_state:
with flax.deprecated.nn.stochastic(rng):
model_class = MODELS[name].partial(**kwargs)
_, params = model_class.init_by_shape(rng, input_specs)
return flax.deprecated.nn.Model(model_class, params), init_state
def update_model(model,
**kwargs):
"""Updates a model to use different model arguments, but same parameters.
Args:
model: The model to update.
**kwargs: List of model specific keyword arguments.
Returns:
A FLAX model.
"""
return flax.deprecated.nn.Model(model.module.partial(**kwargs), model.params)
================================================
FILE: rigl/experimental/jax/models/model_factory_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.model_factory."""
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import model_factory
class ModelCommonTest(parameterized.TestCase):
"""Tests the model factory."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._input_shape = ((1, 28, 28, 1), jnp.float32)
self._num_classes = 10
def _create_model(self, model_name):
return model_factory.create_model(
model_name,
self._rng, (self._input_shape,),
num_classes=self._num_classes)
@parameterized.parameters(*model_factory.MODELS.keys())
def test_model_supported(self, model_name):
"""Tests supported models."""
model, state = self._create_model(model_name)
with self.subTest(name='test_model_supported_model_instance'):
self.assertIsInstance(model, flax.deprecated.nn.Model)
with self.subTest(name='test_model_supported_collection_instance'):
self.assertIsInstance(state, flax.deprecated.nn.Collection)
def test_model_unsupported(self):
"""Tests unsupported models."""
with self.assertRaisesRegex(ValueError, 'No such model: unsupported'):
self._create_model('unsupported')
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/prune.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Symmetry: Iteratively Prune Model during Training.
Command for training and pruning an MNIST fully-connected model for 10 epochs
with a fixed pruning rate of 0.95:
prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10
--pruning_rate=0.95
Command for training and pruning an MNIST fully-connected model for 10
epochs, with pruning rates 0.3, 0.6 and 0.95 at epochs 2, 5, and 8 respectively
for all layers:
prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10
--pruning_schedule='[(2, 0.3), (5, 0.6), (8, 0.95)]'
Command for doing the same, but performing pruning only on the second layer:
prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10
--pruning_schedule="{'1': [(2, 0.3), (5, 0.6), (8, 0.95)]}"
"""
import ast
from collections import abc
import functools
from os import path
from typing import List
import uuid
from absl import app
from absl import flags
from absl import logging
import flax
from flax.metrics import tensorboard
from flax.training import lr_schedule
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils
experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))
logging.info('Saving experimental results to %s', experiment_dir)
host_count = jax.host_count()
local_device_count = jax.local_device_count()
logging.info('Device count: %d, host count: %d, local device count: %d',
jax.device_count(), host_count, local_device_count)
if jax.host_id() == 0:
summary_writer = tensorboard.SummaryWriter(experiment_dir)
dataset = dataset_factory.create_dataset(
FLAGS.dataset,
FLAGS.batch_size,
FLAGS.batch_size_test,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)
rng = jax.random.PRNGKey(FLAGS.random_seed)
input_shape = (1,) + dataset.shape
base_model, _ = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes)
initial_model, initial_state = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
masked_layer_indices=FLAGS.masked_layer_indices)
if FLAGS.optimizer == 'Adam':
optimizer = flax.optim.Adam(
learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
elif FLAGS.optimizer == 'Momentum':
optimizer = flax.optim.Momentum(
learning_rate=FLAGS.lr,
beta=FLAGS.momentum,
weight_decay=FLAGS.weight_decay,
nesterov=False)
steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size
if FLAGS.lr_schedule == LR_SCHEDULE_CONSTANT:
lr_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.lr, steps_per_epoch)
elif FLAGS.lr_schedule == LR_SCHEDULE_STEPPED:
lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, lr_schedule_steps)
elif FLAGS.lr_schedule == LR_SCHEDULE_COSINE:
lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, FLAGS.epochs)
else:
raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}')
# Reuses the FLAX learning rate schedule framework for pruning rate schedule.
pruning_fn_p = functools.partial(
lr_schedule.create_stepped_learning_rate_schedule, FLAGS.pruning_rate,
steps_per_epoch)
if FLAGS.pruning_schedule:
pruning_schedule = ast.literal_eval(FLAGS.pruning_schedule)
if isinstance(pruning_schedule, abc.Mapping):
pruning_rate_fn = {
f'MaskedModule_{layer_num}': pruning_fn_p(schedule)
for layer_num, schedule in pruning_schedule.items()
}
else:
pruning_rate_fn = pruning_fn_p(pruning_schedule)
else:
pruning_rate_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.pruning_rate, steps_per_epoch)
if jax.host_id() == 0:
trainer = training.Trainer(
optimizer,
initial_model,
initial_state,
dataset,
rng,
summary_writer=summary_writer,
)
else:
trainer = training.Trainer(
optimizer, initial_model, initial_state, dataset, rng)
_, best_metrics = trainer.train(
FLAGS.epochs,
lr_fn=lr_fn,
pruning_rate_fn=pruning_rate_fn,
update_iter=FLAGS.update_iterations,
update_epoch=FLAGS.update_epoch,
)
logging.info('Best metrics: %s', str(best_metrics))
if jax.host_id() == 0:
if FLAGS.dump_json:
utils.dump_dict_json(best_metrics,
path.join(experiment_dir, 'best_metrics.json'))
for label, value in best_metrics.items():
summary_writer.scalar(f'best/{label}', value,
FLAGS.epochs * steps_per_epoch)
summary_writer.close()
def main(argv: List[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
run_training()
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/experimental/jax/prune_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.prune."""
import glob
from os import path
from absl.testing import absltest
from absl.testing import flagsaver
from rigl.experimental.jax import prune
class PruneTest(absltest.TestCase):
def test_prune_fixed_schedule(self):
"""Tests training/pruning driver with a fixed global sparsity."""
experiment_dir = self.create_tempdir().full_path
eval_flags = dict(
epochs=1,
pruning_rate=0.95,
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
prune.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_prune_global_pruning_schedule(self):
"""Tests training/pruning driver with a global sparsity schedule."""
experiment_dir = self.create_tempdir().full_path
eval_flags = dict(
epochs=10,
pruning_schedule='[(5, 0.33), (7, 0.66), (9, 0.95)]',
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
prune.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_prune_local_pruning_schedule(self):
"""Tests training/pruning driver with a single layer sparsity schedule."""
experiment_dir = self.create_tempdir().full_path
eval_flags = dict(
epochs=10,
pruning_schedule='{1:[(5, 0.33), (7, 0.66), (9, 0.95)]}',
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
prune.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/pruning/init.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for initialization of masked models."""
import functools
from typing import Callable, Sequence, Optional
import flax
import jax
import jax.numpy as jnp
def sparse_init(
base_init,
mask,
dtype=jnp.float32):
"""Weight initializer with correct fan in/fan out for a masked model.
The weight initializer uses any dense initializer to correctly initialize a
masked weight matrix by calling the given initialization method with the
correct fan in/fan out for every neuron in the layer. If the mask is None, it
reverts to the original initialization method.
Args:
base_init: The base (dense) initialization method to use.
mask: The layer's mask, or None.
dtype: The weight array jnp.dtype.
Returns:
An initialization method that is mask aware for the given layer and mask.
"""
def init(rng, shape, dtype=dtype):
if mask is None:
return base_init(rng, shape, dtype)
# Find the ablated neurons in the mask, to determine correct fan_out.
neuron_weight_count = jnp.sum(
jnp.reshape(mask, (-1, mask.shape[-1])), axis=0)
non_zero_neurons = jnp.sum(neuron_weight_count != 0)
# Special case of completely ablated weight matrix/layer.
if jnp.sum(non_zero_neurons) == 0:
print('Empty weight mask!')
return jnp.zeros(shape, dtype)
# Neurons have different fan_in w/mask, build up initialization per-unit.
init_cols = []
rng, *split_rngs = jax.random.split(rng, mask.shape[-1] + 1)
for i in range(mask.shape[-1]):
# Special case of ablated neuron.
if neuron_weight_count[i] == 0:
init_cols.append(jnp.zeros(shape[:-1] + (1,), dtype))
continue
# Fake shape of weight matrix with correct fan_in, and fan_out.
sparse_shape = (int(neuron_weight_count[i]), int(non_zero_neurons))
# Use only the first column of init from initializer, since faked fan_out.
init = base_init(split_rngs[i], sparse_shape, dtype)[Ellipsis, 0]
# Expand out to full sparse array.
expanded_init = jnp.zeros(
mask[Ellipsis, i].shape,
dtype).flatten().at[jnp.where(mask[Ellipsis, i].flatten() == 1)].set(init)
expanded_init = jnp.reshape(expanded_init, mask[Ellipsis, i].shape)
init_cols.append(expanded_init[Ellipsis, jnp.newaxis])
return jnp.concatenate(init_cols, axis=-1)
return init
xavier_sparse_normal = glorot_sparse_normal = functools.partial(
sparse_init, flax.deprecated.nn.initializers.xavier_normal())
kaiming_sparse_normal = he_sparse_normal = functools.partial(
sparse_init, flax.deprecated.nn.initializers.kaiming_normal())
================================================
FILE: rigl/experimental/jax/pruning/init_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.pruning.init."""
from typing import Any, Mapping, Optional
from absl.testing import absltest
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
class MaskedDense(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=layer_mask,
kernel_init=flax.deprecated.nn.initializers.kaiming_normal())
class MaskedDenseSparseInit(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
*args,
mask = None,
**kwargs):
inputs = inputs.reshape(inputs.shape[0], -1)
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=layer_mask,
kernel_init=init.kaiming_sparse_normal(
layer_mask['kernel'] if layer_mask is not None else None),
**kwargs)
class MaskedCNN(flax.deprecated.nn.Module):
"""Single-layer CNN Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Conv,
kernel_size=(3, 3),
mask=layer_mask,
kernel_init=flax.deprecated.nn.initializers.kaiming_normal())
class MaskedCNNSparseInit(flax.deprecated.nn.Module):
"""Single-layer CNN Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
*args,
mask = None,
**kwargs):
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Conv,
kernel_size=(3, 3),
mask=layer_mask,
kernel_init=init.kaiming_sparse_normal(
layer_mask['kernel'] if layer_mask is not None else None),
**kwargs)
class InitTest(absltest.TestCase):
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._batch_size = 2
self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
self._input = jnp.ones(*self._input_shape)
def test_init_kaiming_sparse_normal_output(self):
"""Tests the output shape/type of kaiming normal sparse initialization."""
input_array = jnp.ones((64, 16), jnp.float32)
mask = jax.random.bernoulli(self._rng, shape=(64, 16))
base_init = flax.deprecated.nn.initializers.kaiming_normal()(
self._rng, input_array.shape, input_array.dtype)
sparse_init = init.kaiming_sparse_normal(mask)(self._rng, input_array.shape,
input_array.dtype)
with self.subTest(name='test_sparse_init_output_shape'):
self.assertSequenceEqual(sparse_init.shape, base_init.shape)
with self.subTest(name='test_sparse_init_output_dtype'):
self.assertEqual(sparse_init.dtype, base_init.dtype)
with self.subTest(name='test_sparse_init_output_notallzero'):
self.assertTrue((sparse_init != 0).any())
def test_dense_no_mask(self):
"""Checks that in the special case of no mask, init is same as base_init."""
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._input_shape,))
self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)
_, initial_params = MaskedDenseSparseInit.init_by_shape(
jax.random.PRNGKey(42), (self._input_shape,), mask=None)
self._masked_model_sparse_init = flax.deprecated.nn.Model(
MaskedDenseSparseInit, initial_params)
self.assertTrue(
jnp.isclose(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'], self._unmasked_model.params['MaskedModule_0']
['unmasked']['kernel']).all())
def test_dense_sparse_init_kaiming(self):
"""Checks kaiming normal sparse initialization for dense layer."""
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._input_shape,))
self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)
mask = masked.simple_mask(self._unmasked_model, jnp.ones,
masked.WEIGHT_PARAM_NAMES)
_, initial_params = MaskedDenseSparseInit.init_by_shape(
jax.random.PRNGKey(42), (self._input_shape,), mask=mask)
self._masked_model_sparse_init = flax.deprecated.nn.Model(
MaskedDenseSparseInit, initial_params)
mean_init = jnp.mean(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
stddev_init = jnp.std(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
mean_sparse_init = jnp.mean(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
stddev_sparse_init = jnp.std(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
with self.subTest(name='test_cnn_sparse_init_mean'):
self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,
mean_init + 2 * stddev_init)
with self.subTest(name='test_cnn_sparse_init_stddev'):
self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,
1.5 * stddev_init)
def test_cnn_sparse_init_kaiming(self):
"""Checks kaiming normal sparse initialization for convolutional layer."""
_, initial_params = MaskedCNN.init_by_shape(self._rng, (self._input_shape,))
self._unmasked_model = flax.deprecated.nn.Model(MaskedCNN, initial_params)
mask = masked.simple_mask(self._unmasked_model, jnp.ones,
masked.WEIGHT_PARAM_NAMES)
_, initial_params = MaskedCNNSparseInit.init_by_shape(
jax.random.PRNGKey(42), (self._input_shape,), mask=mask)
self._masked_model_sparse_init = flax.deprecated.nn.Model(
MaskedCNNSparseInit, initial_params)
mean_init = jnp.mean(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
stddev_init = jnp.std(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
mean_sparse_init = jnp.mean(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
stddev_sparse_init = jnp.std(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
with self.subTest(name='test_cnn_sparse_init_mean'):
self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,
mean_init + 2 * stddev_init)
with self.subTest(name='test_cnn_sparse_init_stddev'):
self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,
1.5 * stddev_init)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/mask_factory.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pruning mask factory.
Attributes:
MaskFnType: A type alias for functions to create sparse masks.
MASK_TYPES: Masks types that can be created.
"""
from typing import Any, Callable, Mapping
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import masked
# A function to create a mask, takes as arguments: a flax model, JAX PRNG Key,
# sparsity level as a float in [0, 1].
MaskFnType = Callable[
[flax.deprecated.nn.Model, Callable[[int],
jnp.array], float], masked.MaskType]
MASK_TYPES: Mapping[str, MaskFnType] = {
'random':
masked.shuffled_mask,
'per_neuron':
masked.shuffled_neuron_mask,
'per_neuron_no_input_ablation':
masked.shuffled_neuron_no_input_ablation_mask,
'symmetric':
masked.symmetric_mask,
}
def create_mask(mask_type, base_model,
rng, sparsity,
**kwargs):
"""Creates a Mask of the given type.
Args:
mask_type: the name of the type of mask to instantiate.
base_model: the model to create a mask for.
rng : the random number generator to use for init.
sparsity: the mask sparsity.
**kwargs: list of model specific keyword arguments.
Returns:
A mask for a FLAX model.
Raises:
ValueError if a model with the given name does not exist.
"""
if mask_type not in MASK_TYPES:
raise ValueError(f'Unknown mask type: {mask_type}')
return MASK_TYPES[mask_type](base_model, rng, sparsity, **kwargs)
================================================
FILE: rigl/experimental/jax/pruning/mask_factory_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.model_factory."""
from typing import Mapping, Optional
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import mask_factory
from rigl.experimental.jax.pruning import masked
class MaskedDense(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask else None)
class MaskFactoryTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._input_shape = ((1, 28, 28, 1), jnp.float32)
self._num_classes = 10
self._sparsity = 0.9
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._input_shape,))
# Use the same initialization for both masked/unmasked models.
self._model = flax.deprecated.nn.Model(MaskedDense, initial_params)
def _create_mask(self, mask_type):
return mask_factory.create_mask(
mask_type, self._model,
self._rng, self._sparsity)
@parameterized.parameters(*mask_factory.MASK_TYPES.keys())
def test_mask_supported(self, mask_type):
"""Tests supported mask types."""
mask = self._create_mask(mask_type)
with self.subTest(name='test_mask_type'):
self.assertIsInstance(mask, dict)
def test_mask_unsupported(self):
"""Tests unsupported mask types."""
with self.assertRaisesRegex(ValueError,
'Unknown mask type: unsupported'):
self._create_mask('unsupported')
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/masked.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Masked wrapped for FLAX modules.
Attributes:
WEIGHT_PARAM_NAMES: The name of the weight parameters to use.
MaskType: Model mask type for static type checking.
MaskLayerType: Mask layer type for static type checking.
MutableMaskType: Mutable model mask type for static type checking.
MutableMaskLayerType: Mutable mask layer type for static type checking.
"""
import functools
import operator
from typing import Any, Callable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple, Type
from absl import logging
import flax
import jax
import jax.numpy as jnp
import jax.ops
# Model weight param names, e.g. 'kernel', (as opposed batch norm param, etc).
WEIGHT_PARAM_NAMES = ('kernel',) # Note: Bias is not typically masked.
# Mask layer type for static type checking.
MaskLayerType = Mapping[str, Optional[jnp.array]]
# Model mask type for static type checking.
MaskType = Mapping[str, Optional[MaskLayerType]]
# Mask layer type for static type checking.
MutableMaskLayerType = MutableMapping[str, Optional[jnp.array]]
# Model mask type for static type checking.
MutableMaskType = MutableMapping[str, MutableMaskLayerType]
class MaskedModule(flax.deprecated.nn.Module):
"""Generic FLAX Masking Module.
Masks a FLAX module, given a mask for params of each layer.
Attributes:
UNMASKED: The key to use for the unmasked parameter dictionary.
"""
UNMASKED = 'unmasked'
def apply(self,
*args,
wrapped_module,
mask = None,
**kwargs):
"""Apply the wrapped module, while applying the given masks to its params.
Args:
*args: The positional arguments for the wrapped module.
wrapped_module: The module class to be wrapped.
mask: The mask nested dictionary containing masks for the wrapped module's
params, in the same format/with the same keys as the module param dict
(or None if not to mask).
**kwargs: The keyword arguments for the wrapped module.
Returns:
The intermediate outputs specified by truncate_path.
Raises:
ValueError: If the given mask is not valid for the wrapped module, i.e. the
pytrees do not match.
"""
# Explicitly create the parameters of the wrapped module.
def init_fn(rng, input_shape):
del input_shape # Unused.
# Call init to get the params of the wrapped module.
_, params = wrapped_module.init(rng, *args, **kwargs)
return params
unmasked_params = self.param(self.UNMASKED, None, init_fn)
if mask is not None:
try:
masked_params = jax.tree_util.tree_map(
lambda x, *xs: x
if xs[0] is None else x * xs[0], unmasked_params, mask)
except ValueError as err:
raise ValueError('Mask is invalid for model.') from err
# Call the wrapped module with the masked params.
return wrapped_module.call(masked_params, *args, **kwargs)
else:
logging.warning('Using masked module without mask!')
# Call the wrapped module with the unmasked params.
return wrapped_module.call(unmasked_params, *args, **kwargs)
def masked(module, mask):
"""Convenience function for masking a FLAX module with MaskedModule."""
return MaskedModule.partial(wrapped_module=module, mask=mask)
def generate_model_masks(
depth,
mask = None,
masked_layer_indices = None):
"""Creates empty masks for this model, or initializes with existing mask.
Args:
depth: Number of layers in the model.
mask: Existing model mask for layers in this model, if not given, all
module masks are initialized to None.
masked_layer_indices: The layer indices of layers in model to be masked, or
all if None.
Returns:
A model mask, with None where no mask is given for a model layer, or that
specific layer is indicated as not to be masked by the masked_layer_indices
parameter.
"""
if depth <= 0:
raise ValueError(f'Invalid model depth: {depth}')
if mask is None:
mask = {f'MaskedModule_{i}': None for i in range(depth)}
# Have to explicitly check for None to differentiate from empty array.
if masked_layer_indices is not None:
# Check none of the indices are outside of model's layer bounds.
if any(i < 0 or i >= depth for i in masked_layer_indices):
raise ValueError(
f'Invalid indices for given depth ({depth}): {masked_layer_indices}')
mask = {
f'MaskedModule_{i}': mask[f'MaskedModule_{i}']
for i in masked_layer_indices
}
return mask
def _filter_param(param_names,
invert = False):
"""Convenience function for filtering maskable parameters from paths.
Args:
param_names: Names of parameters we are looking for.
invert: Inverts filter to exclude, rather than include, given parameters.
Returns:
A function to use with flax.deprecated.nn.optim.ModelParamTraversal for
filtering.
"""
def filter_fn(path, value):
del value # Unused.
parameter_found = any([
'{}/{}'.format(MaskedModule.UNMASKED, param_name) in path
for param_name in param_names
])
return not parameter_found if invert else parameter_found
return filter_fn
def mask_map(model,
fn):
"""Convenience function to create a mask for a model.
Args:
model: The Flax model, with at least one MaskedModule layer.
fn: The function to call on each masked parameter, to create the mask for
that parameter, takes the parameter name, and parameter value as arguments
and returns the new parameter value.
Returns:
A model parameter dictionary, with all masked parameters set by the given
function, and all other parameters set to None.
Raises:
ValueError: If the given model does not support masking, i.e. none of the
layers are wrapped by a MaskedModule.
"""
maskable = False
for layer_key, layer in model.params.items():
if MaskedModule.UNMASKED not in layer:
logging.warning(
'Layer \'%s\' does not support masking, i.e. it is not '
'wrapped by a MaskedModule', layer_key)
else:
maskable = True
if not maskable:
raise ValueError('Model does not support masking, i.e. no layers are '
'wrapped by a MaskedModule.')
# First set all non-masked params to None in copy of model pytree.
filter_non_masked = _filter_param(WEIGHT_PARAM_NAMES, invert=True)
nonmasked_traversal = flax.optim.ModelParamTraversal(filter_non_masked) # pytype: disable=module-attr
mask_model = nonmasked_traversal.update(lambda _: None, model)
# Then find params to mask, and set to array.
for param_name in WEIGHT_PARAM_NAMES:
filter_masked = _filter_param(WEIGHT_PARAM_NAMES)
mask_traversal = flax.optim.ModelParamTraversal(filter_masked) # pytype: disable=module-attr
mask_model = mask_traversal.update(
functools.partial(fn, param_name), mask_model)
mask = mask_model.params
# Remove unneeded unmasked param for mask.
for layer_key, layer in mask.items():
if MaskedModule.UNMASKED in layer:
mask[layer_key] = layer[MaskedModule.UNMASKED]
return mask
def iterate_mask(
mask,
param_names = None
):
"""Iterate over the parameters in as mask.
Args:
mask: The model mask.
param_names: The parameter names to iterate over in each layer, if None
iterates over all parameters of all layers.
Yields:
An iterator of tuples containing the parameter path and parameter value
in sorted order of layer parameters matching the names in param_names (or
all parameters if None).
"""
flat_mask = flax.traverse_util.flatten_dict(mask)
for key, value in flat_mask.items():
if param_names is None or key in param_names:
path = '/' + '/'.join(key)
yield path, value
def shuffled_mask(model, rng,
sparsity):
"""Returns a randomly shuffled mask with a given sparsity for all layers.
Returns a random weight mask for a model param array, by randomly shuffling a
mask with a fixed number of non-zero/zero entries, given by the sparsity.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
Returns:
A randomly shuffled weight mask, in the same form as flax.Module.params.
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are
maskable, i.e. is wrapped by MaskedModule.
"""
if sparsity > 1 or sparsity < 0:
raise ValueError(
'Given sparsity, {}, is not in range [0, 1]'.format(sparsity))
def create_shuffled_mask(param_name, param):
del param_name # Unused.
mask = jnp.arange(param.size)
mask = jnp.where(mask >= sparsity * param.size, jnp.ones_like(mask),
jnp.zeros_like(mask))
mask = jax.random.permutation(rng, mask)
return mask.reshape(param.shape)
return mask_map(model, create_shuffled_mask)
def random_mask(model,
rng,
mean_sparsity = 0.5):
"""Returns a random weight mask for a masked model.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
mean_sparsity: The mean number of 0's in the mask, i.e. mean = (1 -
mean_sparsity) for the Bernoulli distribution to sample from.
Returns:
A random weight mask, in the same form as flax.Module.params
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or if a layer to
mask is not maskable, i.e. is not wrapped by MaskedModule.
"""
if mean_sparsity > 1 or mean_sparsity < 0:
raise ValueError(
'Given sparsity, {}, is not in range [0, 1]'.format(mean_sparsity))
# Invert mean_sparsity to get mean for Bernoulli distribution.
mean = 1. - mean_sparsity
def create_random_mask(param_name, param):
del param_name # Unused.
return jax.random.bernoulli(
rng, p=mean,
shape=param.shape).astype(jnp.int32) # TPU doesn't support uint8.
return mask_map(model, create_random_mask)
def simple_mask(model,
init_fn,
masked_param):
"""Creates a mask given a model and numpy initialization function.
Args:
model: The model to create a mask for.
init_fn: The numpy initialization function, e.g. numpy.ones.
masked_param: The list of parameters to mask.
Returns:
A mask for the model.
"""
def create_init_fn_mask(param_name, param):
if param_name in masked_param:
return init_fn(param.shape)
return None
return mask_map(model, create_init_fn_mask)
def symmetric_mask(model,
rng,
sparsity = 0.5):
"""Generates a random weight mask that's symmetric, i.e. structurally pruned.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), in the
range [0, 1]: 1.0 will mask all weights, while 0 will mask none.
Returns:
A symmetric random weight mask, in the same form as flax.Module.params.
"""
if sparsity > 1 or sparsity < 0:
raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')
def create_neuron_symmetric_mask(param_name, param):
del param_name # Unused.
neuron_length = functools.reduce(operator.mul, param.shape[:-1])
neuron_mask = jnp.arange(neuron_length)
neuron_mask = jnp.where(neuron_mask >= sparsity * neuron_mask.size,
jnp.ones_like(neuron_mask),
jnp.zeros_like(neuron_mask))
neuron_mask = jax.random.shuffle(rng, neuron_mask)
mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1)
return mask.reshape(param.shape)
return mask_map(model, create_neuron_symmetric_mask)
class _PerNeuronShuffle:
"""This class is needed to get around the fact that JAX RNG is stateless."""
def __init__(self, init_rng, sparsity):
"""Creates the per-neuron shuffle class, with initial RNG state.
Args:
init_rng: The initial random number generator state to use.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
"""
self._rng = init_rng
self._sparsity = sparsity
def __call__(self, param_name, param):
"""Shuffles the weight matrix/mask for a given parameter, per-neuron.
This is to be used with mask_map, and accepts the standard mask_map
function parameters.
Args:
param_name: The parameter's name.
param: The parameter's weight or mask matrix.
Returns:
A shuffled weight/mask matrix, with each neuron shuffled independently.
"""
del param_name # Unused.
neuron_length = functools.reduce(operator.mul, param.shape[:-1])
neuron_mask = jnp.arange(neuron_length)
neuron_mask = jnp.where(neuron_mask >= self._sparsity * neuron_mask.size,
jnp.ones_like(neuron_mask),
jnp.zeros_like(neuron_mask))
mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1)
self._rng, rng_input = jax.random.split(self._rng)
mask = jax.random.shuffle(rng_input, mask, axis=0)
return mask.reshape(param.shape)
def shuffled_neuron_mask(model,
rng,
sparsity):
"""Returns a shuffled mask with a given fixed sparsity for all neurons/layers.
Returns a randomly shuffled weight mask for a model param array, by setting a
fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector
in the model, and then randomly shuffling each neuron's weight mask with a
fixed number of non-zero/zero entries, given by the sparsity. This ensures no
neuron is ablated for a non-zero sparsity.
Note: This is much more complicated for convolutional layers due to the
receptive field being different for every pixel! We only take into account
channel-wise masks and not spatial ablations in propagation in that case.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
Returns:
A randomly shuffled weight mask, in the same form as flax.Module.params.
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are
maskable, i.e. is wrapped by MaskedModule.
"""
if sparsity > 1 or sparsity < 0:
raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')
return mask_map(model, _PerNeuronShuffle(rng, sparsity))
def _fill_diagonal_wrap(shape,
value,
dtype = jnp.uint8):
"""Fills the diagonal of a 2D array, while also wrapping tall arrays.
For a matrix of dimensions (N x M),:
if N <= M, i.e. the array is wide rectangular, the array's diagonal is
filled, for example:
_fill_diagonal_wrap(jnp.zeroes((2, 3), dtype=uint8), 1)
> [[1, 0, 0],
[0, 1, 0]]
if N > M, i.e. the array is tall rectangular, the array's diagonal, and
offset diagonals are filled. This differs from
numpy.fill_diagonal(..., wrap=True), in that it does not include a single
row gap between the diagonals, and it is not in-place but returns a copy of
the given array. For example,
_fill_diagonal_wrap(jnp.zeroes((3, 2), dtype=uint8), 1)
> [[1, 0],
[0, 1],
[1, 0]]
Args:
shape: The shape of the 2D array to return with the diagonal filled.
value: The value to fill in for the diagonal, and offset diagonals.
dtype: The datatype of the jax numpy array to return.
Returns:
A copy of the given array with the main diagonal filled, and offset
diagonals filled if the given array is tall.
"""
if len(shape) != 2:
raise ValueError(
f'Expected an 2D array, however array has dimensions: {shape}')
array = jnp.zeros(shape, dtype=dtype)
rows, cols = shape
def diagonal_indices(offset): # Returns jax.ops._Indexable.
"""Returns slice of the nth diagonal of an array, where n is offset."""
# This is an a numpy-style advanced slice of the form [start:end:step], that
# gives you the offset (vertically) diagonal of an array. If it was the main
# diagonal of a (flattened) square matrix of n X n it would be 0:n**2:n+1,
# i.e. start at 0, and look at each n+1 elements, end when you get to end
# of array. We need to look at vertically-offset diagonals as well, which is
# handled by offset.
return jnp.index_exp[cols * offset:cols * (offset + cols):cols + 1]
# Fills (square) matrix diagonals with the given value, tiling over tall
# rectangular arrays by offsetting the filled diagonals by multiples of the
# height of the square arrays.
diagonals = [
array.ravel().at[diagonal_indices(offset)].set(value).reshape(array.shape)
for offset in range(0, rows, cols)
]
return functools.reduce(jnp.add, diagonals)
def _random_neuron_mask(neuron_length,
unmasked_count,
rng,
dtype = jnp.uint32):
"""Generates a random mask for a neuron.
Args:
neuron_length: The length of the neuron's weight vector.
unmasked_count: The number of elements that should be unmasked.
rng: A jax.random.PRNGKey random seed.
dtype: Type of array to create.
Returns:
A random neuron weight vector mask.
"""
if unmasked_count > neuron_length:
raise ValueError('unmasked_count cannot be greater that neuron_length: '
f'{unmasked_count} > {neuron_length}')
neuron_mask = jnp.concatenate(
(jnp.ones(unmasked_count), jnp.zeros(neuron_length - unmasked_count)),
axis=0)
neuron_mask = jax.random.shuffle(rng, neuron_mask)
return neuron_mask.astype(dtype)
class _PerNeuronNoInputAblationShuffle:
"""This class is needed to get around the fact that JAX RNG is stateless."""
def __init__(self, init_rng, sparsity):
"""Creates the per-neuron shuffle class, with initial RNG state.
Args:
init_rng: The initial random number generator state to use.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
"""
self._rng = init_rng
self._sparsity = sparsity
def _get_rng(self):
"""Creates a new JAX RNG, while updating RNG state."""
self._rng, rng_input = jax.random.split(self._rng)
return rng_input
def __call__(self, param_name, param):
"""Shuffles the weight matrix/mask for a given parameter, per-neuron.
This is to be used with mask_map, and accepts the standard mask_map
function parameters.
Args:
param_name: The parameter's name.
param: The parameter's weight or mask matrix.
Returns:
A shuffled weight/mask matrix, with each neuron shuffled independently.
"""
del param_name # Unused.
incoming_connections = jnp.prod(jnp.array(param.shape[:-1]))
num_neurons = param.shape[-1]
# Ensure each input neuron has at least one connection unmasked.
mask = _fill_diagonal_wrap((incoming_connections, num_neurons), 1,
dtype=jnp.uint8)
# Randomly shuffle which of the neurons have these connections.
mask = jax.random.shuffle(self._get_rng(), mask, axis=0)
# Add extra required random connections to mask to satisfy sparsity.
mask_cols = []
for col in range(mask.shape[-1]):
neuron_mask = mask[:, col]
off_diagonal_count = max(
round((1 - self._sparsity) * incoming_connections)
- jnp.count_nonzero(neuron_mask), 0)
zero_indices = jnp.flatnonzero(neuron_mask == 0)
random_entries = _random_neuron_mask(
len(zero_indices), off_diagonal_count, self._get_rng())
neuron_mask = neuron_mask.at[zero_indices].set(random_entries)
mask_cols.append(neuron_mask)
return jnp.column_stack(mask_cols).reshape(param.shape)
def shuffled_neuron_no_input_ablation_mask(model,
rng,
sparsity):
"""Returns a shuffled mask with a given fixed sparsity for all neurons/layers.
Returns a randomly shuffled weight mask for a model param array, by setting a
fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector
in the model, and then randomly shuffling each neuron's weight mask with a
fixed number of non-zero/zero entries, given by the sparsity. This ensures no
neuron is ablated for a non-zero sparsity.
This function also ensures that no neurons in the previous layer are
effectively ablated, by ensuring that each neuron has at least one connection.
Note: This is much more complicated for convolutional layers due to the
receptive field being different for every pixel! We only take into account
channel-wise masks and not spatial ablations in propagation in that case.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, except for the minimum number required to maintain,
connectivity with the input layer, while 0 will mask none.
Returns:
A randomly shuffled weight mask, in the same form as flax.Module.params.
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are
maskable, i.e. is wrapped by MaskedModule.
"""
if sparsity > 1.0 or sparsity < 0.0:
raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')
# First, generate a random permutation matrix, and ensure our mask has at
# least N connections, where there are N neurons in the previous layer.
return mask_map(model, _PerNeuronNoInputAblationShuffle(rng, sparsity))
def propagate_masks(
mask,
param_names = WEIGHT_PARAM_NAMES
):
"""Accounts for implicitly pruned neurons in a model's weight masks.
When neurons are randomly ablated in one layer, they can effectively ablate
neurons in the next layer if in effect all incoming weights of a neuron are
zero. This method accounts for this by propagating forward mask information
through the entire model.
Args:
mask: Model masks to check, in same pytree structure as Model.params.
param_names: List of param keys in mask to count.
Returns:
A refined model mask with weights that are effectively ablated in the
original mask set to zero.
"""
flat_mask = flax.traverse_util.flatten_dict(mask)
mask_layer_list = list(flat_mask.values())
mask_layer_keys = list(flat_mask.keys())
mask_layer_param_names = [layer_param[-1] for layer_param in mask_layer_keys]
for param_name in param_names:
# Find which of the param arrays correspond to leaf nodes with this name.
param_indices = [
i for i, names in enumerate(mask_layer_param_names)
if param_name in names
]
for i in range(1, len(param_indices)):
last_weight_mask = mask_layer_list[param_indices[i - 1]]
weight_mask = mask_layer_list[param_indices[i]]
if last_weight_mask is None or weight_mask is None:
continue
last_weight_mask_reshaped = jnp.reshape(last_weight_mask,
(-1, last_weight_mask.shape[-1]))
# Neurons with any outgoing weights from previous layer.
alive_incoming = jnp.sum(last_weight_mask_reshaped, axis=0) != 0
# Combine effective mask of previous layer with neuron's current mask.
if len(weight_mask.shape) > 2:
# Convolutional layer, only consider channel-wise masks, if any spatial
# weight is non-zero that channel is considered non-masked.
spatial_dim = len(weight_mask.shape) - 2
new_weight_mask = alive_incoming[:, jnp.newaxis] * jnp.amax(
weight_mask, axis=tuple(range(spatial_dim)))
new_weight_mask = jnp.tile(new_weight_mask,
weight_mask.shape[:-2] + (1, 1))
else:
# Check for case of dense following convolution, i.e. spatial input into
# dense, to prevent b/156135283. Must use convolution for these layers.
if len(last_weight_mask.shape) > 2:
raise ValueError(
'propagate_masks requires knowledge of the spatial '
'dimensions of the previous layer. Use a functionally equivalent '
'conv. layer in place of a dense layer in a model with a mixed '
'conv/dense setting.')
new_weight_mask = alive_incoming[:, jnp.newaxis] * weight_mask
mask_layer_list[param_indices[i]] = jnp.reshape(
new_weight_mask, mask_layer_list[param_indices[i]].shape)
return flax.traverse_util.unflatten_dict(
dict(zip(mask_layer_keys, mask_layer_list)))
def mask_layer_sparsity(mask_layer):
"""Calculates the sparsity of a single layer's mask.
Args:
mask_layer: mask layer to calculate the sparsity of.
Returns:
The sparsity of the mask.
"""
parameter_count = 0
masked_count = 0
for key in mask_layer:
if mask_layer[key] is not None and key in WEIGHT_PARAM_NAMES:
parameter_count += mask_layer[key].size
masked_count += jnp.sum(mask_layer[key])
if parameter_count == 0:
return 0.
return 1. - masked_count/parameter_count
def mask_sparsity(
mask,
param_names = None):
"""Calculates the sparsity of the given parameters over a model mask.
Args:
mask: Model mask to calculate sparsity over.
param_names: List of param keys in mask to count.
Returns:
The overall sparsity of the mask.
"""
if param_names is None:
param_names = WEIGHT_PARAM_NAMES
parameter_count = 0
masked_count = 0
for path, value in iterate_mask(mask):
if value is not None and any(
param_name in path for param_name in param_names):
parameter_count += value.size
masked_count += jnp.sum(value.flatten())
if parameter_count == 0:
return 0.
return 1.0 - float(masked_count / parameter_count)
================================================
FILE: rigl/experimental/jax/pruning/masked_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.pruning.masked."""
from typing import Mapping, Optional, Sequence
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
import numpy as np
from rigl.experimental.jax.pruning import masked
class Dense(flax.deprecated.nn.Module):
"""Single-layer Dense Non-Masked Network."""
NUM_FEATURES: int = 32
def apply(self, inputs):
inputs = inputs.reshape(inputs.shape[0], -1)
return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES)
class MaskedDense(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask else None)
class DenseTwoLayer(flax.deprecated.nn.Module):
"""Two-layer Dense Non-Masked Network."""
NUM_FEATURES: Sequence[int] = (32, 64)
def apply(self, inputs):
inputs = inputs.reshape(inputs.shape[0], -1)
inputs = flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[0])
return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[1])
class MaskedTwoLayerDense(flax.deprecated.nn.Module):
"""Two-layer Dense Masked Network."""
NUM_FEATURES: Sequence[int] = (32, 64)
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[0],
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask else None)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[1],
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_1'] if mask else None)
class MaskedConv(flax.deprecated.nn.Module):
"""Single-layer Conv Masked Network."""
NUM_FEATURES: int = 16
def apply(self,
inputs,
mask = None):
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
kernel_size=(3, 3),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_0'] if mask is not None else None)
class MaskedTwoLayerConv(flax.deprecated.nn.Module):
"""Two-layer Conv Masked Network."""
NUM_FEATURES: Sequence[int] = (16, 32)
def apply(self,
inputs,
mask = None):
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[0],
kernel_size=(5, 5),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_0'] if mask is not None else None)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[1],
kernel_size=(3, 3),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_1'] if mask is not None else None)
class MaskedThreeLayerConvDense(flax.deprecated.nn.Module):
"""Three-layer Conv Masked Network with Dense layer."""
NUM_FEATURES: Sequence[int] = (16, 32, 64)
def apply(self,
inputs,
mask = None):
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[0],
kernel_size=(5, 5),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_0'] if mask is not None else None)
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[1],
kernel_size=(3, 3),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_1'] if mask is not None else None)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[2],
kernel_size=inputs.shape[1:-1],
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_2'] if mask is not None else None)
class MaskedTwoLayerMixedConvDense(flax.deprecated.nn.Module):
"""Two-layer Mixed Conv/Dense Masked Network."""
NUM_FEATURES: Sequence[int] = (16, 32)
def apply(self,
inputs,
mask = None):
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[0],
kernel_size=(5, 5),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_0'] if mask is not None else None)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[1],
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_1'] if mask is not None else None)
class MaskedTest(parameterized.TestCase):
"""Tests the flax layer mask."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._batch_size = 2
self._input_dimensions = (28, 28, 1)
self._input_shape = ((self._batch_size,) + self._input_dimensions,
jnp.float32)
self._input = jnp.ones(*self._input_shape)
_, initial_params = Dense.init_by_shape(self._rng, (self._input_shape,))
self._unmasked_model = flax.deprecated.nn.Model(Dense, initial_params)
self._unmasked_output = self._unmasked_model(self._input)
# Use the same initialization for both masked/unmasked models.
masked_initial_params = {
'MaskedModule_0': {
'unmasked': initial_params['Dense_0']
}
}
self._masked_model = flax.deprecated.nn.Model(MaskedDense,
masked_initial_params)
_, initial_params = DenseTwoLayer.init_by_shape(self._rng,
(self._input_shape,))
self._unmasked_model_twolayer = flax.deprecated.nn.Model(
DenseTwoLayer, initial_params)
self._unmasked_output_twolayer = self._unmasked_model_twolayer(self._input)
# Use the same initialization for both masked/unmasked models.
masked_initial_params = {
'MaskedModule_0': {
'unmasked': initial_params['Dense_0']
},
'MaskedModule_1': {
'unmasked': initial_params['Dense_1']
},
}
_, initial_params = MaskedTwoLayerDense.init_by_shape(
self._rng, (self._input_shape,))
self._masked_model_twolayer = flax.deprecated.nn.Model(
MaskedTwoLayerDense, masked_initial_params)
_, initial_params = MaskedConv.init_by_shape(self._rng,
(self._input_shape,))
self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv,
initial_params)
_, initial_params = MaskedTwoLayerConv.init_by_shape(
self._rng, (self._input_shape,))
self._masked_conv_model_twolayer = flax.deprecated.nn.Model(
MaskedTwoLayerConv, initial_params)
_, initial_params = MaskedTwoLayerMixedConvDense.init_by_shape(
self._rng, (self._input_shape,))
self._masked_mixed_model_twolayer = flax.deprecated.nn.Model(
MaskedTwoLayerMixedConvDense, initial_params)
_, initial_params = MaskedThreeLayerConvDense.init_by_shape(
self._rng, (self._input_shape,))
self._masked_conv_fc_model_threelayer = flax.deprecated.nn.Model(
MaskedThreeLayerConvDense, initial_params)
def test_fully_masked_layer(self):
"""Tests masked module with full-sparsity mask."""
full_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])
masked_output = self._masked_model(self._input, mask=full_mask)
with self.subTest(name='fully_masked_dense_values'):
self.assertTrue((masked_output == 0).all())
with self.subTest(name='fully_masked_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_no_mask_masked_layer(self):
"""Tests masked module with no mask."""
masked_output = self._masked_model(self._input, mask=None)
with self.subTest(name='no_mask_masked_dense_values'):
self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())
with self.subTest(name='no_mask_masked_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_empty_mask_masked_layer(self):
"""Tests masked module with an empty (not sparse) mask."""
empty_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])
masked_output = self._masked_model(self._input, mask=empty_mask)
with self.subTest(name='empty_mask_masked_dense_values'):
self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())
with self.subTest(name='empty_mask_masked_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_invalid_mask(self):
"""Tests using an invalid mask."""
invalid_mask = {
'MaskedModule_0': {
'not_kernel':
jnp.ones(self._unmasked_model.params['Dense_0']['kernel'].shape)
}
}
with self.assertRaisesRegex(ValueError, 'Mask is invalid for model.'):
self._masked_model(self._input, mask=invalid_mask)
def test_shuffled_mask_invalid_model(self):
"""Tests shuffled mask with model containing no masked layers."""
with self.assertRaisesRegex(
ValueError, 'Model does not support masking, i.e. no layers are '
'wrapped by a MaskedModule.'):
masked.shuffled_mask(self._unmasked_model, self._rng, 0.5)
def test_shuffled_mask_invalid_sparsity(self):
"""Tests shuffled mask with invalid sparsity."""
with self.subTest(name='sparsity_too_small'):
with self.assertRaisesRegex(
ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'):
masked.shuffled_mask(self._masked_model, self._rng, -0.5)
with self.subTest(name='sparsity_too_large'):
with self.assertRaisesRegex(
ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'):
masked.shuffled_mask(self._masked_model, self._rng, 1.5)
def test_shuffled_mask_sparsity_full(self):
"""Tests shuffled mask generation, for 100% sparsity."""
mask = masked.shuffled_mask(self._masked_model, self._rng, 1.0)
with self.subTest(name='shuffled_full_mask'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_full_mask_values'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())
with self.subTest(name='shuffled_full_mask_not_masked_values'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='shuffled_full_mask_dense_values'):
self.assertTrue((masked_output == 0).all())
with self.subTest(name='shuffled_full_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_shuffled_mask_sparsity_empty(self):
"""Tests shuffled mask generation, for 0% sparsity."""
mask = masked.shuffled_mask(self._masked_model, self._rng, 0.0)
with self.subTest(name='shuffled_empty_mask'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_empty_mask_values'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())
with self.subTest(name='shuffled_empty_mask_not_masked_values'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='shuffled_empty_dense_values'):
self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())
with self.subTest(name='shuffled_empty_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_shuffled_mask_sparsity_half_full(self):
"""Tests shuffled mask generation, for a half-full mask."""
mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5)
param_len = self._masked_model.params['MaskedModule_0']['unmasked'][
'kernel'].size
with self.subTest(name='shuffled_mask_values'):
self.assertEqual(
jnp.sum(mask['MaskedModule_0']['kernel']), param_len // 2)
def test_shuffled_mask_sparsity_full_twolayer(self):
"""Tests shuffled mask generation for two layers, and 100% sparsity."""
mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 1.0)
with self.subTest(name='shuffled_full_mask_layer1'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_full_mask_values_layer1'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())
with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
with self.subTest(name='shuffled_full_mask_layer2'):
self.assertIn('MaskedModule_1', mask)
with self.subTest(name='shuffled_full_mask_values_layer2'):
self.assertTrue((mask['MaskedModule_1']['kernel'] == 0).all())
with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):
self.assertIsNone(mask['MaskedModule_1']['bias'])
masked_output = self._masked_model_twolayer(self._input, mask=mask)
with self.subTest(name='shuffled_full_mask_dense_values'):
self.assertTrue((masked_output == 0).all())
with self.subTest(name='shuffled_full_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape,
self._unmasked_output_twolayer.shape)
def test_shuffled_mask_sparsity_empty_twolayer(self):
"""Tests shuffled mask generation for two layers, for 0% sparsity."""
mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 0.0)
with self.subTest(name='shuffled_empty_mask_layer1'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_empty_mask_values_layer1'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())
with self.subTest(name='shuffled_empty_mask_layer2'):
self.assertIn('MaskedModule_1', mask)
with self.subTest(name='shuffled_empty_mask_values_layer2'):
self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all())
masked_output = self._masked_model_twolayer(self._input, mask=mask)
with self.subTest(name='shuffled_empty_dense_values'):
self.assertTrue(
jnp.isclose(masked_output, self._unmasked_output_twolayer).all())
with self.subTest(name='shuffled_empty_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape,
self._unmasked_output_twolayer.shape)
def test_random_invalid_model(self):
"""Tests random mask with model containing no masked layers."""
with self.assertRaisesRegex(
ValueError, 'Model does not support masking, i.e. no layers are '
'wrapped by a MaskedModule.'):
masked.random_mask(self._unmasked_model, self._rng, 0.5)
def test_random_invalid_sparsity(self):
"""Tests random mask with invalid sparsity."""
with self.subTest(name='random_sparsity_too_small'):
with self.assertRaisesRegex(
ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'):
masked.random_mask(self._masked_model, self._rng, -0.5)
with self.subTest(name='random_sparsity_too_large'):
with self.assertRaisesRegex(
ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'):
masked.random_mask(self._masked_model, self._rng, 1.5)
def test_random_mask_sparsity_full(self):
"""Tests random mask generation, for 100% sparsity."""
mask = masked.random_mask(self._masked_model, self._rng, 1.)
with self.subTest(name='random_full_mask_values'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='random_full_mask_dense_values'):
self.assertTrue((masked_output.all() == 0).all())
with self.subTest(name='random_full_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_random_mask_sparsity_empty(self):
"""Tests random mask generation, for 0% sparsity."""
mask = masked.random_mask(self._masked_model, self._rng, 0.)
with self.subTest(name='random_empty_mask_values'):
self.assertEqual(
jnp.sum(mask['MaskedModule_0']['kernel']),
mask['MaskedModule_0']['kernel'].size)
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='random_empty_dense_values'):
self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())
with self.subTest(name='random_empty_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_random_mask_sparsity_half_full(self):
"""Tests random mask generation, for a half-full mask."""
mask = masked.random_mask(self._masked_model, self._rng, 0.5)
param_len = self._masked_model.params['MaskedModule_0']['unmasked'][
'kernel'].size
half_full = param_len / 2
with self.subTest(name='random_mask_values'):
self.assertBetween(
jnp.sum(mask['MaskedModule_0']['kernel']), 0.66 * half_full,
1.33 * half_full)
def test_simple_mask_one_layer(self):
"""Tests generation of a simple mask."""
mask = {
'MaskedModule_0': {
'kernel':
jnp.zeros(self._masked_model.params['MaskedModule_0']
['unmasked']['kernel'].shape),
'bias':
None,
}
}
gen_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])
result, _ = jax.tree_flatten(
jax.tree_util.tree_map(lambda x, *xs: (x == xs[0]).all(), mask,
gen_mask))
self.assertTrue(all(result))
def test_simple_mask_two_layer(self):
"""Tests generation of a simple mask."""
mask = {
'MaskedModule_0': {
'kernel':
jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0']
['unmasked']['kernel'].shape),
'bias':
None,
},
'MaskedModule_1': {
'kernel':
jnp.zeros(self._masked_model_twolayer.params['MaskedModule_1']
['unmasked']['kernel'].shape),
'bias':
None,
},
}
gen_mask = masked.simple_mask(self._masked_model_twolayer, jnp.zeros,
['kernel'])
result, _ = jax.tree_flatten(
jax.tree_util.tree_map(lambda x, *xs: (x == xs[0]).all(), mask,
gen_mask))
self.assertTrue(all(result))
def test_shuffled_mask_neuron_mask_sparsity_empty(self):
"""Tests shuffled neuron mask generation, for 0% sparsity."""
mask = masked.shuffled_neuron_mask(self._masked_model, self._rng, 0.0)
with self.subTest(name='shuffled_neuron_empty_mask'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_neuron_empty_mask_values'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())
with self.subTest(name='shuffled_neuron_empty_mask_not_masked_values'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='shuffled_neuron_empty_dense_values'):
self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())
with self.subTest(name='shuffled_neuron_empty_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_shuffled_mask_neuron_mask_sparsity_half_full(self):
"""Tests shuffled mask generation, for a half-full mask."""
mask = masked.shuffled_neuron_mask(self._masked_model, self._rng, 0.5)
param_len = len(
self._masked_model.params['MaskedModule_0']['unmasked']['kernel'][:, 0])
mask_sum = jnp.sum(mask['MaskedModule_0']['kernel'][:, 0])
with self.subTest(name='shuffled_mask_values'):
# Check that single neuron has the correct sparsity.
self.assertEqual(mask_sum, param_len // 2)
with self.subTest(name='shuffled_mask_rows_different'):
# Check that two rows are different.
self.assertFalse(
jnp.isclose(mask['MaskedModule_0']['kernel'][:, 0],
mask['MaskedModule_0']['kernel'][:, 1]).all())
def test_symmetric_mask_sparsity_empty(self):
"""Tests symmetric mask generation, for 0% sparsity."""
mask = masked.symmetric_mask(self._masked_model, self._rng, 0.0)
with self.subTest(name='shuffled_neuron_empty_mask'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='symmetric_empty_mask_values'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())
with self.subTest(name='symmetric_empty_mask_not_masked_values'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='symmetric_empty_dense_values'):
self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())
with self.subTest(name='symmetric_empty_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_symmetric_mask_sparsity_half_full(self):
"""Tests shuffled mask generation, for a half-full mask."""
mask = masked.symmetric_mask(self._masked_model, self._rng, 0.5)
param_len = len(
self._masked_model.params['MaskedModule_0']['unmasked']['kernel'][:, 0])
mask_sum = jnp.sum(mask['MaskedModule_0']['kernel'][:, 0])
with self.subTest(name='symmetric_mask_values'):
# Check that single neuron has the correct sparsity.
self.assertEqual(mask_sum, param_len // 2)
with self.subTest(name='symmetric_mask_rows_different'):
# Check that two rows are same.
self.assertTrue(
jnp.isclose(mask['MaskedModule_0']['kernel'][:, 0],
mask['MaskedModule_0']['kernel'][:, 1]).all())
def test_propagate_masks_ablated_neurons_one_layer(self):
"""Tests mask propagation on a single layer model."""
mask = {
'MaskedModule_0': {
'kernel':
jax.random.normal(
self._rng,
self._masked_model_twolayer.params['MaskedModule_0']
['unmasked']['kernel'].shape,
dtype=jnp.float32),
'bias':
None,
},
}
refined_mask = masked.propagate_masks(mask)
# Since this is a single layer, should not affect mask at all.
self.assertTrue((mask['MaskedModule_0']['kernel'] ==
refined_mask['MaskedModule_0']['kernel']).all())
def test_propagate_masks_ablated_neurons_two_layers(self):
"""Tests mask propagation on a two-layer model."""
mask = {
'MaskedModule_0': {
'kernel':
jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0']
['unmasked']['kernel'].shape),
'bias':
None,
},
'MaskedModule_1': {
'kernel':
jnp.ones(self._masked_model_twolayer.params['MaskedModule_1']
['unmasked']['kernel'].shape),
'bias':
None,
},
}
refined_mask = masked.propagate_masks(mask)
with self.subTest(name='layer_1'):
self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all())
# Since layer 1 is all zero, layer 2 is also effectively zero.
with self.subTest(name='layer_2'):
self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all())
def test_propagate_masks_ablated_neurons_two_layers_nonmasked(self):
"""Tests mask propagation where previous layer is not masked."""
mask = {
'Dense_0': {
'kernel': None,
'bias': None,
},
'MaskedModule_1': {
'kernel':
jax.random.normal(
self._rng,
self._masked_model_twolayer.params['MaskedModule_1']
['unmasked']['kernel'].shape,
dtype=jnp.float32),
'bias':
None,
},
}
refined_mask = masked.propagate_masks(mask)
with self.subTest(name='layer_1'):
self.assertIsNone(refined_mask['Dense_0']['kernel'])
# Since layer 1 is all zero, layer 2 is also effectively zero.
with self.subTest(name='layer_2'):
# Since this is a single masked layer, should not affect mask at all.
self.assertTrue((mask['MaskedModule_1']['kernel'] ==
refined_mask['MaskedModule_1']['kernel']).all())
def test_propagate_masks_ablated_neurons_one_conv_layer(self):
"""Tests mask propagation on a single layer model."""
mask = {
'MaskedModule_0': {
'kernel':
jax.random.normal(
self._rng,
self._masked_conv_model.params['MaskedModule_0']['unmasked']
['kernel'].shape,
dtype=jnp.float32),
'bias':
None,
},
}
refined_mask = masked.propagate_masks(mask)
# Since this is a single layer, should not affect mask at all.
self.assertTrue((mask['MaskedModule_0']['kernel'] ==
refined_mask['MaskedModule_0']['kernel']).all())
def test_propagate_masks_ablated_neurons_two_conv_layers(self):
"""Tests mask propagation on a two-layer convolutional model."""
mask = {
'MaskedModule_0': {
'kernel':
jnp.zeros(
self._masked_conv_model_twolayer.params['MaskedModule_0']
['unmasked']['kernel'].shape),
'bias':
None,
},
'MaskedModule_1': {
'kernel':
jnp.ones(
self._masked_conv_model_twolayer.params['MaskedModule_1']
['unmasked']['kernel'].shape),
'bias':
None,
},
}
refined_mask = masked.propagate_masks(mask)
with self.subTest(name='layer_1'):
self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all())
# Since layer 1 is all zero, layer 2 is also effectively zero.
with self.subTest(name='layer_2'):
self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all())
def test_propagate_masks_ablated_neurons_three_conv_fc_layers(self):
"""Tests mask propagation on a two-layer convolutional model with dense."""
mask = {
'MaskedModule_0': {
'kernel':
jnp.zeros(self._masked_conv_fc_model_threelayer
.params['MaskedModule_0']['unmasked']['kernel'].shape
),
'bias':
None,
},
'MaskedModule_1': {
'kernel':
jnp.ones(self._masked_conv_fc_model_threelayer
.params['MaskedModule_1']['unmasked']['kernel'].shape),
'bias':
None,
},
'MaskedModule_2': {
'kernel':
jnp.ones(self._masked_conv_fc_model_threelayer
.params['MaskedModule_2']['unmasked']['kernel'].shape),
'bias':
None,
},
}
refined_mask = masked.propagate_masks(mask)
with self.subTest(name='layer_1'):
self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all())
# Since layer 1 is all zero, layer 2 is also effectively zero.
with self.subTest(name='layer_2'):
self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all())
# Since layer 2 is all zero, layer 3 is also effectively zero.
with self.subTest(name='layer_3'):
self.assertTrue((refined_mask['MaskedModule_2']['kernel'] == 0).all())
def test_propagate_masks_ablated_neurons_mixed_conv_dense_layers(self):
"""Tests mask propagation on a two-layer convolutional/dense model."""
mask = {
'MaskedModule_0': {
'kernel':
jnp.zeros(
self._masked_mixed_model_twolayer.params['MaskedModule_0']
['unmasked']['kernel'].shape),
'bias':
None,
},
'MaskedModule_1': {
'kernel':
jnp.ones(
self._masked_mixed_model_twolayer.params['MaskedModule_1']
['unmasked']['kernel'].shape),
'bias':
None,
},
}
with self.assertRaisesRegex(
ValueError, 'propagate_masks requires knowledge of the spatial '
'dimensions of the previous layer. Use a functionally equivalent '
'conv. layer in place of a dense layer in a model with a mixed '
'conv/dense setting.'):
masked.propagate_masks(mask)
def test_mask_layer_sparsity_zero_mask(self):
"""Tests mask calculation with a zeroed mask."""
zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])
self.assertEqual(
masked.mask_layer_sparsity(zero_mask['MaskedModule_0']), 0.)
def test_mask_layer_sparsity_half_mask(self):
"""Tests mask calculation with a half-filled mask."""
half_mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5)
self.assertAlmostEqual(
masked.mask_layer_sparsity(half_mask['MaskedModule_0']), 0.5)
def test_mask_layer_sparsity_ones_mask(self):
"""Tests mask calculation with a mask full of ones."""
one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])
self.assertEqual(
masked.mask_layer_sparsity(one_mask['MaskedModule_0']), 1.)
def test_mask_sparsity_zero_mask(self):
"""Tests mask calculation with a zeroed mask."""
zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])
self.assertEqual(masked.mask_sparsity(zero_mask), 0.)
def test_mask_sparsity_ones_mask(self):
"""Tests mask calculation with a mask full of ones."""
one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])
self.assertEqual(masked.mask_sparsity(one_mask), 1.)
def test_mask_sparsity_mixed_mask(self):
"""Tests mask calculation with a mask different sparsity masked layers."""
mask = {
'MaskedModule_0': {
'kernel':
jnp.zeros(
self._masked_conv_model_twolayer.params['MaskedModule_0']
['unmasked']['kernel'].shape),
'bias':
None,
},
'MaskedModule_1': {
'kernel':
jnp.ones(
self._masked_conv_model_twolayer.params['MaskedModule_1']
['unmasked']['kernel'].shape),
'bias':
None,
},
}
mask_sparsity = masked.mask_sparsity(mask)
true_sparsity = self._masked_conv_model_twolayer.params['MaskedModule_1'][
'unmasked']['kernel'].size / (
self._masked_conv_model_twolayer.params['MaskedModule_0']
['unmasked']['kernel'].size + self._masked_conv_model_twolayer
.params['MaskedModule_1']['unmasked']['kernel'].size)
self.assertAlmostEqual(mask_sparsity, 1.0 - true_sparsity)
@parameterized.parameters(
# Simple masked 1-layer model.
(1,),
# Simple masked 2-layer model.
(2,),
# Simple masked 10-layer model.
(10,),
)
def test_generate_model_masks_depth_only(self, depth):
mask = masked.generate_model_masks(depth)
with self.subTest(name='test_model_mask_length'):
self.assertLen(mask, depth)
for i in range(depth):
with self.subTest(name=f'test_model_mask_value_layer_{i}'):
self.assertIsNone(mask[f'MaskedModule_{i}'])
@parameterized.parameters(
# Simple masked 1-layer model, no masked indices.
(1, []),
# Simple masked 2-layer model, second layer masked.
(2, (1,)),
# Simple masked 10-layer model, 4 layers masked.
(10, (1, 2, 3, 9)),
)
def test_generate_model_masks_indices(self, depth, indices):
mask = masked.generate_model_masks(depth, None, indices)
with self.subTest(name='test_model_mask_length'):
self.assertLen(mask, len(indices))
for i in indices:
with self.subTest(name=f'test_model_mask_value_layer_{i}'):
self.assertIsNone(mask[f'MaskedModule_{i}'])
@parameterized.parameters(
# Existing 1-layer mask.
(1, {'MaskedModule_0': np.ones(1)}, None),
(2, {'MaskedModule_0': np.ones(1),
'MaskedModule_1': np.ones(1)}, None),
# Existing 2-layer mask, only using one due to mask indices.
(2, {'MaskedModule_0': np.ones(1),
'MaskedModule_1': np.ones(1),}, (1,)),
)
def test_generate_model_masks_existing_mask(self, depth, existing_mask,
indices):
mask = masked.generate_model_masks(depth, existing_mask, indices)
# Need to differentiate from empty sequence by explicitly checking is None.
if indices is None:
indices = range(depth)
with self.subTest(name='test_model_mask_length'):
self.assertLen(mask, len(indices))
for i in indices:
with self.subTest(name=f'test_model_mask_value_layer_{i}'):
self.assertIsNotNone(mask[f'MaskedModule_{i}'])
# Ensure existing mask layers that aren't in indices aren't in output.
for i in range(depth):
if i not in indices:
with self.subTest(
name=f'test_model_mask_only_allowed_indices_layer_{i}'):
self.assertNotIn(f'MaskedModule_{i}', mask)
def test_generate_model_masks_invalid_depth_zero(self):
with self.assertRaisesWithLiteralMatch(ValueError,
'Invalid model depth: 0'):
masked.generate_model_masks(0)
def test_generate_model_masks_invalid_index_toohigh(self):
with self.assertRaisesWithLiteralMatch(
ValueError, 'Invalid indices for given depth (2): (1, 2)'):
masked.generate_model_masks(2, None, (1, 2))
def test_generate_model_masks_invalid_index_negative(self):
with self.assertRaisesWithLiteralMatch(
ValueError, 'Invalid indices for given depth (2): (-1, 2)'):
masked.generate_model_masks(2, None, (-1, 2))
def test_shuffled_neuron_no_input_ablation_mask_invalid_model(self):
"""Tests shuffled mask with model containing no masked layers."""
with self.assertRaisesRegex(
ValueError, 'Model does not support masking, i.e. no layers are '
'wrapped by a MaskedModule.'):
masked.shuffled_neuron_no_input_ablation_mask(self._unmasked_model,
self._rng, 0.5)
def test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity(self):
"""Tests shuffled mask with invalid sparsity."""
with self.subTest(name='sparsity_too_small'):
with self.assertRaisesRegex(
ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'):
masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,
self._rng, -0.5)
with self.subTest(name='sparsity_too_large'):
with self.assertRaisesRegex(
ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'):
masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,
self._rng, 1.5)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self):
"""Tests shuffled mask generation, for 100% sparsity."""
mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,
self._rng, 1.0)
with self.subTest(name='shuffled_full_mask'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_full_mask_values'):
self.assertEqual(jnp.count_nonzero(mask['MaskedModule_0']['kernel']),
jnp.prod(jnp.array(self._input_dimensions)))
with self.subTest(name='shuffled_full_no_input_ablation'):
# Check no row (neurons are columns) is completely ablated.
self.assertTrue((jnp.count_nonzero(
mask['MaskedModule_0']['kernel'], axis=0) != 0).all())
with self.subTest(name='shuffled_full_mask_not_masked_values'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='shuffled_full_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty(self):
"""Tests shuffled mask generation, for 0% sparsity."""
mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,
self._rng, 0.0)
with self.subTest(name='shuffled_empty_mask'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_empty_mask_values'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())
with self.subTest(name='shuffled_empty_mask_not_masked_values'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
masked_output = self._masked_model(self._input, mask=mask)
with self.subTest(name='shuffled_empty_dense_values'):
self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())
with self.subTest(name='shuffled_empty_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self):
"""Tests shuffled mask generation, for a half-full mask."""
mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,
self._rng, 0.5)
param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][
'kernel'].shape
with self.subTest(name='shuffled_mask_values'):
self.assertEqual(
jnp.sum(mask['MaskedModule_0']['kernel']),
param_shape[0]//2 * param_shape[1])
with self.subTest(name='shuffled_half_no_input_ablation'):
# Check no row (neurons are columns) is completely ablated.
self.assertTrue((jnp.count_nonzero(
mask['MaskedModule_0']['kernel'], axis=0) != 0).all())
def test_shuffled_neuron_no_input_ablation_mask_sparsity_quarter_full(self):
"""Tests shuffled mask generation, for a half-full mask."""
mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,
self._rng, 0.25)
param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][
'kernel'].shape
with self.subTest(name='shuffled_mask_values'):
self.assertEqual(
jnp.sum(mask['MaskedModule_0']['kernel']),
0.75 * param_shape[0] * param_shape[1])
with self.subTest(name='shuffled_half_no_input_ablation'):
# Check no row (neurons are columns) is completely ablated.
self.assertTrue((jnp.count_nonzero(
mask['MaskedModule_0']['kernel'], axis=0) != 0).all())
def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer(self):
"""Tests shuffled mask generation for two layers, and 100% sparsity."""
mask = masked.shuffled_neuron_no_input_ablation_mask(
self._masked_model_twolayer, self._rng, 1.0)
with self.subTest(name='shuffled_full_mask_layer1'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_full_mask_values_layer1'):
self.assertEqual(jnp.count_nonzero(mask['MaskedModule_0']['kernel']),
jnp.prod(jnp.array(self._input_dimensions)))
with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):
self.assertIsNone(mask['MaskedModule_0']['bias'])
with self.subTest(name='shuffled_full_no_input_ablation_layer1'):
# Check no row (neurons are columns) is completely ablated.
self.assertTrue((jnp.count_nonzero(
mask['MaskedModule_0']['kernel'], axis=0) != 0).all())
with self.subTest(name='shuffled_full_mask_layer2'):
self.assertIn('MaskedModule_1', mask)
with self.subTest(name='shuffled_full_mask_values_layer2'):
self.assertEqual(jnp.count_nonzero(mask['MaskedModule_1']['kernel']),
jnp.prod(MaskedTwoLayerDense.NUM_FEATURES[0]))
with self.subTest(name='shuffled_full_mask_not_masked_values_layer2'):
self.assertIsNone(mask['MaskedModule_1']['bias'])
with self.subTest(name='shuffled_full_no_input_ablation_layer2'):
# Note: check no *inputs* are ablated, and inputs < num_neurons.
self.assertEqual(
jnp.sum(jnp.count_nonzero(mask['MaskedModule_1']['kernel'], axis=0)),
MaskedTwoLayerDense.NUM_FEATURES[0])
masked_output = self._masked_model_twolayer(self._input, mask=mask)
with self.subTest(name='shuffled_full_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape,
self._unmasked_output_twolayer.shape)
def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolayer(self):
"""Tests shuffled mask generation for two layers, for 0% sparsity."""
mask = masked.shuffled_neuron_no_input_ablation_mask(
self._masked_model_twolayer, self._rng, 0.0)
with self.subTest(name='shuffled_empty_mask_layer1'):
self.assertIn('MaskedModule_0', mask)
with self.subTest(name='shuffled_empty_mask_values_layer1'):
self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())
with self.subTest(name='shuffled_empty_mask_layer2'):
self.assertIn('MaskedModule_1', mask)
with self.subTest(name='shuffled_empty_mask_values_layer2'):
self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all())
masked_output = self._masked_model_twolayer(self._input, mask=mask)
with self.subTest(name='shuffled_empty_dense_values'):
self.assertTrue(
jnp.isclose(masked_output, self._unmasked_output_twolayer).all())
with self.subTest(name='shuffled_empty_mask_dense_shape'):
self.assertSequenceEqual(masked_output.shape,
self._unmasked_output_twolayer.shape)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/pruning.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for pruning FLAX masked models."""
from collections import abc
from typing import Any, Callable, Mapping, Optional, Union
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import masked
def weight_magnitude(weights):
"""Creates weight magnitude-based saliencies, given a weight matrix."""
return jnp.absolute(weights)
def prune(
model,
pruning_rate,
saliency_fn = weight_magnitude,
mask = None,
compare_fn = jnp.greater):
"""Returns a mask for a model where the params in each layer are pruned using a saliency function.
Args:
model: The model to create a pruning mask for.
pruning_rate: The fraction of lowest magnitude saliency weights that are
pruned. If a float, the same rate is used for all layers, otherwise if it
is a mapping, it must contain a rate for all masked layers in the model.
saliency_fn: A function that returns a float number used to rank
the importance of individual weights in the layer.
mask: If the model has an existing mask, the mask will be applied before
pruning the model.
compare_fn: A pairwise operator to compare saliency with threshold, and
return True if the saliency indicates the value should not be masked.
Returns:
A pruned mask for the given model.
"""
if not mask:
mask = masked.simple_mask(model, jnp.ones, masked.WEIGHT_PARAM_NAMES)
if not isinstance(pruning_rate, abc.Mapping):
pruning_rate_dict = {}
for param_name, _ in masked.iterate_mask(mask):
# Get the layer name from the parameter's full name/path.
layer_name = param_name.split('/')[-2]
pruning_rate_dict[layer_name] = pruning_rate
pruning_rate = pruning_rate_dict
for param_path, param_mask in masked.iterate_mask(mask):
split_param_path = param_path.split('/')
layer_name = split_param_path[-2]
param_name = split_param_path[-1]
# If we don't have a pruning rate for the given layer, don't mask it.
if layer_name in pruning_rate and mask[layer_name][param_name] is not None:
param_value = model.params[layer_name][
masked.MaskedModule.UNMASKED][param_name]
# Here any existing mask is first applied to weight matrix.
# Note: need to check explicitly is not None for np array.
if param_mask is not None:
saliencies = saliency_fn(param_mask * param_value)
else:
saliencies = saliency_fn(param_value)
# TODO: Use partition here (partial sort) instead of sort,
# since it's O(N), not O(N log N), however JAX doesn't support it.
sorted_param = jnp.sort(jnp.abs(saliencies.flatten()))
# Figure out the weight magnitude threshold.
threshold_index = jnp.round(pruning_rate[layer_name] *
sorted_param.size).astype(jnp.int32)
threshold = sorted_param[threshold_index]
mask[layer_name][param_name] = jnp.array(
compare_fn(saliencies, threshold), dtype=jnp.int32)
return mask
================================================
FILE: rigl/experimental/jax/pruning/pruning_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.pruning.pruning."""
from typing import Mapping, Optional, Sequence
from absl.testing import absltest
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.pruning import pruning
class MaskedDense(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask else None)
class MaskedTwoLayerDense(flax.deprecated.nn.Module):
"""Two-layer Dense Masked Network."""
NUM_FEATURES: Sequence[int] = (32, 64)
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[0],
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask else None)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[1],
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_1'] if mask else None)
class MaskedConv(flax.deprecated.nn.Module):
"""Single-layer Conv Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
kernel_size=(3, 3),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_0'] if mask is not None else None)
class MaskedTwoLayerConv(flax.deprecated.nn.Module):
"""Two-layer Conv Masked Network."""
NUM_FEATURES: Sequence[int] = (16, 32)
def apply(self,
inputs,
mask = None):
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[0],
kernel_size=(5, 5),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_0'] if mask is not None else None)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[1],
kernel_size=(3, 3),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_1'] if mask is not None else None)
class PruningTest(absltest.TestCase):
"""Tests the flax layer pruning module."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._batch_size = 2
self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
self._input = jnp.ones(*self._input_shape)
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._input_shape,))
self._masked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)
_, initial_params = MaskedTwoLayerDense.init_by_shape(
self._rng, (self._input_shape,))
self._masked_model_twolayer = flax.deprecated.nn.Model(
MaskedTwoLayerDense, initial_params)
_, initial_params = MaskedConv.init_by_shape(self._rng,
(self._input_shape,))
self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv,
initial_params)
_, initial_params = MaskedTwoLayerConv.init_by_shape(
self._rng, (self._input_shape,))
self._masked_conv_model_twolayer = flax.deprecated.nn.Model(
MaskedTwoLayerConv, initial_params)
def test_prune_single_layer_dense_no_mask(self):
"""Tests pruning of single dense layer without an existing mask."""
pruned_mask = pruning.prune(self._masked_model, 0.5)
mask_sparsity = masked.mask_sparsity(pruned_mask)
with self.subTest(name='test_mask_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_sparsity'):
self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
def test_prune_single_layer_local_pruning(self):
"""Test pruning of model with a single layer, and local pruning schedule."""
pruned_mask = pruning.prune(self._masked_model, {
'MaskedModule_0': 0.5,
})
mask_sparsity = masked.mask_sparsity(pruned_mask)
with self.subTest(name='test_mask_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_sparsity'):
self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
def test_prune_single_layer_dense_with_mask(self):
"""Tests pruning of single dense layer with an existing mask."""
pruned_mask = pruning.prune(
self._masked_model,
0.5,
mask=masked.shuffled_mask(self._masked_model, self._rng, 0.95))
mask_sparsity = masked.mask_sparsity(pruned_mask)
with self.subTest(name='test_mask_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_sparsity'):
self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
def test_prune_two_layers_dense_no_mask(self):
"""Tests pruning of model with two dense layers without an existing mask."""
pruned_mask = pruning.prune(self._masked_model_twolayer, 0.5)
mask_sparsity = masked.mask_sparsity(pruned_mask)
with self.subTest(name='test_mask_layer1_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_layer2_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])
with self.subTest(name='test_mask_sparsity'):
self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
def test_prune_two_layer_local_pruning_rate(self):
"""Test pruning of model with two layers, and a local pruning schedule."""
pruned_mask = pruning.prune(self._masked_model_twolayer, {
'MaskedModule_1': 0.5,
})
mask_layer_0_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_0'])
mask_layer_1_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_1'])
with self.subTest(name='test_mask_layer1_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_layer2_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])
with self.subTest(name='test_mask_layer_0_sparsity'):
self.assertEqual(mask_layer_0_sparsity, 0.)
with self.subTest(name='test_mask_layer_1_sparsity'):
self.assertAlmostEqual(mask_layer_1_sparsity, 0.5, places=3)
def test_prune_one_layer_conv_no_mask(self):
"""Tests pruning of model with one conv. layer without an existing mask."""
pruned_mask = pruning.prune(self._masked_conv_model, 0.5)
mask_sparsity = masked.mask_sparsity(pruned_mask)
with self.subTest(name='test_mask_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_sparsity'):
self.assertAlmostEqual(mask_sparsity, 0.5, places=1)
def test_prune_one_layer_conv_with_mask(self):
"""Tests pruning of model with one conv. layer with an existing mask."""
pruned_mask = pruning.prune(
self._masked_conv_model,
0.5,
mask=masked.shuffled_mask(self._masked_model, self._rng, 0.95))
mask_sparsity = masked.mask_sparsity(pruned_mask)
with self.subTest(name='test_mask_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_sparsity'):
self.assertAlmostEqual(mask_sparsity, 0.95, places=3)
def test_prune_two_layer_conv_no_mask(self):
"""Tests pruning of model with two conv. layer without an existing mask."""
pruned_mask = pruning.prune(self._masked_conv_model_twolayer, 0.5)
mask_sparsity = masked.mask_sparsity(pruned_mask)
with self.subTest(name='test_mask_layer1_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])
with self.subTest(name='test_mask_layer2_param_not_none'):
self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])
with self.subTest(name='test_mask_sparsity'):
self.assertAlmostEqual(mask_sparsity, 0.5, places=3)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/symmetry.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Code for analyzing symmetries in NN."""
import functools
import math
import operator
from typing import Dict, Optional, Union
import jax.numpy as jnp
import numpy as np
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.utils import utils
def count_permutations_mask_layer(
mask_layer,
next_mask_layer = None,
parameter_key = 'kernel'):
"""Calculates the number of permutations for a layer, given binary masks.
Args:
mask_layer: The binary weight mask of a dense/conv layer, where last
dimension is number of neurons/filters.
next_mask_layer: The binary weight mask of the following a dense/conv layer,
or None if this is the last layer.
parameter_key: The name of the parameter to count the permutations of in each
layer.
Returns:
A dictionary with stats on the permutation structure of a mask, including
the number of symmetric permutations of the mask, number of unique mask
columns, count of the zeroed out (structurally pruned) neurons, and total
number of neurons/filters.
"""
# Have to check 'is None' since mask_layer[parameter_key] is jnp.array.
if not mask_layer or parameter_key not in mask_layer or mask_layer[
parameter_key] is None:
return {
'permutations': 1,
'zeroed_neurons': 0,
'total_neurons': 0,
'unique_neurons': 0,
}
mask = mask_layer[parameter_key]
num_neurons = mask.shape[-1]
# Initialize with stats for an empty mask.
mask_stats = {
'permutations': 0,
'zeroed_neurons': num_neurons,
'total_neurons': num_neurons,
'unique_neurons': 0,
}
# Re-shape masks as 1D, in case they are 2D (e.g. convolutional).
connection_mask = mask.reshape(-1, num_neurons)
# Only consider non-zero columns (in JAX neurons/filters are last index).
non_zero_neurons = ~jnp.all(connection_mask == 0, axis=0)
# Count only zeroed neurons in the current layer.
zeroed_count = num_neurons - jnp.count_nonzero(non_zero_neurons)
# Special case where all neurons in current layer are ablated.
if zeroed_count == num_neurons:
return mask_stats
# Have to check is None since next_mask_layer[parameter_key] is jnp.array.
if next_mask_layer and parameter_key in next_mask_layer and next_mask_layer[
parameter_key] is not None:
next_mask = next_mask_layer[parameter_key]
# Re-shape masks as 1D, in case they are 2D (e.g. convolutional).
next_connection_mask = next_mask.T.reshape(-1, num_neurons)
# Update with neurons that are non-zero in outgoing connections too.
non_zero_neurons &= ~jnp.all(next_connection_mask == 0, axis=0)
# Remove rows corresponding to neurons that are ablated.
next_connection_mask = next_connection_mask[:, non_zero_neurons]
connection_mask = connection_mask[:, non_zero_neurons]
# Combine the outgoing and incoming masks in one vector per-neuron.
connection_mask = jnp.concatenate(
(connection_mask, next_connection_mask), axis=0)
else:
connection_mask = connection_mask[:, non_zero_neurons]
# Effectively no connections between these two layers.
if not connection_mask.size:
return mask_stats
# Note: np.unique not implemented in JAX numpy yet.
_, unique_counts = np.unique(connection_mask, axis=-1, return_counts=True)
# Convert from device array.
mask_stats['zeroed_neurons'] = int(zeroed_count)
mask_stats['permutations'] = functools.reduce(
operator.mul, (np.math.factorial(t) for t in unique_counts))
mask_stats['unique_neurons'] = len(unique_counts)
return mask_stats
def count_permutations_mask(mask):
"""Calculates the number of permutations for a given model mask.
Args:
mask: Model masks to check, similar to Model.params.
Returns:
A dictionary with stats on the permutation structure of a mask, including
the number of symmetric permutations of the mask, number of unique mask
columns, count of the zeroed out (structurally pruned) neurons, and total
number of neurons/filters.
"""
sum_keys = ('total_neurons', 'unique_neurons', 'zeroed_neurons')
product_keys = ('permutations',)
# Count permutation stats for each pairwise set of layers.
# Note: I tried doing this with more_itertools.pairwise/itertools.chain, but
# there is a type conflict in passing iterators of different types to
# itertools.chain.
counts = [
count_permutations_mask_layer(layer, next_layer)
for layer, next_layer in utils.pairwise_longest(mask.values())
]
sum_stats = {}
for key in sum_keys:
sum_stats[key] = functools.reduce(operator.add, (z[key] for z in counts))
product_stats = {}
for key in product_keys:
product_stats[key] = functools.reduce(operator.mul,
(z[key] for z in counts))
return {**sum_stats, **product_stats}
def get_mask_stats(mask):
"""Calculates an array of mask statistics.
Args:
mask: A model mask to calculate the statistics of.
Returns:
A dictionary, containing a set of mask statistics.
"""
mask_stats = count_permutations_mask(mask)
mask_stats.update({
'sparsity': masked.mask_sparsity(mask),
'permutation_num_digits': len(str(mask_stats['permutations'])),
'permutation_log10': math.log10(mask_stats['permutations'] + 1),
})
return mask_stats
================================================
FILE: rigl/experimental/jax/pruning/symmetry_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.pruning.symmetry."""
import functools
import math
import operator
from typing import Mapping, Optional, Sequence
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
import numpy as np
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.pruning import symmetry
class MaskedDense(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network.
Attributes:
NUM_FEATURES: The number of neurons in the single dense layer.
"""
NUM_FEATURES: int = 16
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask is not None else None)
class MaskedConv(flax.deprecated.nn.Module):
"""Single-layer Conv Masked Network.
Attributes:
NUM_FEATURES: The number of filters in the single conv layer.
"""
NUM_FEATURES: int = 16
def apply(self,
inputs,
mask = None):
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
kernel_size=(3, 3),
wrapped_module=flax.deprecated.nn.Conv,
mask=mask['MaskedModule_0'] if mask is not None else None)
class MaskedTwoLayerDense(flax.deprecated.nn.Module):
"""Two-layer Dense Masked Network.
Attributes:
NUM_FEATURES: The number of neurons in the dense layers.
"""
NUM_FEATURES: Sequence[int] = (16, 32)
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
inputs = masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[0],
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask is not None else None)
inputs = flax.deprecated.nn.relu(inputs)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES[1],
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_1'] if mask is not None else None)
class SymmetryTest(parameterized.TestCase):
"""Tests symmetry analysis methods."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._batch_size = 2
self._input_shape = ((self._batch_size, 2, 2, 1), jnp.float32)
self._flat_input_shape = ((self._batch_size, 2 * 2 * 1), jnp.float32)
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._flat_input_shape,))
self._masked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)
_, initial_params = MaskedConv.init_by_shape(self._rng,
(self._input_shape,))
self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv,
initial_params)
_, initial_params = MaskedTwoLayerDense.init_by_shape(
self._rng, (self._flat_input_shape,))
self._masked_two_layer_model = flax.deprecated.nn.Model(
MaskedTwoLayerDense, initial_params)
def test_count_permutations_layer_mask_full(self):
"""Tests count of weight permutations in a full mask."""
mask_layer = {
'kernel':
jnp.ones(self._masked_model.params['MaskedModule_0']['unmasked']
['kernel'].shape),
}
stats = symmetry.count_permutations_mask_layer(mask_layer)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 1)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'],
math.factorial(MaskedDense.NUM_FEATURES))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedDense.NUM_FEATURES)
def test_count_permutations_layer_mask_empty(self):
"""Tests count of weight permutations in an empty mask."""
mask_layer = {
'kernel':
jnp.zeros(self._masked_model.params['MaskedModule_0']['unmasked']
['kernel'].shape),
}
stats = symmetry.count_permutations_mask_layer(mask_layer)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 0)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], 0)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], MaskedDense.NUM_FEATURES)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedDense.NUM_FEATURES)
def test_count_permutations_conv_layer_mask_full(self):
"""Tests count of weight permutations in a full mask for a conv. layer."""
mask_layer = {
'kernel':
jnp.ones(self._masked_conv_model.params['MaskedModule_0']
['unmasked']['kernel'].shape),
}
stats = symmetry.count_permutations_mask_layer(mask_layer)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 1)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'],
math.factorial(MaskedConv.NUM_FEATURES))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_conv_layer_mask_empty(self):
"""Tests count of weight permutations in an empty mask for a conv. layer."""
mask_layer = {
'kernel':
jnp.zeros(self._masked_conv_model.params['MaskedModule_0']
['unmasked']['kernel'].shape),
}
stats = symmetry.count_permutations_mask_layer(mask_layer)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 0)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], 0)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_layer_mask_known_perm(self):
"""Tests count of weight permutations in a mask with known # permutations."""
param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][
'kernel'].shape
# Create two unique random mask rows.
row_type_one = jax.random.bernoulli(
self._rng, p=0.3, shape=(param_shape[0],)).astype(jnp.int32)
row_type_two = jax.random.bernoulli(
self._rng, p=0.9, shape=(param_shape[0],)).astype(jnp.int32)
# Create mask by repeating the two unique rows.
repeat_one = param_shape[-1] // 3
repeat_two = param_shape[-1] - repeat_one
mask_layer = {'kernel': jnp.concatenate(
(jnp.repeat(row_type_one[:, jnp.newaxis], repeat_one, axis=-1),
jnp.repeat(row_type_two[:, jnp.newaxis], repeat_two, axis=-1)),
axis=-1)}
stats = symmetry.count_permutations_mask_layer(mask_layer)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 2)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'],
math.factorial(repeat_one) * math.factorial(repeat_two))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], param_shape[-1])
def test_count_permutations_layer_mask_known_perm_zeros(self):
"""Tests count of weight permutations in a mask with zeroed neurons."""
param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][
'kernel'].shape
# Create two unique random mask rows.
row_type_one = jax.random.bernoulli(
self._rng, p=0.3, shape=(param_shape[0],)).astype(jnp.int32)
row_type_two = jnp.zeros(shape=(param_shape[0],), dtype=jnp.int32)
# Create mask by repeating the two unique rows.
repeat_one = param_shape[-1] // 3
repeat_two = param_shape[-1] - repeat_one
mask_layer = {'kernel': jnp.concatenate(
(jnp.repeat(row_type_one[:, jnp.newaxis], repeat_one, axis=-1),
jnp.repeat(row_type_two[:, jnp.newaxis], repeat_two, axis=-1)),
axis=-1)}
stats = symmetry.count_permutations_mask_layer(mask_layer)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 1)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], math.factorial(repeat_one))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], repeat_two)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], param_shape[-1])
def test_count_permutations_shuffled_full_mask(self):
"""Tests count of weight permutations on a generated full mask."""
mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=1)
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 0)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], 0)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_shuffled_empty_mask(self):
"""Tests count of weight permutations on a generated empty mask."""
mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=0)
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 1)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'],
math.factorial(MaskedConv.NUM_FEATURES))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_layer_twolayer_known_symmetric(self):
"""Tests count of permutations in a known mask with 2 permutations."""
mask = {
'MaskedModule_0': {
'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T,
},
'MaskedModule_1': {
'kernel': jnp.array(((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T,
},
}
stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'],
mask['MaskedModule_1'])
with self.subTest(name='count_permutations_unique'):
self.assertEqual(stats['unique_neurons'], 2)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], 2)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'],
mask['MaskedModule_0']['kernel'].shape[-1])
# Note: Can't pass jnp.array here since global, InitGoogle() not called yet.
@parameterized.parameters(
# Tests mask with 1 permutation only if both layers are considered.
({
'MaskedModule_0': {
'kernel': np.array(((1, 0), (1, 0), (0, 1))).T,
},
'MaskedModule_1': {
'kernel':
np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T,
},
}, 3, 1, 0, 3),
# Tests mask zero count with an ablated neuron in first layer.
({
'MaskedModule_0': {
'kernel': np.array(((1, 0), (1, 0), (0, 0))).T,
},
'MaskedModule_1': {
'kernel':
np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T,
},
}, 2, 1, 1, 3),
# Tests mask zero count with first layer completely ablated.
({
'MaskedModule_0': {
'kernel': np.array(((0, 0), (0, 0), (0, 0))).T,
},
'MaskedModule_1': {
'kernel':
np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T,
},
}, 0, 0, 3, 3),
# Tests mask zero count with second layer completely ablated.
({
'MaskedModule_0': {
'kernel': np.array(((1, 0), (1, 0), (0, 1))).T,
},
'MaskedModule_1': {
'kernel':
np.array(((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))).T,
},
}, 0, 0, 3, 3),
# """Tests layer 1 permutation matrix mask, having only 1 permutation."""
({
'MaskedModule_0': {
'kernel': np.array(((1, 0, 0), (0, 1, 0), (0, 0, 1))).T,
},
'MaskedModule_1': {
'kernel':
np.array(((1, 1, 1), (0, 1, 1), (1, 1, 1), (1, 1, 1))).T,
},
}, 3, 1, 0, 3),
)
def test_count_permutations_mask_layer_twolayer(self, mask, unique,
permutations, zeroed, total):
"""Test mask permutations if both layers are considered."""
stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'],
mask['MaskedModule_1'])
with self.subTest(name='count_permutations_unique'):
self.assertEqual(stats['unique_neurons'], unique)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], permutations)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], zeroed)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], total)
def test_count_permutations_mask_full(self):
"""Tests count of weight permutations in a full mask."""
mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 1)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'],
math.factorial(MaskedDense.NUM_FEATURES))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_bn_layer_full(self):
"""Tests count of permutations on a mask for model with non-masked layers."""
mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 1)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'],
math.factorial(MaskedDense.NUM_FEATURES))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_empty(self):
"""Tests count of weight permutations in an empty mask."""
mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 0)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], 0)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)
def test_count_permutations_mask_twolayer_full(self):
"""Tests count of weight permutations in a full mask for 2 layers."""
mask = masked.simple_mask(self._masked_two_layer_model, jnp.ones,
['kernel'])
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 2)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(
stats['permutations'],
functools.reduce(
operator.mul,
[math.factorial(x) for x in MaskedTwoLayerDense.NUM_FEATURES]))
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 0)
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'],
sum(MaskedTwoLayerDense.NUM_FEATURES))
def test_count_permutations_mask_twolayers_empty(self):
"""Tests count of weight permutations in an empty mask for 2 layers."""
mask = masked.simple_mask(self._masked_two_layer_model, jnp.zeros,
['kernel'])
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_mask_unique'):
self.assertEqual(stats['unique_neurons'], 0)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], 0)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'],
sum(MaskedTwoLayerDense.NUM_FEATURES))
with self.subTest(name='count_permutations_total'):
self.assertEqual(stats['total_neurons'],
sum(MaskedTwoLayerDense.NUM_FEATURES))
def test_count_permutations_mask_twolayer_known_symmetric(self):
"""Tests count of permutations in a known mask with 4 permutations."""
mask = {
'MaskedModule_0': {
'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T
},
'MaskedModule_1': {
'kernel': jnp.array(((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T
}
}
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_full_mask_unique'):
self.assertEqual(stats['unique_neurons'], 4)
with self.subTest(name='count_permutations_full_mask_permutations'):
self.assertEqual(stats['permutations'], 4)
with self.subTest(name='count_permutations_full_mask_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 1)
with self.subTest(name='Count_permutations_full_mask_total'):
self.assertEqual(
stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] +
mask['MaskedModule_1']['kernel'].shape[-1])
def test_count_permutations_mask_twolayer_known_non_symmetric(self):
"""Tests mask with 1 permutation only if both layers are considered."""
mask = {
'MaskedModule_0': {
'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T
},
'MaskedModule_1': {
'kernel': jnp.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T
}
}
stats = symmetry.count_permutations_mask(mask)
with self.subTest(name='count_permutations_unique'):
self.assertEqual(stats['unique_neurons'], 6)
with self.subTest(name='count_permutations_permutations'):
self.assertEqual(stats['permutations'], 1)
with self.subTest(name='count_permutations_zeroed'):
self.assertEqual(stats['zeroed_neurons'], 1)
with self.subTest(name='count_permutations_total'):
self.assertEqual(
stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] +
mask['MaskedModule_1']['kernel'].shape[-1])
def test_get_mask_stats_keys_values(self):
"""Tests the returned dict has the required keys, and value types/ranges."""
mask = {
'MaskedModule_0': {
'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T
},
'MaskedModule_1': {
'kernel': jnp.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T
}
}
mask_stats = symmetry.get_mask_stats(mask)
with self.subTest(name='sparsity_exists'):
self.assertIn('sparsity', mask_stats)
with self.subTest(name='sparsity_value'):
self.assertBetween(mask_stats['sparsity'], 0.0, 1.0)
with self.subTest(name='permutation_num_digits_exists'):
self.assertIn('permutation_num_digits', mask_stats)
with self.subTest(name='permutation_num_digits_value'):
self.assertGreaterEqual(mask_stats['permutation_num_digits'], 0.0)
with self.subTest(name='permutation_log10_exists'):
self.assertIn('permutation_log10', mask_stats)
with self.subTest(name='permutation_log10_value'):
self.assertGreaterEqual(mask_stats['permutation_log10'], 0.0)
with self.subTest(name='unique_neurons_exists'):
self.assertIn('unique_neurons', mask_stats)
with self.subTest(name='unique_neurons_value'):
self.assertEqual(mask_stats['unique_neurons'], 6)
with self.subTest(name='permutations_exists'):
self.assertIn('permutations', mask_stats)
with self.subTest(name='permutations_value'):
self.assertEqual(mask_stats['permutations'], 1)
with self.subTest(name='zeroed_neurons_exists'):
self.assertIn('zeroed_neurons', mask_stats)
with self.subTest(name='zeroed_neurons_value'):
self.assertEqual(mask_stats['zeroed_neurons'], 1)
with self.subTest(name='total_neurons_exists'):
self.assertIn('total_neurons', mask_stats)
with self.subTest(name='total_neurons_value'):
self.assertEqual(mask_stats['total_neurons'],
mask['MaskedModule_0']['kernel'].shape[-1] +
mask['MaskedModule_1']['kernel'].shape[-1])
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/random_mask.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Symmetry: Train model with randomly sampled sparse mask."""
import ast
from os import path
from typing import List, Sequence
import uuid
from absl import app
from absl import flags
from absl import logging
import flax
from flax.metrics import tensorboard
from flax.training import lr_schedule
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.pruning import mask_factory
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.pruning import symmetry
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils
experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))
logging.info('Saving experimental results to %s', experiment_dir)
host_count = jax.host_count()
local_device_count = jax.local_device_count()
logging.info('Device count: %d, host count: %d, local device count: %d',
jax.device_count(), host_count, local_device_count)
if jax.host_id() == 0:
summary_writer = tensorboard.SummaryWriter(experiment_dir)
dataset = dataset_factory.create_dataset(
FLAGS.dataset,
FLAGS.batch_size,
FLAGS.batch_size_test,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)
rng = jax.random.PRNGKey(FLAGS.random_seed)
input_shape = (1,) + dataset.shape
base_model, _ = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
masked_layer_indices=FLAGS.masked_layer_indices)
logging.info('Generating random mask based on model')
# Re-initialize the RNG to maintain same training pattern (as in prune code).
mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed)
mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng,
FLAGS.mask_sparsity)
if jax.host_id() == 0:
mask_stats = symmetry.get_mask_stats(mask)
logging.info('Mask stats: %s', str(mask_stats))
for label, value in mask_stats.items():
try:
summary_writer.scalar(f'mask/{label}', value, 0)
# This is needed because permutations (long int) can't be cast to float32.
except (OverflowError, ValueError):
summary_writer.text(f'mask/{label}', str(value), 0)
logging.error('Could not write mask/%s to tensorflow summary as float32'
', writing as string instead.', label)
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(
mask_stats, path.join(experiment_dir, 'mask_stats.json'))
mask = masked.propagate_masks(mask)
if jax.host_id() == 0:
mask_stats = symmetry.get_mask_stats(mask)
logging.info('Propagated mask stats: %s', str(mask_stats))
for label, value in mask_stats.items():
try:
summary_writer.scalar(f'propagated_mask/{label}', value, 0)
# This is needed because permutations (long int) can't be cast to float32.
except (OverflowError, ValueError):
summary_writer.text(f'propagated_mask/{label}', str(value), 0)
logging.error('Could not write mask/%s to tensorflow summary as float32'
', writing as string instead.', label)
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(
mask_stats, path.join(experiment_dir, 'propagated_mask_stats.json'))
model, initial_state = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
masks=mask)
if FLAGS.optimizer == 'Adam':
optimizer = flax.optim.Adam(
learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
elif FLAGS.optimizer == 'Momentum':
optimizer = flax.optim.Momentum(
learning_rate=FLAGS.lr,
beta=FLAGS.momentum,
weight_decay=FLAGS.weight_decay,
nesterov=False)
steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size
if FLAGS.lr_schedule == 'constant':
lr_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.lr, steps_per_epoch)
elif FLAGS.lr_schedule == 'stepped':
lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, lr_schedule_steps)
elif FLAGS.lr_schedule == 'cosine':
lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, FLAGS.epochs)
else:
raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}')
if jax.host_id() == 0:
trainer = training.Trainer(
optimizer,
model,
initial_state,
dataset,
rng,
summary_writer=summary_writer,
)
else:
trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)
_, best_metrics = trainer.train(
FLAGS.epochs,
lr_fn=lr_fn,
update_iter=FLAGS.update_iterations,
update_epoch=FLAGS.update_epoch,
)
logging.info('Best metrics: %s', str(best_metrics))
if jax.host_id() == 0:
if FLAGS.dump_json:
utils.dump_dict_json(best_metrics,
path.join(experiment_dir, 'best_metrics.json'))
for label, value in best_metrics.items():
summary_writer.scalar(f'best/{label}', value,
FLAGS.epochs * steps_per_epoch)
summary_writer.close()
def main(argv: List[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
run_training()
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/experimental/jax/random_mask_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.random_mask."""
import glob
from os import path
import tempfile
from absl.testing import absltest
from absl.testing import flagsaver
from rigl.experimental.jax import random_mask
class RandomMaskTest(absltest.TestCase):
def test_run_fc(self):
"""Test random mask driver with fully-connected model."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
model='MNIST_FC',
)
with flagsaver.flagsaver(**self._eval_flags):
random_mask.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_conv(self):
"""Test random mask driver with CNN model."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
model='MNIST_CNN',
)
with flagsaver.flagsaver(**self._eval_flags):
random_mask.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_random(self):
"""Test random mask driver with per-neuron sparsity."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
mask_type='random',
)
with flagsaver.flagsaver(**self._eval_flags):
random_mask.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_per_neuron(self):
"""Test random mask driver with per-neuron sparsity."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
mask_type='per_neuron',
)
with flagsaver.flagsaver(**self._eval_flags):
random_mask.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_symmetric(self):
"""Test random mask driver with per-neuron sparsity."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
mask_type='symmetric',
)
with flagsaver.flagsaver(**self._eval_flags):
random_mask.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/requirements.txt
================================================
absl-py>=0.10.0
flax>=0.2.2
jax>=0.2.0
jaxlib>=0.1.55
tensorboard>=2.3.0
tensorflow>=2.3.1
tensorflow_datasets>=3.2.1
================================================
FILE: rigl/experimental/jax/run.sh
================================================
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
set -x
virtualenv -p python3 .
source ./bin/activate
pip install -r weight_symmetry/requirements.txt
TEST_NAMES='training.training_test
train_test
fixed_param_test
shuffled_mask_test
models.model_factory_test
models.cifar10_cnn_test
models.mnist_cnn_test
models.mnist_fc_test
utils.utils_test
prune_test
random_mask_test
pruning.mask_factory_test
pruning.init_test
pruning.symmetry_test
pruning.pruning_test
pruning.masked_test
datasets.dataset_factory_test
datasets.dataset_base_test
datasets.cifar10_test
datasets.mnist_test'
IFS=$'\n' readarray -t tests <<<$TEST_NAMES
for test in ${tests[@]}; do
python3 -m "weight_symmetry.${test}"
done
================================================
FILE: rigl/experimental/jax/shuffled_mask.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Symmetry: Train model with randomly shuffled sparse mask."""
# TODO: Refactor drivers to separate logic from flags/IO.
import ast
from os import path
from typing import List, Sequence
import uuid
from absl import app
from absl import flags
from absl import logging
import flax
from flax.metrics import tensorboard
from flax.training import lr_schedule
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.pruning import mask_factory
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.pruning import symmetry
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils
experiment_dir = '{}/{}/'.format(FLAGS.experiment_dir, work_unit_id)
logging.info('Saving experimental results to %s', experiment_dir)
host_count = jax.host_count()
local_device_count = jax.local_device_count()
logging.info('Device count: %d, host count: %d, local device count: %d',
jax.device_count(), host_count, local_device_count)
if jax.host_id() == 0:
summary_writer = tensorboard.SummaryWriter(experiment_dir)
dataset = dataset_factory.create_dataset(
FLAGS.dataset,
FLAGS.batch_size,
FLAGS.batch_size_test,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)
rng = jax.random.PRNGKey(FLAGS.random_seed)
input_shape = (1,) + dataset.shape
base_model, _ = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes)
logging.info('Generating random mask based on model')
# Re-initialize the RNG to maintain same training pattern (as in prune code).
mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed)
mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng,
FLAGS.mask_sparsity)
if jax.host_id() == 0:
mask_stats = symmetry.get_mask_stats(mask)
logging.info('Mask stats: %s', str(mask_stats))
for label, value in mask_stats.items():
try:
summary_writer.scalar(f'mask/{label}', value, 0)
# This is needed because permutations (long int) can't be cast to float32.
except (OverflowError, ValueError):
summary_writer.text(f'mask/{label}', str(value), 0)
logging.error('Could not write mask/%s to tensorflow summary as float32'
', writing as string instead.', label)
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(
mask_stats, path.join(experiment_dir, 'mask_stats.json'))
mask = masked.propagate_masks(mask)
if jax.host_id() == 0:
mask_stats = symmetry.get_mask_stats(mask)
logging.info('Propagated mask stats: %s', str(mask_stats))
for label, value in mask_stats.items():
try:
summary_writer.scalar(f'propagated_mask/{label}', value, 0)
# This is needed because permutations (long int) can't be cast to float32.
except (OverflowError, ValueError):
summary_writer.text(f'propagated_mask/{label}', str(value), 0)
logging.error('Could not write mask/%s to tensorflow summary as float32'
', writing as string instead.', label)
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(
mask_stats, path.join(experiment_dir, 'propagated_mask_stats.json'))
model, initial_state = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
masks=mask)
if FLAGS.optimizer == 'Adam':
optimizer = flax.optim.Adam(
learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
elif FLAGS.optimizer == 'Momentum':
optimizer = flax.optim.Momentum(
learning_rate=FLAGS.lr,
beta=FLAGS.momentum,
weight_decay=FLAGS.weight_decay,
nesterov=False)
steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size
if FLAGS.lr_schedule == 'constant':
lr_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.lr, steps_per_epoch)
elif FLAGS.lr_schedule == 'stepped':
lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, lr_schedule_steps)
elif FLAGS.lr_schedule == 'cosine':
lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, FLAGS.epochs)
else:
raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule))
if jax.host_id() == 0:
trainer = training.Trainer(
optimizer,
model,
initial_state,
dataset,
rng,
summary_writer=summary_writer,
)
else:
trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)
_, best_metrics = trainer.train(
FLAGS.epochs,
lr_fn=lr_fn,
update_iter=FLAGS.update_iterations,
update_epoch=FLAGS.update_epoch,
)
logging.info('Best metrics: %s', str(best_metrics))
if jax.host_id() == 0:
if FLAGS.dump_json:
utils.dump_dict_json(best_metrics,
path.join(experiment_dir, 'best_metrics.json'))
for label, value in best_metrics.items():
summary_writer.scalar('best/{}'.format(label), value,
FLAGS.epochs * steps_per_epoch)
summary_writer.close()
def main(argv: List[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
run_training()
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/experimental/jax/shuffled_mask_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.shuffled_mask."""
import glob
from os import path
import tempfile
from absl.testing import absltest
from absl.testing import flagsaver
from rigl.experimental.jax import shuffled_mask
class ShuffledMaskTest(absltest.TestCase):
def test_run_fc(self):
"""Tests if the driver for shuffled training runs correctly with FC NN."""
experiment_dir = tempfile.mkdtemp()
eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
model='MNIST_FC',
)
with flagsaver.flagsaver(**eval_flags):
shuffled_mask.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_conv(self):
"""Tests if the driver for shuffled training runs correctly with CNN."""
experiment_dir = tempfile.mkdtemp()
eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
model='MNIST_CNN',
)
with flagsaver.flagsaver(**eval_flags):
shuffled_mask.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_random(self):
"""Test random mask driver with per-neuron sparsity."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
mask_type='random',
)
with flagsaver.flagsaver(**self._eval_flags):
shuffled_mask.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_per_neuron(self):
"""Test random mask driver with per-neuron sparsity."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
mask_type='per_neuron',
)
with flagsaver.flagsaver(**self._eval_flags):
shuffled_mask.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_run_symmetric(self):
"""Test random mask driver with per-neuron sparsity."""
experiment_dir = tempfile.mkdtemp()
self._eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
mask_type='symmetric',
)
with flagsaver.flagsaver(**self._eval_flags):
shuffled_mask.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/train.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Symmetry: Train Model.
Trains a model from scratch, saving the relevant early weight snapshots.
"""
import ast
from os import path
from typing import List, Sequence
import uuid
from absl import app
from absl import flags
from absl import logging
import flax
from flax.metrics import tensorboard
from flax.training import lr_schedule
import jax
import jax.numpy as np
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.training import training
FLAGS = flags.FLAGS
MODEL_LIST: Sequence[str] = tuple(model_factory.MODELS.keys())
DATASET_LIST: Sequence[str] = tuple(dataset_factory.DATASETS.keys())
flags.DEFINE_enum('model', MODEL_LIST[0], MODEL_LIST,
'Model to train.')
flags.DEFINE_enum('dataset', DATASET_LIST[0], DATASET_LIST,
'Dataset to train on.')
flags.DEFINE_enum('optimizer', 'Adam', ['Momentum', 'Adam'],
'Optimizer to use.')
flags.DEFINE_float('learning_rate', 0.01, 'Learning rate.', short_name='lr')
flags.DEFINE_float('weight_decay', 1e-5, 'Weight decay penalty.',
short_name='wd')
flags.DEFINE_float('momentum', 0.9, 'Momentum weighting.')
flags.DEFINE_string(
'lr_schedule', default='stepped',
help=('Learning rate schedule type; constant, stepped or cosine.'))
flags.DEFINE_string(
'lr_schedule_steps', default='[[50, 0.01], [70, 0.001], [90, 0.0001]]',
help=('Learning rate schedule steps as a Python list; '
'[[step1_epoch, step1_lr_scale], '
'[step2_epoch, step2_lr_scale], ...]'))
flags.DEFINE_integer(
'batch_size', 128, 'Training minibatch size.', lower_bound=1)
flags.DEFINE_integer(
'batch_size_test',
128,
'Test minibatch size.',
lower_bound=1)
flags.DEFINE_integer(
'epochs', 100, 'Number of epochs to train over.', lower_bound=1)
flags.DEFINE_integer('random_seed', 42, 'Random seed.')
flags.DEFINE_integer('shuffle_buffer_size', 1024,
'Dataset shuffle buffer size.')
flags.DEFINE_string(
'experiment_dir', '/tmp/experiments',
'Path to store experiment output in, i.e. models, snapshots.')
flags.DEFINE_integer(
'update_iterations',
1000,
'Epoch interval after which to evaluate test error.',
lower_bound=1)
flags.DEFINE_integer(
'update_epoch', 10, 'Epoch interval after which to evaluate test error.',
lower_bound=1)
def run_training():
"""Trains a model."""
print('Logging to {}'.format(FLAGS.log_dir))
work_unit_id = uuid.uuid4()
experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))
logging.info('Saving experimental results to %s', experiment_dir)
host_count = jax.host_count()
local_device_count = jax.local_device_count()
logging.info('Device count: %d, host count: %d, local device count: %d',
jax.device_count(), host_count, local_device_count)
if jax.host_id() == 0:
summary_writer = tensorboard.SummaryWriter(experiment_dir)
dataset = dataset_factory.create_dataset(
FLAGS.dataset,
FLAGS.batch_size,
FLAGS.batch_size_test,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)
rng = jax.random.PRNGKey(FLAGS.random_seed)
input_shape = (1,) + dataset.shape
model, initial_state = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, np.float32),),
num_classes=dataset.num_classes)
if FLAGS.optimizer == 'Adam':
optimizer = flax.optim.Adam(
learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
elif FLAGS.optimizer == 'Momentum':
optimizer = flax.optim.Momentum(
learning_rate=FLAGS.lr,
beta=FLAGS.momentum,
weight_decay=FLAGS.weight_decay,
nesterov=False)
steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size
if FLAGS.lr_schedule == 'constant':
lr_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.lr, steps_per_epoch)
elif FLAGS.lr_schedule == 'stepped':
lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, lr_schedule_steps)
elif FLAGS.lr_schedule == 'cosine':
lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, FLAGS.epochs)
else:
raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule))
if jax.host_id() == 0:
trainer = training.Trainer(
optimizer,
model,
initial_state,
dataset,
rng,
summary_writer=summary_writer,
)
else:
trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)
_, best_metrics = trainer.train(
FLAGS.epochs,
lr_fn=lr_fn,
update_iter=FLAGS.update_iterations,
update_epoch=FLAGS.update_epoch)
logging.info('Best metrics: %s', str(best_metrics))
if jax.host_id() == 0:
for label, value in best_metrics.items():
summary_writer.scalar('best/{}'.format(label), value,
FLAGS.epochs * steps_per_epoch)
summary_writer.close()
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
run_training()
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/experimental/jax/train_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.train."""
import glob
from os import path
import tempfile
from absl.testing import absltest
from absl.testing import flagsaver
from rigl.experimental.jax import train
class TrainTest(absltest.TestCase):
def test_train_driver_run(self):
"""Tests that the training driver runs, and outputs a TF summary."""
experiment_dir = tempfile.mkdtemp()
eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
train.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/training/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/training/training.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common training code.
This module contains utility functions for training NN.
Attributes:
LABELKEY: The key used to retrieve a label from the batch dictionary.
DATAKEY: The key used to retrieve an input image from the batch dictionary.
PruningRateFnType: Typing alias for a valid pruning rate function.
"""
from collections import abc
import functools
import time
from typing import Callable, Dict, Mapping, Optional, Tuple, Union
from absl import logging
import flax
from flax import jax_utils
from flax.training import common_utils
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_base
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.pruning import pruning
from rigl.experimental.jax.pruning import symmetry
from rigl.experimental.jax.utils import utils
import tensorflow.compat.v2 as tf
LABELKEY = dataset_base.ImageDataset.LABELKEY
DATAKEY = dataset_base.ImageDataset.DATAKEY
PruningRateFnType = Union[Mapping[str, Callable[[int], float]], Callable[[int],
float]]
def _shard_batch(xs):
"""Shards a batch for a pmap, based on the number of devices."""
local_device_count = jax.local_device_count()
def _prepare(x):
return x.reshape((local_device_count, -1) + x.shape[1:])
return jax.tree_map(_prepare, xs)
def train_step(
optimizer: flax.optim.Optimizer, batch: Mapping[str, jnp.array], # pytype: disable=module-attr
rng: Callable[[int], jnp.array], state: flax.deprecated.nn.Collection,
learning_rate_fn: Callable[[int], float]
) -> Tuple[flax.optim.Optimizer, flax.deprecated.nn.Collection, float, float]: # pytype: disable=module-attr
"""Performs training for one minibatch.
Args:
optimizer: Optimizer to use.
batch: Minibatch to train with.
rng: Random number generator, i.e. jax.random.PRNGKey, to use for model
training, e.g. dropout.
state: Model state.
learning_rate_fn: A function that returns the learning rate given the step.
Returns:
A tuple consisting of the new optimizer, new state, mini-batch loss, and
gradient norm.
"""
def loss_fn(
model: flax.deprecated.nn.Model
) -> Tuple[float, Tuple[flax.deprecated.nn.Collection, jnp.array]]:
"""Evaluates the loss function.
Args:
model: The model with which to evaluate the loss.
Returns:
Tuple of the loss for the mini-batch, and model state.
"""
with flax.deprecated.nn.stateful(state) as new_state:
with flax.deprecated.nn.stochastic(rng):
logits = model(batch[DATAKEY])
loss = utils.cross_entropy_loss(logits, batch[LABELKEY])
return loss, new_state
lr = learning_rate_fn(optimizer.state.step)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, new_state), grad = grad_fn(optimizer.target)
grad = jax.lax.pmean(grad, 'batch')
new_opt = optimizer.apply_gradient(grad, learning_rate=lr)
grad_norm = jnp.linalg.norm(utils.param_as_array(grad))
return new_opt, new_state, loss, grad_norm
class Trainer:
"""Training class with the state and methods for training a neural network.
Attributes:
optimizer: Optimizer used for training, None if training hasn't begun.
state: Model state used for training.
"""
def __init__(
self,
optimizer_def: flax.optim.OptimizerDef, # pytype: disable=module-attr
initial_model: flax.deprecated.nn.Model,
initial_state: flax.deprecated.nn.Collection,
dataset: jnp.array,
rng: Callable[[int], jnp.array] = None,
summary_writer: Optional[tf.summary.SummaryWriter] = None,
):
"""Creates a Trainer object.
Args:
optimizer_def: The flax optimizer def (i.e. not instantiated with a model
using .create) to use for training.
initial_model: The initial model to train.
initial_state: The initial state of the model.
dataset: The training dataset.
rng: Random number generator, i.e. jax.random.PRNGKey, to use for model
training, e.g. dropout.
summary_writer: An optional tensorboard summary writer for logging
self._rng = rng
if self._rng is None:
self._rng = jax.random.PRNGKey(42)
def _update_optimizer(self, model: flax.deprecated.nn.Model):
"""Updates the optimizer based on the given model."""
self.optimizer = jax_utils.replicate(
self._optimizer_def.create(model))
def train(
self,
num_epochs: int,
lr_fn: Optional[Callable[[int], float]] = None,
pruning_rate_fn: Optional[PruningRateFnType] = None,
update_iter: int = 100,
update_epoch: int = 10
) -> Tuple[flax.deprecated.nn.Model, Mapping[str, Union[int, float, Mapping[
str, float]]]]:
"""Trains the model over the given number of epochs.
Args:
num_epochs: The total number of epochs to train over.
lr_fn: The learning rate function, takes the current iteration/step as an
argument and returns the current learning rate, uses constant learning
rate if no function is provided.
pruning_rate_fn: The pruning rate function, takes the current epoch as an
argument and returns the current pruning rate, no further pruning is
performed during training if not set. Can be a dictionary, containing
the pruning rate schedule functions for each layer, or a single function
for all layers.
update_iter: Period of iterations in which to log/update per-batch
metrics.
update_epoch: Period of epochs in which to log/update full training/test
metrics.
Returns:
Tuple consisting of the best model found during training, and metrics.
Raises:
ValueError: If the batch size of the data set is not evenly divisible by
number of devices, or the model batch size is not the training
data batch size/number of jax devices.
"""
best_test_acc = 0
best_train_loss = None
best_iter = None
if lr_fn is None:
lr_fn = lambda _: self.optimizer.optimizer_def.hyper_params.learning_rate
host_count = jax.host_count()
device_count = jax.device_count()
local_device_count = jax.local_device_count()
logging.info('JAX hosts %d, devices: %d, local devices: %d', host_count,
device_count, local_device_count)
# TODO Implement multi-host training.
if host_count > 1:
raise NotImplementedError('Multi-host training is not supported yet, '
'see b/155550457.')
if self._dataset.batch_size % device_count > 0:
raise ValueError(
'Train batch size ({}) must be divisible by number of local devices '
'({})'.format(self._dataset.batch_size, local_device_count))
if self._dataset.batch_size_test % device_count > 0:
raise ValueError(
'Test batch size ({}) must be divisible by number of local devices '
'({})'.format(self._dataset.batch_size_test, local_device_count))
# Required to use state and optimizer with jax.pmap.
state = jax_utils.replicate(self.state)
self._update_optimizer(self._initial_model)
p_train_step = jax.pmap(
functools.partial(train_step, learning_rate_fn=lr_fn),
axis_name='batch')
# Function to sync the batch statistics across replicas.
p_synchronized_batch_stats = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')
p_cosine_similarity = functools.partial(utils.cosine_similarity_model,
self._initial_model)
p_vector_difference_norm = functools.partial(
utils.vector_difference_norm_model, self._initial_model)
pruning_rate = None
mask = None
cumulative_grad_norm = 0
start_time = time.time()
# Main training loop.
for epoch in range(num_epochs):
if epoch % update_epoch == 0 or epoch == num_epochs - 1:
epoch_start_time = time.time()
# If we get different schedules for different layers.
if isinstance(pruning_rate_fn, abc.Mapping):
next_pruning_rate = {
layer: layer_fn(epoch)
for layer, layer_fn in pruning_rate_fn.items()
}
elif pruning_rate_fn:
next_pruning_rate = pruning_rate_fn(epoch)
# If pruning rate has changed/is first epoch, we need to update mask.
# Note: pruning_rate could be zero, so must explicitly check it's None.
if pruning_rate_fn and (pruning_rate is None or
pruning_rate != next_pruning_rate):
pruning_rate = next_pruning_rate
logging.info('[%d] Pruning Rate: %s', epoch, str(pruning_rate))
# Unreplicate optimizer/current model, and mask.
self.optimizer = jax_utils.unreplicate(self.optimizer)
mask = jax_utils.unreplicate(mask) if mask else None
# Performs pruning to get updated mask.
mask = pruning.prune(self.optimizer.target, pruning_rate, mask=mask)
logging.info('[%d] Mask Sparsity: %0.3f', epoch,
masked.mask_sparsity(mask))
for layer, layer_mask in sorted(mask.items()):
if layer_mask:
logging.info('[%d] Layer: %s, Mask Sparsity: %0.3f', epoch, layer,
masked.mask_layer_sparsity(layer_mask))
if jax.host_id() == 0:
mask_stats = symmetry.get_mask_stats(mask)
logging.info('Mask stats: %s', str(mask_stats))
if self._summary_writer:
for label, value in mask_stats.items():
try:
self._summary_writer.scalar(f'mask_{epoch}/{label}', value, 0)
# Needed when permutations (long int) can't be cast to float32.
except (OverflowError, ValueError):
self._summary_writer.text(f'mask_{epoch}/{label}', str(value),
0)
logging.error(
'Could not write mask_%d/%s to tensorflow summary as float32'
', writing as string instead.', epoch, label)
# Creates a new optimizer, based on a new model with new mask.
self._update_optimizer(
model_factory.update_model(self.optimizer.target, masks=mask))
# Begins epoch.
for batch in self._dataset.get_train():
# Note: Because of replicate, step has # device identical vals.
step = jax_utils.unreplicate(self.optimizer.state.step)
if step % update_iter == 0:
batch_start_time = time.time()
# These are required for pmap call.
self._rng, step_key = jax.random.split(self._rng)
batch = _shard_batch(batch)
sharded_keys = common_utils.shard_prng_key(step_key)
(self.optimizer, state, opt_loss,
grad_norm) = p_train_step(self.optimizer, batch, sharded_keys, state)
if state.state:
state = p_synchronized_batch_stats(state)
grad_norm = jax_utils.unreplicate(grad_norm)
cumulative_grad_norm += grad_norm
# Per-iteration status/metrics update.
if jax.host_id() == 0 and step % update_iter == 0:
batch_time = time.time() - batch_start_time
if self._summary_writer is not None:
self._summary_writer.scalar('training/train_batch_loss',
jnp.mean(opt_loss),
step)
self._summary_writer.scalar('training/gradient_norm', grad_norm,
step)
logging.info('[epoch %d] %d, loss %0.5f, lr %0.3f, %0.3f sec', epoch,
step, jnp.mean(opt_loss), lr_fn(step), batch_time)
# Per-epoch status/metrics update.
if (jax.host_id() == 0 and
(epoch % update_epoch == 0 or epoch == num_epochs - 1)):
epoch_time = time.time() - epoch_start_time
cosine_distance = p_cosine_similarity(
jax_utils.unreplicate(self.optimizer.target))
vector_difference_norm = p_vector_difference_norm(
jax_utils.unreplicate(self.optimizer.target))
train_metrics = eval_model(self.optimizer.target, state,
self._dataset.get_train())
test_metrics = eval_model(self.optimizer.target, state,
self._dataset.get_test())
train_loss = train_metrics['loss']
train_acc = train_metrics['accuracy']
test_loss = test_metrics['loss']
test_acc = test_metrics['accuracy']
if jax.host_id() == 0:
metrics = {
'wallclock_time':
float(epoch_time),
'train_accuracy':
float(train_acc),
'train_avg_loss':
float(train_loss),
'test_accuracy':
float(test_acc),
'test_avg_loss':
float(test_loss),
'lr':
float(lr_fn(step)),
'cosine_distance':
float(cosine_distance),
'cumulative_gradient_norm':
float(cumulative_grad_norm),
'vector_difference_norm':
float(vector_difference_norm),
}
if self._summary_writer is not None:
for label, value in metrics.items():
self._summary_writer.scalar('training/{}'.format(label), value,
step)
if test_acc >= best_test_acc:
best_model = self.optimizer.target
best_test_acc = test_acc
best_test_metrics = {
'train_avg_loss': float(train_loss),
'train_accuracy': float(train_acc),
'test_avg_loss': float(test_loss),
'test_accuracy': float(test_acc),
'step': int(step),
'cosine_distance': float(cosine_distance),
'cumulative_gradient_norm': float(cumulative_grad_norm),
'vector_difference_norm': float(vector_difference_norm),
}
best_iter = step
if best_train_loss is None or train_loss <= best_train_loss:
best_train_loss = train_loss
best_train_metrics = {
'train_avg_loss': float(train_loss),
'train_accuracy': float(train_acc),
'test_avg_loss': float(test_loss),
'test_accuracy': float(test_acc),
'step': int(step),
'cosine_distance': float(cosine_distance),
'cumulative_gradient_norm': float(cumulative_grad_norm),
'vector_difference_norm': float(vector_difference_norm),
}
log_format_str = (
'[epoch %d] train avg. loss %0.4f, train acc. %0.4f, test avg. '
'loss %0.4f, test acc. %0.4f, %0.4f sec, cosine sim.: %0.3f, cum. '
'grad. norm: %0.3f, vector diff: %0.3f')
log_vars = [
epoch, train_loss, train_acc, test_loss, test_acc, epoch_time,
float(cosine_distance),
float(cumulative_grad_norm),
float(vector_difference_norm)
]
logging.info(log_format_str, *log_vars)
# End epoch.
training_time = time.time() - start_time
logging.info('Training finished, Total wallclock time: %0.2f sec',
training_time)
if jax.host_id() == 0 and self._summary_writer is not None:
for label, value in best_test_metrics.items():
self._summary_writer.scalar('best_test_acc/{}'.format(label), value,
best_iter)
logging.info('Best Test Accuracy: iteration %d, test acc. %0.5f',
best_test_metrics['step'], best_test_acc)
if jax.host_id() == 0 and self._summary_writer is not None:
for label, value in best_test_metrics.items():
self._summary_writer.scalar(
'best_train_loss/{}'.format(label),
value,
step=best_train_metrics['step'])
logging.info('Best Train Loss: iteration %d, test loss. %0.5f',
best_train_metrics['step'], best_train_loss)
return (best_model, best_test_metrics)
def _eval_step(model: flax.deprecated.nn.Model,
state: flax.deprecated.nn.Collection,
batch: Mapping[str, jnp.array]) -> Dict[str, jnp.array]:
"""Evaluates a mini-batch of data.
Args:
model: The model to use to evaluate.
state: Model state containing state for stateful flax.deprecated.nn
functions, such as batch normalization.
batch: Mini-batch of data to evaluate on.
Returns:
Dictionary consisting of the mini-batch the loss and accuracy.
"""
state = jax.lax.pmean(state, 'batch')
with flax.deprecated.nn.stateful(state, mutable=False):
logits = model(batch[DATAKEY], train=False)
metrics = utils.compute_metrics(logits, batch[LABELKEY])
return metrics
def eval_model(model: flax.deprecated.nn.Model,
state: flax.deprecated.nn.Collection,
eval_dataset: jnp.array) -> Dict[str, float]:
"""Evaluates the given model using the given dataset.
Args:
model: The model the evaluate.
state: Model state containing state for stateful flax.deprecated.nn
functions, such as batch normalization.
eval_dataset: Dataset to evaluate the model over.
Returns:
Dictionary containing the average loss and accuracy of the model on the given
dataset.
"""
p_eval_step = jax.pmap(_eval_step, axis_name='batch')
batch_sizes = []
metrics = []
for batch in eval_dataset:
batch_size = len(batch[LABELKEY])
# These are required for pmap call.
batch = _shard_batch(batch)
batch_metrics = p_eval_step(model, state, batch)
batch_sizes.append(batch_size)
metrics.append(batch_metrics)
# Note: use weighted mean, since we do mean of means with potentially
# different batch sizes otherwise.
batch_sizes = jnp.array(batch_sizes)
weights = batch_sizes / jnp.sum(batch_sizes)
eval_metrics = common_utils.get_metrics(metrics)
return jax.tree_map(lambda x: (weights * x).sum(), eval_metrics)
================================================
FILE: rigl/experimental/jax/training/training_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.training.training."""
import functools
import math
from absl.testing import absltest
import flax
from flax import jax_utils
from flax.metrics import tensorboard
from flax.training import common_utils
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.training import training
class TrainingTest(absltest.TestCase):
"""Tests functions for training loop and training convenience functions."""
def setUp(self):
super().setUp()
self._batch_size = 128 # Note: Tests are run on GPU/TPU.
self._batch_size_test = 128
self._shuffle_buffer_size = 1024
self._rng = jax.random.PRNGKey(42)
self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
self._num_classes = 10
self._num_epochs = 1
self._learning_rate_fn = lambda _: 0.01
self._weight_decay = 0.0001
self._momentum = 0.9
self._rng = jax.random.PRNGKey(42)
self._min_loss = jnp.finfo(float).eps
self._max_loss = 2.0 * math.log(self._num_classes)
self._dataset_name = 'MNIST'
self._model_name = 'MNIST_CNN'
self._summarywriter = tensorboard.SummaryWriter('/tmp/')
self._dataset = dataset_factory.create_dataset(
self._dataset_name,
self._batch_size,
self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
self._model, self._state = model_factory.create_model(
self._model_name,
self._rng, (self._input_shape,),
num_classes=self._num_classes)
self._optimizer = flax.optim.Momentum( # pytype: disable=module-attr
learning_rate=self._learning_rate_fn(0),
beta=self._momentum,
weight_decay=self._weight_decay)
def test_train_one_step(self):
"""Tests training loop over one step."""
iterator = self._dataset.get_train()
batch = next(iterator)
state = jax_utils.replicate(self._state)
optimizer = jax_utils.replicate(self._optimizer.create(self._model))
self._rng, step_key = jax.random.split(self._rng)
batch = training._shard_batch(batch)
sharded_keys = common_utils.shard_prng_key(step_key)
p_train_step = jax.pmap(
functools.partial(
training.train_step, learning_rate_fn=self._learning_rate_fn),
axis_name='batch')
_, _, loss, gradient_norm = p_train_step(optimizer, batch, sharded_keys,
state)
loss = jnp.mean(loss)
gradient_norm = jax_utils.unreplicate(gradient_norm)
with self.subTest(name='test_loss_range'):
self.assertBetween(loss, self._min_loss, self._max_loss)
with self.subTest(name='test_gradient_norm'):
self.assertGreaterEqual(gradient_norm, 0)
def test_train_one_epoch(self):
"""Tests training loop over one epoch."""
trainer = training.Trainer(self._optimizer, self._model, self._state,
self._dataset)
with self.subTest(name='trainer_instantiation'):
self.assertIsInstance(trainer, training.Trainer)
best_model, best_metrics = trainer.train(self._num_epochs)
with self.subTest(name='best_model_type'):
self.assertIsInstance(best_model, flax.deprecated.nn.Model)
with self.subTest(name='train_accuracy'):
self.assertBetween(best_metrics['train_accuracy'], 0., 1.)
with self.subTest(name='train_avg_loss'):
self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,
self._max_loss)
with self.subTest(name='step'):
self.assertGreater(best_metrics['step'], 0)
with self.subTest(name='cosine_distance'):
self.assertBetween(best_metrics['cosine_distance'], 0., 1.)
with self.subTest(name='cumulative_gradient_norm'):
self.assertGreater(best_metrics['cumulative_gradient_norm'], 0)
with self.subTest(name='test_accuracy'):
self.assertBetween(best_metrics['test_accuracy'], 0., 1.)
with self.subTest(name='test_avg_loss'):
self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,
self._max_loss)
def test_train_one_epoch_tensorboard(self):
"""Tests training loop over one epoch, with tensorboard."""
trainer = training.Trainer(
self._optimizer,
self._model,
self._state,
self._dataset,
summary_writer=self._summarywriter)
with self.subTest(name='TrainerInstantiation'):
self.assertIsInstance(trainer, training.Trainer)
best_model, best_metrics = trainer.train(self._num_epochs)
with self.subTest(name='best_model_type'):
self.assertIsInstance(best_model, flax.deprecated.nn.Model)
with self.subTest(name='train_accuracy'):
self.assertBetween(best_metrics['train_accuracy'], 0., 1.)
with self.subTest(name='train_avg_loss'):
self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,
self._max_loss)
with self.subTest(name='step'):
self.assertGreater(best_metrics['step'], 0)
with self.subTest(name='cosine_distance'):
self.assertBetween(best_metrics['cosine_distance'], 0., 1.)
with self.subTest(name='cumulative_gradient_norm'):
self.assertGreater(best_metrics['cumulative_gradient_norm'], 0)
with self.subTest(name='test_accuracy'):
self.assertBetween(best_metrics['test_accuracy'], 0., 1.)
with self.subTest(name='test_avg_loss'):
self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,
self._max_loss)
def test_train_one_epoch_pruning_global_schedule(self):
"""Tests training loop over one epoch with global pruning rate schedule."""
trainer = training.Trainer(self._optimizer, self._model, self._state,
self._dataset)
with self.subTest(name='trainer_instantiation'):
self.assertIsInstance(trainer, training.Trainer)
best_model, best_metrics = trainer.train(self._num_epochs,
pruning_rate_fn=lambda _: 0.5)
with self.subTest(name='best_model_type'):
self.assertIsInstance(best_model, flax.deprecated.nn.Model)
with self.subTest(name='train_accuracy'):
self.assertBetween(best_metrics['train_accuracy'], 0., 1.)
with self.subTest(name='train_avg_loss'):
self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,
self._max_loss)
with self.subTest(name='step'):
self.assertGreater(best_metrics['step'], 0)
with self.subTest(name='cosine_distance'):
self.assertBetween(best_metrics['cosine_distance'], 0., 1.)
with self.subTest(name='cumulative_gradient_norm'):
self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.)
with self.subTest(name='test_accuracy'):
self.assertBetween(best_metrics['test_accuracy'], 0., 1.)
with self.subTest(name='test_avg_loss'):
self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,
self._max_loss)
def test_train_one_epoch_pruning_local_schedule(self):
"""Tests training loop over one epoch with local pruning rate schedule."""
trainer = training.Trainer(self._optimizer, self._model, self._state,
self._dataset)
with self.subTest(name='trainer_instantiation'):
self.assertIsInstance(trainer, training.Trainer)
best_model, best_metrics = trainer.train(
self._num_epochs, pruning_rate_fn={'MaskedModule_0': lambda _: 0.5})
with self.subTest(name='best_model_type'):
self.assertIsInstance(best_model, flax.deprecated.nn.Model)
with self.subTest(name='train_accuracy'):
self.assertBetween(best_metrics['train_accuracy'], 0., 1.)
with self.subTest(name='train_avg_loss'):
self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,
self._max_loss)
with self.subTest(name='step'):
self.assertGreater(best_metrics['step'], 0)
with self.subTest(name='cosine_distance'):
self.assertBetween(best_metrics['cosine_distance'], 0., 1.)
with self.subTest(name='cumulative_gradient_norm'):
self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.)
with self.subTest(name='test_accuracy'):
self.assertBetween(best_metrics['test_accuracy'], 0., 1.)
with self.subTest(name='test_avg_loss'):
self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,
self._max_loss)
def test_eval_batch(self):
"""Tests model per-batch evaluation function."""
state = jax_utils.replicate(self._state)
optimizer = jax_utils.replicate(self._optimizer.create(self._model))
iterator = self._dataset.get_test()
batch = next(iterator)
batch = training._shard_batch(batch)
metrics = jax.pmap(training._eval_step, axis_name='batch')(
optimizer.target, state, batch)
loss = jnp.mean(metrics['loss'])
accuracy = jnp.mean(metrics['accuracy'])
with self.subTest(name='test_eval_batch_loss'):
self.assertBetween(loss, self._min_loss, self._max_loss)
with self.subTest(name='test_eval_batch_accuracy'):
self.assertBetween(accuracy, 0., 1.)
def test_eval(self):
"""Tests model evaluation function."""
state = jax_utils.replicate(self._state)
optimizer = jax_utils.replicate(self._optimizer.create(self._model))
metrics = training.eval_model(optimizer.target, state,
self._dataset.get_test())
loss = metrics['loss']
accuracy = metrics['accuracy']
with self.subTest(name='test_eval_loss'):
self.assertBetween(loss, 0., 2.0*math.log(self._num_classes))
with self.subTest(name='test_eval_accuracy'):
self.assertBetween(accuracy, 0., 1.)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/utils/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/utils/utils.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convenience Functions for NN training.
Misc. common functions used in training NN models.
"""
import functools
import itertools
import json
import operator
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, TypeVar
import flax
from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np
def cross_entropy_loss(log_softmax_logits,
labels):
"""Returns the cross-entropy classification loss.
Args:
log_softmax_logits: The log of the softmax of the logits for the mini-batch,
e.g. as output by jax.nn.log_softmax(logits).
labels: The labels for the mini-batch.
"""
num_classes = log_softmax_logits.shape[-1]
one_hot_labels = common_utils.onehot(labels, num_classes)
return -jnp.sum(one_hot_labels * log_softmax_logits) / labels.size
def compute_metrics(logits,
labels):
"""Computes the classification loss and accuracy for a mini-batch.
Args:
logits: NN model's logit outputs for the mini-batch.
labels: The classification labels for the mini-batch.
Returns:
Metrics dictionary where 'loss' the mini-batch loss and 'accuracy' is
the classification accuracy.
Raises:
ValueError: If the given logits array is not of the correct shape.
"""
if len(logits.shape) != 2:
raise ValueError(
'Expected an array of (BATCHSIZE, NUM_CLASSES), but got {}'.format(
logits.shape))
metrics = {
'loss': cross_entropy_loss(logits, labels),
'accuracy': jnp.mean(jnp.argmax(logits, -1) == labels)
}
return jax.lax.pmean(metrics, 'batch')
def _np_converter(obj):
"""Explicitly cast Numpy types not recognized by JSON serializer."""
if isinstance(obj, jnp.integer) or isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, jnp.floating) or isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, jnp.ndarray) or isinstance(obj, np.ndarray):
return obj.tolist()
def dump_dict_json(data_dict, path):
"""Dumps a dictionary to a JSON file, ensuring Numpy types are cast correctly.
Args:
data_dict: A metrics dictionary.
path: Path of the JSON file to save.
Raises:
"""
with open(path, 'w') as json_file:
json.dump(data_dict, json_file, default=_np_converter)
def count_param(model,
param_names):
"""Counts the number of parameters in the given model.
Args:
model: The model for which to count the parameters.
param_names: The parameters in each layer which should be accounted for.
Returns:
The total number of parameters of the given names in the model.
"""
param_traversal = flax.optim.ModelParamTraversal( # pytype: disable=module-attr
lambda path, _: any(param_name in path for param_name in param_names))
return functools.reduce(
operator.add, [param.size for param in param_traversal.iterate(model)], 0)
@jax.jit
def cosine_similarity(a, b):
"""Calculates the cosine similarity between two tensors of same shape."""
a = a.flatten()
b = b.flatten()
return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
def param_as_array(params):
"""Returns a Flax parameter pytree as a single numpy weight vector."""
params_flat = jax.tree_util.tree_leaves(params)
return jnp.concatenate([param.flatten() for param in params_flat])
def cosine_similarity_model(initial_model,
current_model):
"""Calculates the cosine similarity between two model's parameters."""
initial_params = param_as_array(initial_model.params)
params = param_as_array(current_model.params)
return cosine_similarity(initial_params, params)
def vector_difference_norm_model(initial_model,
current_model):
"""Calculates norm of the difference between two model's parameter vectors."""
initial_params = param_as_array(initial_model.params)
params = param_as_array(current_model.params)
return jnp.linalg.norm(params - initial_params)
# Use typevar to hint that we expect unspecified types to match.
T = TypeVar('T')
def pairwise_longest(iterable):
"""Creates a meta-iterator to iterate over current/next values concurrently.
This is different from itertools pairwise recipe in that it returns the final
element as (final, None).
Args:
iterable: An Iterable of any type.
Returns:
An iterable which returns the current and next items in the iterable, or
None if there is no next. For example, for an iterator over the list
(1, 2, 3, 4), this would return an iterator as
((1, 2), (2, 3), (3, 4), (4, None)).
"""
# From itertools example documentation.
a, b = itertools.tee(iterable)
next(b, None)
return itertools.zip_longest(a, b)
================================================
FILE: rigl/experimental/jax/utils/utils_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.nn.nn_functions."""
import functools
import json
import operator
import tempfile
from typing import Optional, Sequence, TypeVar
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
import numpy as np
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils
class TwoLayerDense(flax.deprecated.nn.Module):
"""Two-layer Dense Network."""
NUM_FEATURES: Sequence[int] = (32, 64)
def apply(self, inputs):
# If inputs are in image dimensions, flatten image.
inputs = inputs.reshape(inputs.shape[0], -1)
inputs = flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[0])
return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[1])
class UtilsTest(parameterized.TestCase):
"""Test functions for NN convenience functions."""
def setUp(self):
"""Common setup for test cases."""
super().setUp()
self._batch_size = 2
self._num_classes = 10
self._true_logit = 0.5
self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
self._input = jnp.ones(*self._input_shape)
self._rng = jax.random.PRNGKey(42)
_, initial_params = TwoLayerDense.init_by_shape(self._rng,
(self._input_shape,))
self._model = flax.deprecated.nn.Model(TwoLayerDense, initial_params)
_, initial_params = TwoLayerDense.init_by_shape(self._rng,
(self._input_shape,))
self._model_diff_init = flax.deprecated.nn.Model(TwoLayerDense,
initial_params)
def _create_logits_labels(self, correct):
"""Creates a set of logits/labels resulting from correct classification.
Args:
correct: If true, creates labels for a correct classifiction, otherwise
creates labels for an incorrect classification.
Returns:
A tuple of logits, labels.
"""
logits = np.full((self._batch_size, self._num_classes),
(1.0 - self._true_logit) / self._num_classes,
dtype=np.float32)
# Diagonal over batch will be true.
for i in range(self._batch_size):
logits[i, i % self._num_classes] = self._true_logit
labels = np.zeros(self._batch_size, dtype=jnp.int32)
# Diagonal over batch will be true.
for i in range(self._batch_size):
labels[i] = (i if correct else i + 1) % self._num_classes
return jnp.array(logits), jnp.array(labels)
def test_compute_metrics_correct(self):
"""Tests output when logit outputs indicate correct classification."""
logits, labels_correct = self._create_logits_labels(True)
logits = training._shard_batch(logits)
labels_correct = training._shard_batch(labels_correct)
p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch')
metrics = p_compute_metrics(logits, labels_correct)
loss = metrics['loss']
accuracy = metrics['accuracy']
with self.subTest(name='loss_type'):
self.assertIsInstance(loss, jnp.ndarray)
with self.subTest(name='loss_len'):
self.assertEqual(loss.size, 1)
with self.subTest(name='loss_values'):
self.assertGreaterEqual(loss.all(), 0)
with self.subTest(name='accuracy_type'):
self.assertIsInstance(accuracy, jnp.ndarray)
with self.subTest(name='accuracy_Len'):
self.assertEqual(accuracy.size, 1)
with self.subTest(name='accuracy_values'):
self.assertAlmostEqual(accuracy.all(), 1.0)
def test_compute_metrics_incorrect(self):
"""Tests output when logit outputs indicate incorrect classification."""
logits, labels_incorrect = self._create_logits_labels(False)
logits = training._shard_batch(logits)
labels_incorrect = training._shard_batch(labels_incorrect)
p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch')
metrics = p_compute_metrics(logits, labels_incorrect)
loss = metrics['loss']
accuracy = metrics['accuracy']
with self.subTest(name='loss_type'):
self.assertIsInstance(loss, jnp.ndarray)
with self.subTest(name='loss_len'):
self.assertEqual(loss.size, 1)
with self.subTest(name='loss_values'):
self.assertGreaterEqual(loss.all(), 0)
with self.subTest(name='accuracy_type'):
self.assertIsInstance(accuracy, jnp.ndarray)
with self.subTest(name='accuracy_len'):
self.assertEqual(accuracy.size, 1)
with self.subTest(name='accuracy_values'):
self.assertAlmostEqual(accuracy.all(), 0.0)
def test_compute_metrics_equal_logits(self):
"""Tests output when the logit outputs are equal for all classes."""
logits, labels_correct = self._create_logits_labels(True)
logits = training._shard_batch(logits)
labels_correct = training._shard_batch(labels_correct)
p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch')
metrics = p_compute_metrics(logits, labels_correct)
loss = metrics['loss']
accuracy = metrics['accuracy']
with self.subTest(name='loss_type'):
self.assertIsInstance(loss, jnp.ndarray)
with self.subTest(name='loss_len'):
self.assertEqual(loss.size, 1)
with self.subTest(name='loss_values'):
self.assertGreaterEqual(loss.all(), 0)
with self.subTest(name='accuracy_type'):
self.assertIsInstance(accuracy, jnp.ndarray)
with self.subTest(name='accuracy_len'):
self.assertEqual(accuracy.size, 1)
with self.subTest(name='accuracy_values'):
self.assertAlmostEqual(accuracy.all(), 1.0)
def test_dump_dict_json(self):
"""Tests JSON dumping function."""
data_dict = {
'np_float': np.dtype('float32').type(1.0),
'jnp_float': jnp.dtype('float32').type(1.0),
'np_int': np.dtype('int32').type(1),
'jnp_int': jnp.dtype('int32').type(1),
'np_array': np.array(1.0, dtype=np.float32),
'jnp_array': jnp.array(1.0, dtype=jnp.float32),
}
converted_dict = {
key: utils._np_converter(value) for key, value in data_dict.items()
}
json_path = tempfile.NamedTemporaryFile()
utils.dump_dict_json(data_dict, json_path.name)
with open(json_path.name, 'r') as input_file:
loaded_dict = json.load(input_file)
self.assertDictEqual(loaded_dict, converted_dict)
def test_count_param_two_layer_dense(self):
"""Tests model parameter counting on small FC model."""
count = utils.count_param(self._model, ('kernel',))
self.assertEqual(
count,
self._input.size / self._batch_size * TwoLayerDense.NUM_FEATURES[0] +
TwoLayerDense.NUM_FEATURES[0] * TwoLayerDense.NUM_FEATURES[1])
def test_count_invalid_param(self):
"""Tests model parameter counting for a non-existent parameter name."""
count = utils.count_param(self._model, ('not_kernel',))
self.assertEqual(count, 0)
def test_model_param_as_array(self):
"""Tests method for returning single parameter vector for model."""
param_array = utils.param_as_array(self._model.params)
with self.subTest(name='test_param_is_vector'):
self.assertLen(param_array.shape, 1)
param_sizes = [param.size for param in jax.tree_leaves(self._model.params)]
model_size = functools.reduce(operator.add, param_sizes)
with self.subTest(name='test_param_size'):
self.assertEqual(param_array.size, model_size)
def test_cosine_similarity_random(self):
"""Tests cosine similarity for two random weight matrices."""
a = jax.random.normal(self._rng, (3, 4))
b = jax.random.normal(self._rng, (3, 4))
cosine_similarity = utils.cosine_similarity(a, b)
with self.subTest(name='test_cosine_distance_range'):
self.assertBetween(cosine_similarity, 0., 1.)
def test_cosine_similarity_same(self):
"""Tests cosine similarity for the same weight matrix."""
a = jax.random.normal(self._rng, (3, 4))
cosine_similarity = utils.cosine_similarity(a, a)
with self.subTest(name='test_cosine_distance_range'):
self.assertAlmostEqual(cosine_similarity, 1., places=5)
def test_cosine_similarity_same_model(self):
"""Tests cosine similarity for the same model."""
cosine_dist = utils.cosine_similarity_model(self._model, self._model)
self.assertAlmostEqual(cosine_dist, 1., places=5)
def test_vector_difference_norm_diff_model(self):
"""Tests vector difference norm for different models."""
vector_diff_norm = utils.vector_difference_norm_model(
self._model, self._model_diff_init)
self.assertGreaterEqual(vector_diff_norm, 0.)
def test_vector_difference_norm_same_model(self):
"""Tests vector difference norm for the same model."""
vector_diff_norm = utils.vector_difference_norm_model(
self._model, self._model)
self.assertAlmostEqual(vector_diff_norm, 0., places=5)
T = TypeVar('T')
@parameterized.parameters(
# Tests pairwise longest iterator convenience function with list.
((1, 2, 3, 4), ((1, 2), (2, 3), (3, 4), (4, None))),
# Tests pairwise longest iterator with empty input iterator.
(iter(()), ()),
# Tests pairwise longest iterator with single element iterator.
((1,), ((1, None),))
)
def test_pairwise_longest_list_iterator(
self, input_sequence,
output_sequence):
"""Tests pairwise longest iterator with list iterators."""
output = list(utils.pairwise_longest(iter(input_sequence)))
self.assertSequenceEqual(output, output_sequence)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "e5O1UdsY202_"
},
"source": [
"##### Copyright 2020 Google LLC.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wtx39-f76KsC"
},
"outputs": [],
"source": [
"# Download necessary libraries.\n",
"%%bash \n",
"test -d rigl || git clone https://github.com/google-research/rigl rigl_repo \u0026\u0026 mv rigl_repo/rigl ./ \n",
"test -d gresearch || git clone https://github.com/google-research/google-research google_research"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i25HTaVl6LAI"
},
"source": [
"## Parameter and FLOPs Counting for MobileNetv1 "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gAkFMbjrNCww"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from google_research.micronet_challenge import counting\n",
"from rigl import sparse_utils\n",
"tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 34
},
"executionInfo": {
"elapsed": 2458,
"status": "ok",
"timestamp": 1593006846761,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "dYm9k-Q47PXe",
"outputId": "db7fc195-6e0b-4c04-b695-5670128503d7"
},
"outputs": [
{
"data": {
"text/plain": [
"\u003ctf.Tensor 'mobilenet_1.00_224/act_softmax/Softmax:0' shape=(2, 1000) dtype=float32\u003e"
]
},
"execution_count": 2,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"tf.compat.v1.reset_default_graph()\n",
"model=tf.keras.applications.MobileNet(input_shape=(224,224,3), weights=None)\n",
"model(tf.ones((2,224,224,3)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RNS1s5Wm7U8-"
},
"outputs": [],
"source": [
"masked_layers = []\n",
"dw_layers = []\n",
"for layer in model.layers:\n",
" if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense, tf.keras.layers.DepthwiseConv2D)): \n",
" masked_layers.append(layer)\n",
" if 'conv_dw' in layer.name:\n",
" dw_layers.append(layer)\n",
" # print(layer.name, sparse_utils._get_kernel(layer).shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QtD03TrBSDzV"
},
"outputs": [],
"source": [
"PARAM_SIZE=32\n",
"import functools\n",
"\n",
"get_stats = functools.partial(\n",
" sparse_utils.get_stats, first_layer_name='conv1',\n",
" last_layer_name='conv_preds', param_size=PARAM_SIZE)\n",
"\n",
"def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',\n",
" custom_sparsities=None, is_debug=False, width=1.):\n",
" print('Method: %s, Sparsity: %f' % (method, default_sparsity))\n",
" total_flops, total_param_bits, sparsity = get_stats(\n",
" masked_layers, default_sparsity=default_sparsity, method=method,\n",
" custom_sparsities=custom_sparsities, is_debug=is_debug, width=width)\n",
" print('Total Flops: %.3f MFlops' % (total_flops/1e6))\n",
" print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))\n",
" print('Real Sparsity: %.3f' % (sparsity))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FvqtfXePpgdb"
},
"source": [
"### Printing sparse network stats"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 218
},
"executionInfo": {
"elapsed": 548,
"status": "ok",
"timestamp": 1593006940695,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "qupDcQOlTxDk",
"outputId": "f59b39d2-eedb-4e45-db93-f52958f24a45"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Method: erdos_renyi_kernel, Sparsity: 0.750000\n",
"Total Flops: 599.144 MFlops\n",
"Total Size: 4.888 Mbytes\n",
"Real Sparsity: 0.742\n",
"Method: random, Sparsity: 0.750000\n",
"Total Flops: 330.769 MFlops\n",
"Total Size: 4.894 Mbytes\n",
"Real Sparsity: 0.742\n",
"Method: random, Sparsity: 0.000000\n",
"Total Flops: 1141.544 MFlops\n",
"Total Size: 16.864 Mbytes\n",
"Real Sparsity: 0.000\n"
]
}
],
"source": [
"c_sparsities = {'%s/depthwise_kernel:0' % l.name: 0. for l in dw_layers}\n",
"c_sparsities_uniform = c_sparsities.copy()\n",
"c_sparsities_uniform['conv1/kernel:0'] = 0.\n",
"# c_sparsities_uniform['conv_preds/kernel:0'] = 0.\n",
"# First layer has sparsity 0 by default.\n",
"print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n",
"print_stats(masked_layers, 0.75, 'random', c_sparsities_uniform, is_debug=False)\n",
"print_stats(masked_layers, 0, 'random', is_debug=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 151
},
"executionInfo": {
"elapsed": 529,
"status": "ok",
"timestamp": 1593028091210,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "qvagZCnX31yP",
"outputId": "542832bb-7b59-4f43-d216-73260a9a3a56"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Method: erdos_renyi_kernel, Sparsity: 0.850000\n",
"Total Flops: 439.152 MFlops\n",
"Total Size: 3.224 Mbytes\n",
"Real Sparsity: 0.841\n",
"Method: random, Sparsity: 0.850000\n",
"Total Flops: 222.666 MFlops\n",
"Total Size: 3.229 Mbytes\n",
"Real Sparsity: 0.841\n"
]
}
],
"source": [
"print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n",
"print_stats(masked_layers, 0.85, 'random', c_sparsities_uniform, is_debug=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 151
},
"executionInfo": {
"elapsed": 840,
"status": "ok",
"timestamp": 1593006957962,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "t3L8WlYJOhku",
"outputId": "e5d4709b-984e-4e6d-ded4-8bdd81071267"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Method: erdos_renyi_kernel, Sparsity: 0.900000\n",
"Total Flops: 334.134 MFlops\n",
"Total Size: 2.392 Mbytes\n",
"Real Sparsity: 0.890\n",
"Method: random, Sparsity: 0.900000\n",
"Total Flops: 168.614 MFlops\n",
"Total Size: 2.396 Mbytes\n",
"Real Sparsity: 0.890\n"
]
}
],
"source": [
"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n",
"print_stats(masked_layers, 0.9, 'random', c_sparsities_uniform, is_debug=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 153
},
"executionInfo": {
"elapsed": 567,
"status": "ok",
"timestamp": 1582843606223,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 480
},
"id": "Ge1Ct0YjUME1",
"outputId": "7144ccdc-eae9-47d8-8a5c-b74aad94187c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Method: erdos_renyi_kernel, Sparsity: 0.950000\n",
"Total Flops: 205.281 MFlops\n",
"Total Size: 1.560 Mbytes\n",
"Real Sparsity: 0.940\n",
"Method: random, Sparsity: 0.950000\n",
"Total Flops: 114.563 MFlops\n",
"Total Size: 1.563 Mbytes\n",
"Real Sparsity: 0.940\n"
]
}
],
"source": [
"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n",
"print_stats(masked_layers, 0.95, 'random', c_sparsities_uniform, is_debug=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2RnZ9BCDVJ2P"
},
"source": [
"## Finding the width Multiplier for small dense model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 173
},
"executionInfo": {
"elapsed": 536,
"status": "ok",
"timestamp": 1569942238017,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "-qQMOoNqURfs",
"outputId": "4edf8c57-c3ab-45a1-f19d-13be5da23368"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9933069386323201\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 266.539 MFlops\n",
"Total Size: 4.789 Mbytes\n",
"Real Sparsity: 0.000\n",
"Method: erdos_renyi_kernel, Sparsity: 0.750000\n",
"Total Flops: 588.355 MFlops\n",
"Total Size: 4.757 Mbytes\n",
"Real Sparsity: 0.750\n"
]
}
],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.47)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.47)\n",
"print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 173
},
"executionInfo": {
"elapsed": 536,
"status": "ok",
"timestamp": 1569942242149,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "P5mS-6h3ZChX",
"outputId": "b722e40b-2797-454e-a2bb-91cdaef4a79d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9998127484496482\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 154.770 MFlops\n",
"Total Size: 3.076 Mbytes\n",
"Real Sparsity: 0.000\n",
"Method: erdos_renyi_kernel, Sparsity: 0.850000\n",
"Total Flops: 422.419 MFlops\n",
"Total Size: 3.075 Mbytes\n",
"Real Sparsity: 0.850\n"
]
}
],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.353)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.353)\n",
"print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 168
},
"executionInfo": {
"elapsed": 656,
"status": "ok",
"timestamp": 1569028742267,
"user": {
"displayName": "Utku Evci",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64",
"userId": "01088181649958641579"
},
"user_tz": 240
},
"id": "wY2Uc8RlVkRb",
"outputId": "03535606-8b6f-4eb9-ca48-ef235d69994f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9996546850118981\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 103.825 MFlops\n",
"Total Size: 2.236 Mbytes\n",
"Real Sparsity: 0.000\n",
"Method: erdos_renyi_kernel, Sparsity: 0.900000\n",
"Total Flops: 312.956 MFlops\n",
"Total Size: 2.235 Mbytes\n",
"Real Sparsity: 0.900\n"
]
}
],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.285)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.285)\n",
"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 168
},
"executionInfo": {
"elapsed": 574,
"status": "ok",
"timestamp": 1569089855290,
"user": {
"displayName": "Utku Evci",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64",
"userId": "01088181649958641579"
},
"user_tz": 240
},
"id": "TUfPAjO5Cryq",
"outputId": "c528942a-f531-48df-a46e-d94d5dae0a89"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9982463429660301\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 56.617 MFlops\n",
"Total Size: 1.396 Mbytes\n",
"Real Sparsity: 0.000\n",
"Method: erdos_renyi_kernel, Sparsity: 0.950000\n",
"Total Flops: 180.359 MFlops\n",
"Total Size: 1.393 Mbytes\n",
"Real Sparsity: 0.950\n"
]
}
],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.204)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.204)\n",
"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f8sqZWZYpoqa"
},
"source": [
"### Big-Sparse Networks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 242
},
"executionInfo": {
"elapsed": 631,
"status": "ok",
"timestamp": 1569285091631,
"user": {
"displayName": "Utku Evci",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64",
"userId": "01088181649958641579"
},
"user_tz": 240
},
"id": "f-eD8zoFY_-U",
"outputId": "0341ebde-cff6-497e-afaf-65e4a39ac438"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0084815029856933\n",
"Method: erdos_renyi_kernel, Sparsity: 0.750000\n",
"Total Flops: 2180.140 MFlops\n",
"Total Size: 16.723 Mbytes\n",
"Real Sparsity: 0.742\n",
"Method: random, Sparsity: 0.750000\n",
"Total Flops: 1122.572 MFlops\n",
"Total Size: 15.863 Mbytes\n",
"Real Sparsity: 0.757\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 1141.544 MFlops\n",
"Total Size: 16.864 Mbytes\n",
"Real Sparsity: 0.000\n"
]
}
],
"source": [
"# BIGGER\n",
"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=1.98)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1.98)\n",
"print_stats(masked_layers, 0.75, 'random', {'conv_preds/kernel:0':0.8, 'conv1/kernel:0':0.}, is_debug=False, width=1.98)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 168
},
"executionInfo": {
"elapsed": 581,
"status": "ok",
"timestamp": 1569029822060,
"user": {
"displayName": "Utku Evci",
"photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64",
"userId": "01088181649958641579"
},
"user_tz": 240
},
"id": "z_rW4hO0ZwIG",
"outputId": "efe0e3cd-4ed1-49eb-db6b-d673b01cc020"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0032864697591513\n",
"Method: erdos_renyi_kernel, Sparsity: 0.850000\n",
"Total Flops: 2442.726 MFlops\n",
"Total Size: 16.809 Mbytes\n",
"Real Sparsity: 0.846\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 1141.544 MFlops\n",
"Total Size: 16.864 Mbytes\n",
"Real Sparsity: 0.000\n"
]
}
],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=2.52)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=2.52)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 242
},
"executionInfo": {
"elapsed": 558,
"status": "ok",
"timestamp": 1569939161351,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "MHhuiXGlaQEi",
"outputId": "74db692f-bc1d-4f42-acc9-3848f4b2d21c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0120353164650686\n",
"Method: erdos_renyi_kernel, Sparsity: 0.900000\n",
"Total Flops: 2452.785 MFlops\n",
"Total Size: 16.664 Mbytes\n",
"Real Sparsity: 0.899\n",
"Method: random, Sparsity: 0.900000\n",
"Total Flops: 1058.478 MFlops\n",
"Total Size: 17.833 Mbytes\n",
"Real Sparsity: 0.890\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 1141.544 MFlops\n",
"Total Size: 16.864 Mbytes\n",
"Real Sparsity: 0.000\n"
]
}
],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=3.)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=3.)\n",
"print_stats(masked_layers, 0.9, 'random', {'conv_preds/kernel:0':0.8, 'conv1/kernel:0':0.}, is_debug=False, width=3.)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"height": 173
},
"executionInfo": {
"elapsed": 523,
"status": "ok",
"timestamp": 1569939157037,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": 240
},
"id": "wENtmNUGaXwj",
"outputId": "dab1f1c2-b647-4a67-b486-5ec5dcfcf4af"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0031304863290271\n",
"Method: erdos_renyi_kernel, Sparsity: 0.950000\n",
"Total Flops: 2132.954 MFlops\n",
"Total Size: 16.812 Mbytes\n",
"Real Sparsity: 0.954\n",
"Method: erdos_renyi_kernel, Sparsity: 0.000000\n",
"Total Flops: 1141.544 MFlops\n",
"Total Size: 16.864 Mbytes\n",
"Real Sparsity: 0.000\n"
]
}
],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=3.98)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=3.98)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "klQNdBJIqm3E"
},
"outputs": [],
"source": [
""
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/brain/python/client:colab_notebook",
"kind": "private"
},
"name": "MobileNet v1: Param/Flops Counting [OPEN_SOURCE].ipynb"
},
"kernelspec": {
"display_name": "Python 2",
"name": "python2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "e5O1UdsY202_"
},
"source": [
"##### Copyright 2020 Google LLC.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P5p1fkA3rgL_"
},
"outputs": [],
"source": [
"# Download the official ResNet50 implementation and other libraries.\n",
"# the ResNet50 module s.t. we can use the model builders for our counting.\n",
"%%bash \n",
"test -d tpu || git clone https://github.com/tensorflow/tpu tpu \u0026\u0026 mv tpu/models/experimental/resnet50_keras ./ \n",
"test -d rigl || git clone https://github.com/google-research/rigl rigl_repo \u0026\u0026 mv rigl_repo/rigl ./ \n",
"test -d gresearch || git clone https://github.com/google-research/google-research google_research"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tmr3djWe1rKj"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"from micronet_challenge import counting\n",
"from resnet50_keras import resnet_model as resnet_keras\n",
"from rigl import sparse_utils\n",
"tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dYm9k-Q47PXe"
},
"outputs": [],
"source": [
"tf.compat.v1.reset_default_graph()\n",
"model = resnet_keras.ResNet50(1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RNS1s5Wm7U8-"
},
"outputs": [],
"source": [
"masked_layers = []\n",
"for layer in model.layers:\n",
" if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):\n",
" masked_layers.append(layer)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QtD03TrBSDzV"
},
"outputs": [],
"source": [
"PARAM_SIZE=32 # bits\n",
"import functools\n",
"get_stats = functools.partial(\n",
" sparse_utils.get_stats, first_layer_name='conv1', last_layer_name='fc1000',\n",
" param_size=PARAM_SIZE)\n",
"def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',\n",
" custom_sparsities={}, is_debug=False, width=1., **kwargs):\n",
" print('Method: %s, Sparsity: %f' % (method, default_sparsity))\n",
" total_flops, total_param_bits, sparsity = get_stats(\n",
" masked_layers, default_sparsity=default_sparsity, method=method,\n",
" custom_sparsities=custom_sparsities, is_debug=is_debug, width=width, **kwargs)\n",
" print('Total Flops: %.3f MFlops' % (total_flops/1e6))\n",
" print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))\n",
" print('Real Sparsity: %.3f' % (sparsity))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "C_2kH9dsrUqu"
},
"source": [
"# Pruning FLOPs\n",
"We calculate theoratical FLOPs for pruning, which means we will start counting sparse FLOPs when the pruning starts."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yHmbXdMyT2c8"
},
"outputs": [],
"source": [
"p_start, p_end, p_freq = 10000,25000,1000\n",
"target_sparsity = 0.8\n",
"total_flops = []\n",
"for i in range(0,32001,1000):\n",
" if i \u003c p_start:\n",
" sparsity = 0.\n",
" elif p_end \u003c i:\n",
" sparsity = target_sparsity\n",
" else:\n",
" sparsity = (1-(1-(i-p_start)/float(p_end-p_start))**3)*target_sparsity\n",
" # print(i, sparsity)\n",
" c_flops, _, _ = get_stats(\n",
" masked_layers, default_sparsity=sparsity, method='random', custom_sparsities={'conv1/kernel:0':0, 'fc1000/kernel:0':0.8})\n",
" # print(i, c_flops, sparsity)\n",
" total_flops.append(c_flops)\n",
"avg_flops = sum(total_flops) / len(total_flops)\n",
"print('Average Flops: ', avg_flops, avg_flops/total_flops[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xUU10hxxsZX-"
},
"source": [
"### Printing sparse network stats."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qupDcQOlTxDk"
},
"outputs": [],
"source": [
"print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=True, erk_power_scale=0.2)\n",
"print_stats(masked_layers, 0.8, 'erdos_renyi')\n",
"print_stats(masked_layers, 0.8, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False)\n",
"print_stats(masked_layers, 0, 'random', is_debug=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AI1HIlLrzuED"
},
"outputs": [],
"source": [
"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False)\n",
"print_stats(masked_layers, 0.9, 'erdos_renyi')\n",
"print_stats(masked_layers, 0.9, 'random', {'conv1/kernel:0':0., 'fc1000/kernel:0':0.9}, is_debug=False)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oX5klsS4_vy-"
},
"outputs": [],
"source": [
"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False)\n",
"print_stats(masked_layers, 0.95, 'erdos_renyi')\n",
"print_stats(masked_layers, 0.95, 'random', {'conv1/kernel:0':0}, is_debug=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fe2FHmPfzS7S"
},
"outputs": [],
"source": [
"print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', {'conv1/kernel:0':0}, is_debug=False)\n",
"print_stats(masked_layers, 0.965, 'erdos_renyi')\n",
"print_stats(masked_layers, 0.965, 'random', {'conv1/kernel:0':0}, is_debug=False)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yc2EeP_YWUfA"
},
"source": [
"## Finding the width Multiplier for small dense model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p8NJFEo9Se2S"
},
"outputs": [],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.465)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.465)\n",
"print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gjk8Z2g2TOKq"
},
"outputs": [],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.34)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.34)\n",
"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sa1zoC-bT-Qk"
},
"outputs": [],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.26)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.26)\n",
"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "f_IugJP5URFa"
},
"outputs": [],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0.965, 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.231)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.231)\n",
"print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fXd4Mx90sc9Q"
},
"source": [
"### Printing the Big-Sparse Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BtpJ3LvKYCNn"
},
"outputs": [],
"source": [
"# BIGGER\n",
"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel', width=2.1)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=2.1)\n",
"print_stats(masked_layers, 0.8, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8},\n",
" is_debug=False, width=2.1)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.1)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kRcOlrf4YG7K"
},
"outputs": [],
"source": [
"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n",
"_, bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', width=2.8)\n",
"print(sparse_bits/bits)\n",
"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=2.8)\n",
"print_stats(masked_layers, 0.9, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False, width=2.8)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.8)\n",
"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BN8qfasQWva2"
},
"source": [
"## [BONUS] DSR FLOPs\n",
"Obtained from figure https://arxiv.org/abs/1902.05967; exact values are probably slightly different.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RwI5aRe-SH0n"
},
"outputs": [],
"source": [
"resnet_layers=['conv1/kernel:0',\n",
"'res2a_branch2a/kernel:0',\n",
"'res2a_branch2b/kernel:0',\n",
"'res2a_branch2c/kernel:0',\n",
"'res2a_branch1/kernel:0',\n",
"'res2b_branch2a/kernel:0',\n",
"'res2b_branch2b/kernel:0',\n",
"'res2b_branch2c/kernel:0',\n",
"'res2c_branch2a/kernel:0',\n",
"'res2c_branch2b/kernel:0',\n",
"'res2c_branch2c/kernel:0',\n",
"'res3a_branch2a/kernel:0',\n",
"'res3a_branch2b/kernel:0',\n",
"'res3a_branch2c/kernel:0',\n",
"'res3a_branch1/kernel:0',\n",
"'res3b_branch2a/kernel:0',\n",
"'res3b_branch2b/kernel:0',\n",
"'res3b_branch2c/kernel:0',\n",
"'res3c_branch2a/kernel:0',\n",
"'res3c_branch2b/kernel:0',\n",
"'res3c_branch2c/kernel:0',\n",
"'res3d_branch2a/kernel:0',\n",
"'res3d_branch2b/kernel:0',\n",
"'res3d_branch2c/kernel:0',\n",
"'res4a_branch2a/kernel:0',\n",
"'res4a_branch2b/kernel:0',\n",
"'res4a_branch2c/kernel:0',\n",
"'res4a_branch1/kernel:0',\n",
"'res4b_branch2a/kernel:0',\n",
"'res4b_branch2b/kernel:0',\n",
"'res4b_branch2c/kernel:0',\n",
"'res4c_branch2a/kernel:0',\n",
"'res4c_branch2b/kernel:0',\n",
"'res4c_branch2c/kernel:0',\n",
"'res4d_branch2a/kernel:0',\n",
"'res4d_branch2b/kernel:0',\n",
"'res4d_branch2c/kernel:0',\n",
"'res4e_branch2a/kernel:0',\n",
"'res4e_branch2b/kernel:0',\n",
"'res4e_branch2c/kernel:0',\n",
"'res4f_branch2a/kernel:0',\n",
"'res4f_branch2b/kernel:0',\n",
"'res4f_branch2c/kernel:0',\n",
"'res5a_branch2a/kernel:0',\n",
"'res5a_branch2b/kernel:0',\n",
"'res5a_branch2c/kernel:0',\n",
"'res5a_branch1/kernel:0',\n",
"'res5b_branch2a/kernel:0',\n",
"'res5b_branch2b/kernel:0',\n",
"'res5b_branch2c/kernel:0',\n",
"'res5c_branch2a/kernel:0',\n",
"'res5c_branch2b/kernel:0',\n",
"'res5c_branch2c/kernel:0',\n",
"'fc1000/kernel:0']\n",
"dsr_sparsities8=[0,\n",
" 0., .15, .5, .425, .575, .55, .425, .32, .44, .15,\n",
" 0., .15, .55, .6, .8, .65, .75, .65, .65, .65, .55, .65, .7,\n",
" 0., .35, .65, .85, .9, .8, .85, .85, .8, .85, .85, .85, .85, .8, .8, .9, .75, .8, .85,\n",
" 0., .65, .85, .95, .85, .8, .9, .65, .9, .8,\n",
" .8]\n",
"dsr_sparsities9=[0,\n",
" 0., .4, .6, .65, .65, .6, .6, .5, .6, .45,\n",
" 0., .4, .7, .8, .9, .8, .85, .8, .75, .8, .7, .8, .8,\n",
" 0., .6, .8, .95, .95, .9, .95, .9, .9, .95, .9, .9, .95, .9, .9, .95, .85, .85, .9,\n",
" 0., 0.8, .95, .95, .9, .9, .95, .8, .95, .9,\n",
" .9] "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "P6i-jjz6OLBH"
},
"outputs": [],
"source": [
"dsr_map = dict(zip(resnet_layers, dsr_sparsities8))\n",
"print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xeGqdHtYYlZT"
},
"outputs": [],
"source": [
"dsr_map = dict(zip(resnet_layers, dsr_sparsities9))\n",
"print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pf3qqLKrG67e"
},
"source": [
"# [BONUS] STR FLOPs\n",
"Layerwise sparsities are obtained from the [STR paper](https://arxiv.org/abs/2002.03231)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MIwBmu0NHOuI"
},
"outputs": [],
"source": [
"str_sparsities = \"\"\"\n",
"Layer 1 - conv1 9408 118013952 51.46 51.40 63.02 59.80 59.83 64.87 67.36 66.96 72.11 69.46 73.29 73.47 72.05 75.12 76.12 77.75\n",
"Layer 2 - layer1.0.conv1 4096 12845056 69.36 73.24 87.57 83.28 85.18 89.60 91.41 91.11 92.38 91.75 94.46 94.51 94.60 95.95 96.53 96.51\n",
"Layer 3 - layer1.0.conv2 36864 115605504 77.85 76.26 90.87 89.48 87.31 94.79 94.27 95.04 95.69 96.07 97.36 97.77 98.35 98.51 98.59 98.84\n",
"Layer 4 - layer1.0.conv3 16384 51380224 74.81 74.65 86.52 85.80 85.25 91.85 92.78 93.67 94.13 94.69 96.61 97.03 97.37 98.04 98.21 98.47\n",
"Layer 5 - layer1.0.downsample.0 16384 51380224 70.95 72.96 83.53 83.34 82.56 89.13 90.62 90.17 91.83 92.69 95.48 94.89 95.68 96.98 97.56 97.72\n",
"Layer 6 - layer1.1.conv1 16384 51380224 80.27 79.58 89.82 89.89 88.51 94.56 96.64 95.78 95.81 96.81 98.79 98.90 98.98 99.13 99.62 99.47\n",
"Layer 7 - layer1.1.conv2 36864 115605504 81.36 80.95 91.75 90.60 89.61 94.70 95.78 96.18 96.42 97.26 98.65 99.07 99.40 99.11 99.31 99.56\n",
"Layer 8 - layer1.1.conv3 16384 51380224 84.45 80.11 91.22 91.70 90.21 95.17 97.05 95.81 96.34 97.23 98.68 98.76 98.90 99.16 99.57 99.46\n",
"Layer 9 - layer1.2.conv1 16384 51380224 78.23 79.79 90.12 88.07 89.36 94.62 95.94 94.74 96.23 96.75 97.96 98.41 98.72 99.38 99.35 99.46\n",
"Layer 10 - layer1.2.conv2 36864 115605504 76.01 81.53 91.06 87.03 88.27 93.90 95.63 94.26 96.24 96.11 97.54 98.27 98.44 99.32 99.19 99.39\n",
"Layer 11 - layer1.2.conv3 16384 51380224 84.47 83.28 94.95 90.99 92.64 95.76 96.95 96.01 96.87 97.31 98.38 98.60 98.72 99.38 99.27 99.51\n",
"Layer 12 - layer2.0.conv1 32768 102760448 73.74 73.96 86.78 85.95 85.90 92.32 94.79 93.86 94.62 95.64 97.19 98.22 98.52 98.48 98.84 98.92\n",
"Layer 13 - layer2.0.conv2 147456 115605504 82.56 85.70 91.31 93.91 94.03 97.54 97.43 97.65 98.38 98.62 99.24 99.23 99.40 99.61 99.67 99.63\n",
"Layer 14 - layer2.0.conv3 65536 51380224 84.70 83.55 93.04 93.13 92.13 96.61 97.37 97.21 97.59 98.14 98.80 98.95 99.18 99.29 99.47 99.43\n",
"Layer 15 - layer2.0.downsample.0 131072 102760448 85.10 87.66 92.78 94.96 95.13 98.07 97.97 98.15 98.70 98.88 99.37 99.35 99.40 99.69 99.68 99.71\n",
"Layer 16 - layer2.1.conv1 65536 51380224 85.42 85.79 94.04 95.31 94.94 97.92 98.53 98.21 98.84 99.06 99.46 99.53 99.72 99.78 99.81 99.80\n",
"Layer 17 - layer2.1.conv2 147456 115605504 76.95 82.75 87.63 91.50 91.76 95.59 97.22 96.07 97.32 97.80 98.24 98.24 98.60 99.24 99.66 99.33\n",
"Layer 18 - layer2.1.conv3 65536 51380224 84.76 84.71 93.10 93.66 93.23 97.00 98.18 97.35 98.06 98.41 98.96 99.21 99.32 99.55 99.58 99.59\n",
"Layer 19 - layer2.2.conv1 65536 51380224 84.30 85.34 92.70 94.61 94.76 97.72 97.91 98.21 98.54 98.98 99.24 99.35 99.50 99.62 99.63 99.77\n",
"Layer 20 - layer2.2.conv2 147456 115605504 84.28 85.43 92.99 94.86 94.90 97.52 97.21 98.11 98.19 99.04 99.28 99.37 99.46 99.63 99.59 99.72\n",
"Layer 21 - layer2.2.conv3 65536 51380224 82.19 84.21 91.12 93.38 93.53 96.89 97.14 97.59 97.77 98.66 98.96 99.15 99.25 99.49 99.51 99.57\n",
"Layer 22 - layer2.3.conv1 65536 51380224 83.37 84.41 90.46 93.26 93.50 96.71 97.89 96.99 98.14 98.36 99.10 99.23 99.33 99.53 99.75 99.60\n",
"Layer 23 - layer2.3.conv2 147456 115605504 82.83 84.03 91.44 93.21 93.25 96.83 98.02 96.96 98.45 98.30 98.97 99.06 99.26 99.31 99.81 99.68\n",
"Layer 24 - layer2.3.conv3 65536 51380224 82.93 85.65 91.02 94.14 93.56 97.20 97.97 97.04 98.16 98.36 98.88 98.97 99.20 99.32 99.67 99.62\n",
"Layer 25 - layer3.0.conv1 131072 102760448 76.63 77.98 85.99 88.85 88.60 94.26 95.07 94.97 96.21 96.59 97.75 98.04 98.30 98.72 99.11 99.06\n",
"Layer 26 - layer3.0.conv2 589824 115605504 87.35 88.68 94.39 96.14 96.19 98.51 98.77 98.72 99.11 99.23 99.53 99.59 99.64 99.73 99.80 99.81\n",
"Layer 27 - layer3.0.conv3 262144 51380224 81.22 83.22 90.58 93.19 93.05 96.82 97.38 97.32 97.98 98.28 98.88 99.03 99.16 99.39 99.55 99.53\n",
"Layer 28 - layer3.0.downsample.0 524288 102760448 89.75 90.99 96.05 97.20 97.16 98.96 99.21 99.20 99.50 99.58 99.78 99.82 99.86 99.91 99.94 99.93\n",
"Layer 29 - layer3.1.conv1 262144 51380224 85.88 87.35 93.43 95.36 96.12 98.64 98.77 98.87 99.22 99.33 99.64 99.67 99.72 99.82 99.88 99.84\n",
"Layer 30 - layer3.1.conv2 589824 115605504 85.06 86.24 92.74 95.06 95.30 98.09 98.28 98.36 98.75 99.08 99.46 99.48 99.54 99.69 99.76 99.76\n",
"Layer 31 - layer3.1.conv3 262144 51380224 84.34 86.79 92.15 94.84 94.90 97.75 98.15 98.11 98.56 98.94 99.30 99.36 99.45 99.65 99.79 99.70\n",
"Layer 32 - layer3.2.conv1 262144 51380224 87.51 89.15 94.15 96.77 96.46 98.81 98.83 98.96 99.19 99.44 99.67 99.71 99.74 99.82 99.85 99.89\n",
"Layer 33 - layer3.2.conv2 589824 115605504 87.15 88.67 94.09 95.59 96.14 98.86 98.69 98.91 99.21 99.20 99.64 99.72 99.76 99.85 99.84 99.90\n",
"Layer 34 - layer3.2.conv3 262144 51380224 84.86 86.90 92.40 94.99 94.99 98.19 98.19 98.42 98.76 98.97 99.42 99.56 99.62 99.76 99.75 99.88\n",
"Layer 35 - layer3.3.conv1 262144 51380224 86.62 89.46 94.06 96.08 95.88 98.70 98.71 98.77 99.01 99.27 99.58 99.66 99.69 99.83 99.87 99.87\n",
"Layer 36 - layer3.3.conv2 589824 115605504 86.52 87.97 93.56 96.10 96.11 98.70 98.82 98.89 99.19 99.31 99.68 99.73 99.77 99.88 99.87 99.93\n",
"Layer 37 - layer3.3.conv3 262144 51380224 84.19 86.81 92.32 94.94 94.91 98.20 98.37 98.43 98.82 99.00 99.51 99.57 99.64 99.81 99.81 99.87\n",
"Layer 38 - layer3.4.conv1 262144 51380224 85.85 88.40 93.55 95.49 95.86 98.35 98.44 98.55 98.79 98.96 99.54 99.59 99.60 99.82 99.86 99.87\n",
"Layer 39 - layer3.4.conv2 589824 115605504 85.96 87.38 93.27 95.66 95.63 98.41 98.58 98.56 99.19 99.26 99.64 99.69 99.67 99.87 99.90 99.92\n",
"Layer 40 - layer3.4.conv3 262144 51380224 83.45 85.76 91.75 94.49 94.35 97.67 98.09 97.99 98.65 98.94 99.49 99.52 99.48 99.77 99.86 99.85\n",
"Layer 41 - layer3.5.conv1 262144 51380224 83.33 85.77 91.79 95.09 94.24 97.46 97.89 97.92 98.71 98.90 99.35 99.52 99.58 99.76 99.79 99.83\n",
"Layer 42 - layer3.5.conv2 589824 115605504 84.98 86.67 92.48 94.92 95.13 97.88 98.14 98.32 98.91 99.00 99.44 99.58 99.69 99.80 99.83 99.87\n",
"Layer 43 - layer3.5.conv3 262144 51380224 79.78 82.23 89.39 93.14 92.76 96.59 97.04 97.30 98.10 98.41 99.03 99.25 99.44 99.61 99.71 99.75\n",
"Layer 44 - layer4.0.conv1 524288 102760448 77.83 79.61 87.11 90.32 90.64 95.39 95.84 95.92 97.17 97.35 98.36 98.60 98.83 99.20 99.37 99.42\n",
"Layer 45 - layer4.0.conv2 2359296 115605504 86.18 88.00 93.53 95.66 95.78 98.31 98.47 98.55 99.08 99.16 99.54 99.63 99.69 99.81 99.85 99.86\n",
"Layer 46 - layer4.0.conv3 1048576 51380224 78.43 80.48 87.85 91.14 91.27 96.00 96.40 96.47 97.53 97.92 98.81 99.00 99.15 99.45 99.57 99.61\n",
"Layer 47 - layer4.0.downsample.0 2097152 102760448 88.49 89.98 95.03 96.79 96.90 98.91 99.06 99.11 99.45 99.51 99.77 99.82 99.85 99.92 99.94 99.94\n",
"Layer 48 - layer4.1.conv1 1048576 51380224 82.07 84.02 90.34 93.69 93.72 97.15 97.56 97.76 98.45 98.75 99.27 99.36 99.54 99.67 99.76 99.80\n",
"Layer 49 - layer4.1.conv2 2359296 115605504 83.42 85.23 91.16 93.98 93.93 97.26 97.58 97.71 98.36 98.67 99.25 99.34 99.50 99.68 99.76 99.80\n",
"Layer 50 - layer4.1.conv3 1048576 51380224 78.08 79.96 86.66 90.48 90.22 95.22 95.76 95.89 96.88 97.65 98.70 98.85 99.13 99.45 99.58 99.66\n",
"Layer 51 - layer4.2.conv1 1048576 51380224 76.34 77.93 84.98 87.57 88.47 93.90 93.87 94.16 95.55 95.91 97.66 97.97 98.15 98.88 99.08 99.22\n",
"Layer 52 - layer4.2.conv2 2359296 115605504 73.57 74.97 82.32 84.37 86.01 91.92 91.66 92.22 94.02 94.16 96.65 97.13 97.29 98.44 98.74 99.00\n",
"Layer 53 - layer4.2.conv3 1048576 51380224 68.78 70.38 78.11 80.29 81.73 89.64 89.43 89.65 91.40 92.65 96.02 96.72 96.93 98.47 98.83 99.15\n",
"Layer 54 - fc 2048000 2048000 50.65 52.46 60.48 64.50 65.12 75.20 75.73 75.80 78.57 80.69 85.96 87.26 88.03 91.11 92.15 92.87\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gSFw1eH1G8zh"
},
"outputs": [],
"source": [
"resnet_layers=['conv1/kernel:0',\n",
"'res2a_branch2a/kernel:0',\n",
"'res2a_branch2b/kernel:0',\n",
"'res2a_branch2c/kernel:0',\n",
"'res2a_branch1/kernel:0',\n",
"'res2b_branch2a/kernel:0',\n",
"'res2b_branch2b/kernel:0',\n",
"'res2b_branch2c/kernel:0',\n",
"'res2c_branch2a/kernel:0',\n",
"'res2c_branch2b/kernel:0',\n",
"'res2c_branch2c/kernel:0',\n",
"'res3a_branch2a/kernel:0',\n",
"'res3a_branch2b/kernel:0',\n",
"'res3a_branch2c/kernel:0',\n",
"'res3a_branch1/kernel:0',\n",
"'res3b_branch2a/kernel:0',\n",
"'res3b_branch2b/kernel:0',\n",
"'res3b_branch2c/kernel:0',\n",
"'res3c_branch2a/kernel:0',\n",
"'res3c_branch2b/kernel:0',\n",
"'res3c_branch2c/kernel:0',\n",
"'res3d_branch2a/kernel:0',\n",
"'res3d_branch2b/kernel:0',\n",
"'res3d_branch2c/kernel:0',\n",
"'res4a_branch2a/kernel:0',\n",
"'res4a_branch2b/kernel:0',\n",
"'res4a_branch2c/kernel:0',\n",
"'res4a_branch1/kernel:0',\n",
"'res4b_branch2a/kernel:0',\n",
"'res4b_branch2b/kernel:0',\n",
"'res4b_branch2c/kernel:0',\n",
"'res4c_branch2a/kernel:0',\n",
"'res4c_branch2b/kernel:0',\n",
"'res4c_branch2c/kernel:0',\n",
"'res4d_branch2a/kernel:0',\n",
"'res4d_branch2b/kernel:0',\n",
"'res4d_branch2c/kernel:0',\n",
"'res4e_branch2a/kernel:0',\n",
"'res4e_branch2b/kernel:0',\n",
"'res4e_branch2c/kernel:0',\n",
"'res4f_branch2a/kernel:0',\n",
"'res4f_branch2b/kernel:0',\n",
"'res4f_branch2c/kernel:0',\n",
"'res5a_branch2a/kernel:0',\n",
"'res5a_branch2b/kernel:0',\n",
"'res5a_branch2c/kernel:0',\n",
"'res5a_branch1/kernel:0',\n",
"'res5b_branch2a/kernel:0',\n",
"'res5b_branch2b/kernel:0',\n",
"'res5b_branch2c/kernel:0',\n",
"'res5c_branch2a/kernel:0',\n",
"'res5c_branch2b/kernel:0',\n",
"'res5c_branch2c/kernel:0',\n",
"'fc1000/kernel:0']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "31sg-lNhHN7D"
},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"str_sparsities_parsed = defaultdict(list)\n",
"for j, l in enumerate(str_sparsities.strip().split('\\n')):\n",
" l = l.split('-')[1].strip().split(' ')\n",
" if l[0] == 'Overall':\n",
" overall_sparsities = map(float, l[3:])\n",
" else:\n",
" for i, ls in enumerate(l[3:]):\n",
" s = overall_sparsities[i]\n",
" # Accuracies are between 0 and 1, so devide by 100.\n",
" str_sparsities_parsed[s].append(float(ls) / 100.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xrjtum-4HgAT"
},
"outputs": [],
"source": [
"for k in str_sparsities_parsed:\n",
" print(k)\n",
" dsr_map = dict(zip(resnet_layers, str_sparsities_parsed[k]))\n",
" print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//research/colab/notebook:notebook_backend",
"kind": "private"
},
"name": "Resnet-50: Param/Flops Counting [OpenSource].ipynb"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: rigl/imagenet_resnet/imagenet_train_eval.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""This script trains a ResNet model that implements various pruning methods.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
from absl import app
from absl import flags
from absl import logging
from rigl import sparse_optimizers
from rigl import sparse_utils
from rigl.imagenet_resnet import mobilenetv1_model
from rigl.imagenet_resnet import mobilenetv2_model
from rigl.imagenet_resnet import resnet_model
from rigl.imagenet_resnet import utils
from rigl.imagenet_resnet import vgg
from official.resnet import imagenet_input
from tensorflow.contrib import estimator as contrib_estimator
from tensorflow.contrib import tpu as contrib_tpu
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.training.python.training import evaluation
from tensorflow_estimator.python.estimator import estimator
DST_METHODS = [
'set',
'momentum',
'rigl',
'static'
]
ALL_METHODS = tuple(['scratch', 'baseline', 'snip', 'dnw'] + DST_METHODS)
NO_MASK_INIT_METHODS = ('snip', 'dnw', 'baseline')
flags.DEFINE_string(
'precision',
default='float32',
help=('Precision to use; one of: {bfloat16, float32}'))
flags.DEFINE_integer('num_workers', 1, 'Number of training workers.')
flags.DEFINE_float(
'base_learning_rate',
default=0.1,
help=('Base learning rate when train batch size is 256.'))
flags.DEFINE_float(
'momentum',
default=0.9,
help=('Momentum parameter used in the MomentumOptimizer.'))
flags.DEFINE_integer('ps_task', 0,
'Task id of the replica running the training.')
flags.DEFINE_float(
'weight_decay',
default=1e-4,
help=('Weight decay coefficiant for l2 regularization.'))
flags.DEFINE_string('master', '', 'Master job.')
flags.DEFINE_string('tpu_job_name', None, 'For complicated TensorFlowFlock')
flags.DEFINE_integer(
'steps_per_checkpoint',
default=1000,
help=('Controls how often checkpoints are generated. More steps per '
'checkpoint = higher utilization of TPU and generally higher '
'steps/sec'))
flags.DEFINE_integer(
'keep_checkpoint_max', default=0, help=('Number of checkpoints to hold.'))
flags.DEFINE_integer(
'seed', default=0, help=('Sets the random seed.'))
flags.DEFINE_string(
'data_directory', None, 'The location of the sstable used for training.')
flags.DEFINE_string('eval_once_ckpt_prefix', '',
'File name of the eval chekpoint used for evaluation.')
flags.DEFINE_string(
'data_format',
default='channels_last',
help=('A flag to override the data format used in the model. The value'
' is either channels_first or channels_last. To run the network on'
' CPU or TPU, channels_last should be used. For GPU, channels_first'
' will improve performance.'))
flags.DEFINE_bool(
'transpose_input',
default=False,
help='Use TPU double transpose optimization')
flags.DEFINE_bool(
'log_mask_imgs_each_iteration',
default=False,
help='Use to log few masks as images. Be careful when using. This is'
' very likely to slow down your training and create huge logs.')
flags.DEFINE_string(
'mask_init_method',
default='',
help='If not empty string and mask is not loaded from a checkpoint, '
'indicates the method used for mask initialization. One of the following: '
'`random`, `erdos_renyi`.')
flags.DEFINE_integer(
'resnet_depth',
default=50,
help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,'
' 200}. ResNet-18 and 34 use the pre-activation residual blocks'
' without bottleneck layers. The other models use pre-activation'
' bottleneck layers. Deeper models require more training time and'
' more memory and may require reducing --train_batch_size to prevent'
' running out of memory.'))
flags.DEFINE_float('label_smoothing', 0.1,
'Relax confidence in the labels by (1-label_smoothing).')
flags.DEFINE_float(
'erk_power_scale', 1.0,
'Softens the ERK distribituion. Value 0 means uniform.'
'1 means regular ERK.')
flags.DEFINE_integer(
'train_steps',
default=2,
help=('The number of steps to use for training. Default is 112590 steps'
' which is approximately 90 epochs at batch size 1024. This flag'
' should be adjusted according to the --train_batch_size flag.'))
flags.DEFINE_integer(
'train_batch_size', default=1024, help='Batch size for training.')
flags.DEFINE_integer(
'eval_batch_size', default=1000, help='Batch size for evaluation.')
flags.DEFINE_integer(
'num_train_images', default=1281167, help='Size of training data set.')
flags.DEFINE_integer(
'num_eval_images', default=50000, help='Size of evaluation data set.')
flags.DEFINE_integer(
'num_label_classes', default=1000, help='Number of classes, at least 2')
flags.DEFINE_integer(
'steps_per_eval',
default=1251,
help=('Controls how often evaluation is performed. Since evaluation is'
' fairly expensive, it is advised to evaluate as infrequently as'
' possible (i.e. up to --train_steps, which evaluates the model only'
' after finishing the entire training regime).'))
flags.DEFINE_bool(
'use_tpu',
default=False,
help=('Use TPU to execute the model for training and evaluation. If'
' --use_tpu=false, will use whatever devices are available to'
' TensorFlow by default (e.g. CPU and GPU)'))
flags.DEFINE_integer(
'iterations_per_loop',
default=1251,
help=('Number of steps to run on TPU before outfeeding metrics to the CPU.'
' If the number of iterations in the loop would exceed the number of'
' train steps, the loop will exit before reaching'
' --iterations_per_loop. The larger this value is, the higher the'
' utilization on the TPU.'))
flags.DEFINE_integer(
'num_parallel_calls',
default=64,
help=('Number of parallel threads in CPU for the input pipeline'))
flags.DEFINE_integer(
'num_cores',
default=8,
help=('Number of TPU cores. For a single TPU device, this is 8 because each'
' TPU has 4 chips each with 2 cores.'))
flags.DEFINE_string('output_dir', '/tmp/imagenet/',
'Directory where to write event logs and checkpoint.')
flags.DEFINE_bool('use_folder_stub', True,
'If True the output_dir is extended with some parameters.')
flags.DEFINE_bool('use_batch_statistics', False,
'If True the forward pass is made in training mode. ')
flags.DEFINE_bool('eval_on_train', False,
'If True the evaluation is made on training set.')
flags.DEFINE_enum(
'mode', 'train', ('train_and_eval', 'train', 'eval', 'eval_once'),
'One of {"train_and_eval", "train", "eval"}.')
flags.DEFINE_integer('export_model_freq', 2502,
'The rate at which estimator exports the model.')
flags.DEFINE_enum(
'training_method', 'scratch', ALL_METHODS,
'Method used for training sparse network. `scratch` means initial mask is '
'kept during training. `set` is for sparse evalutionary training and '
'`baseline` is for dense baseline.')
flags.DEFINE_enum(
'init_method', 'baseline', ('baseline', 'sparse'),
'Method for initialization. If sparse and training_method=scratch, then '
'use initializers that take into account starting sparsity.')
# flags.DEFINE_enum(
# 'mask_init_method', 'baseline', ('default'),
# 'Method for initializating masks. If not default, end_sparsities are used'
# ' to define the layer wise random sparse connectivity.')
flags.DEFINE_bool(
'is_warm_up',
default=True,
help=('Boolean for whether to scale weight of regularizer.'))
flags.DEFINE_float(
'width', -1., 'Multiplier for the number of channels in each layer')
# first and last layer are somewhat special. First layer has almost no
# parameters, but 3% of the total flops. Last layer has only .05% of the total
# flops but 10% of the total parameters. Depending on whether the goal is max
# compression or max acceleration, pruning goals will be different.
flags.DEFINE_bool('use_adam', False,
'Whether to use Adam or not')
flags.DEFINE_bool('use_sgdr', False,
'Whether to use SGDR for learning rate schedule.')
flags.DEFINE_float('sgdr_decay_step', 5, 'Initial cycle length for SGDR.')
flags.DEFINE_float('sgdr_t_mul', 1.5, 'Cycle length multiplier for SGDR')
flags.DEFINE_float('sgdr_m_mul', .5,
'Learning rate drop at each restart cycle.')
flags.DEFINE_float('end_sparsity', 0.9,
'Target sparsity desired by end of training.')
flags.DEFINE_float('drop_fraction', 0.3,
'When changing mask dynamically, this fraction decides how '
'much of the ')
flags.DEFINE_string('drop_fraction_anneal', 'constant',
'If not empty the drop fraction is annealed during sparse'
' training. One of the following: `constant`, `cosine` or '
'`exponential_(\\d*\\.?\\d*)$`. For example: '
'`exponential_3`, `exponential_.3`, `exponential_0.3`. '
'The number after `exponential` defines the exponent.')
flags.DEFINE_string('grow_init', 'zeros',
'Passed to the SparseInitializer, one of: zeros, '
'initial_value, random_normal, random_uniform.')
flags.DEFINE_float('s_momentum', 0.9,
'Momentum values for exponential moving average of '
'gradients. Used when training_method="momentum".')
flags.DEFINE_float('rigl_acc_scale', 0.,
'Used to scale initial accumulated gradients for new '
'connections.')
flags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin pruning at.')
flags.DEFINE_integer('maskupdate_end_step', 25000, 'Step to end pruning at.')
flags.DEFINE_integer('maskupdate_frequency', 100,
'Step interval between pruning.')
flags.DEFINE_float(
'first_layer_sparsity', 0.,
'Sparsity to use for the first layer. Overrides default end_sparsity '
'if greater than 0. If -1, default sparsity is applied. If 0, layer is not'
'pruned or masked.')
flags.DEFINE_float(
'last_layer_sparsity', -1,
'Sparsity to use for the last layer. Overrides default end_sparsity '
'if greater than 0. If -1, default sparsity is applied. If 0, layer is not'
'pruned or masked.')
flags.DEFINE_string(
'load_mask_dir', '',
'Directory of a trained model from which to load only the mask')
flags.DEFINE_string(
'initial_value_checkpoint', '',
'Directory of a model from which to load only the parameters')
flags.DEFINE_string(
'model_architecture', 'resnet',
'Which architecture to use. Options: resnet, mobilenet_v1, mobilenet_v2.'
'vgg_16, vgg_a, vgg_19.')
flags.DEFINE_float('expansion_factor', 6.,
'how much to expand filters before depthwise conv')
flags.DEFINE_float('training_steps_multiplier', 1.0,
'Training schedule is shortened or extended with the '
'multiplier, if it is not 1.')
flags.DEFINE_integer('block_width', 1, 'width of block')
flags.DEFINE_integer('block_height', 1, 'height of block')
FLAGS = flags.FLAGS
LR_SCHEDULE = []
PARAM_SUFFIXES = ('gamma', 'beta', 'weights', 'biases')
MASK_SUFFIX = 'mask'
# Learning rate schedule (multiplier, epoch to start) tuples
def set_lr_schedule():
"""Sets the learning schedule: LR_SCHEDULE for the training."""
global LR_SCHEDULE
if FLAGS.model_architecture == 'mobilenet_v2' or FLAGS.model_architecture == 'mobilenet_v1':
LR_SCHEDULE = [(1.0, 8), (0.1, 40), (0.01, 75), (0.001, 95), (.0003, 120)]
elif (FLAGS.model_architecture == 'resnet' or
FLAGS.model_architecture.startswith('vgg')):
LR_SCHEDULE = [(1.0, 0), (0.1, 30), (0.01, 70), (0.001, 90), (.0001, 120)]
else:
raise ValueError('Unknown architecture ' + FLAGS.model_architecture)
if FLAGS.training_steps_multiplier != 1.0:
multiplier = FLAGS.training_steps_multiplier
LR_SCHEDULE = [(x, y * multiplier) for x, y in LR_SCHEDULE]
FLAGS.train_steps = int(FLAGS.train_steps * multiplier)
FLAGS.maskupdate_begin_step = int(FLAGS.maskupdate_begin_step * multiplier)
FLAGS.maskupdate_end_step = int(FLAGS.maskupdate_end_step * multiplier)
tf.logging.info(
'Training schedule is updated with multiplier: %.2f' % multiplier)
tf.logging.info('LR schedule: %s' % LR_SCHEDULE)
tf.logging.info('Training Steps: %d' % FLAGS.train_steps)
# The input tensor is in the range of [0, 255], we need to scale them to the
# range of [0, 1]
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
CUSTOM_SPARSITY_MAP = {}
def set_custom_sparsity_map():
if FLAGS.first_layer_sparsity > 0.:
CUSTOM_SPARSITY_MAP[
'resnet_model/initial_conv'] = FLAGS.first_layer_sparsity
if FLAGS.last_layer_sparsity > 0.:
CUSTOM_SPARSITY_MAP[
'resnet_model/final_dense'] = FLAGS.last_layer_sparsity
def lr_schedule(current_epoch):
"""Computes learning rate schedule."""
scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
if FLAGS.use_sgdr:
decay_rate = tf.train.cosine_decay_restarts(
scaled_lr, current_epoch, FLAGS.sgdr_decay_step,
t_mul=FLAGS.sgdr_t_mul, m_mul=FLAGS.sgdr_m_mul)
else:
decay_rate = (
scaled_lr * LR_SCHEDULE[0][0] * current_epoch / LR_SCHEDULE[0][1])
for mult, start_epoch in LR_SCHEDULE:
decay_rate = tf.where(current_epoch < start_epoch, decay_rate,
scaled_lr * mult)
return decay_rate
def train_function(training_method, loss, cross_loss, reg_loss, output_dir,
use_tpu):
"""Training script for resnet model.
Args:
training_method: string indicating pruning method used to compress model.
loss: tensor float32 of the cross entropy + regularization losses.
cross_loss: tensor, only cross entropy loss, passed for logging.
reg_loss: tensor, only regularization loss, passed for logging.
output_dir: string tensor indicating the directory to save summaries.
use_tpu: boolean indicating whether to run script on a tpu.
Returns:
host_call: summary tensors to be computed at each training step.
train_op: the optimization term.
"""
global_step = tf.train.get_global_step()
steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
learning_rate = lr_schedule(current_epoch)
if FLAGS.use_adam:
# We don't use step decrease for the learning rate.
learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
else:
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True)
if use_tpu:
# use CrossShardOptimizer when using TPU.
optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
if training_method == 'set':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseSETOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal,
stateless_seed_offset=FLAGS.seed)
elif training_method == 'static':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseStaticOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal,
stateless_seed_offset=FLAGS.seed)
elif training_method == 'momentum':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseMomentumOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
grow_init=FLAGS.grow_init, stateless_seed_offset=FLAGS.seed,
drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=use_tpu)
elif training_method == 'rigl':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseRigLOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency,
drop_fraction=FLAGS.drop_fraction, stateless_seed_offset=FLAGS.seed,
drop_fraction_anneal=FLAGS.drop_fraction_anneal,
initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=use_tpu)
elif training_method == 'snip':
optimizer = sparse_optimizers.SparseSnipOptimizer(
optimizer, mask_init_method=FLAGS.mask_init_method,
custom_sparsity_map=CUSTOM_SPARSITY_MAP,
default_sparsity=FLAGS.end_sparsity, use_tpu=use_tpu)
elif training_method == 'dnw':
optimizer = sparse_optimizers.SparseDNWOptimizer(
optimizer,
mask_init_method=FLAGS.mask_init_method,
custom_sparsity_map=CUSTOM_SPARSITY_MAP,
default_sparsity=FLAGS.end_sparsity,
use_tpu=use_tpu)
elif training_method in ('scratch', 'baseline'):
pass
else:
raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
# UPDATE_OPS needs to be added as a dependency due to batch norm
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops), tf.name_scope('train'):
grads_and_vars = optimizer.compute_gradients(loss)
vars_with_grad = [v for g, v in grads_and_vars if g is not None]
if not vars_with_grad:
raise ValueError(
'No gradients provided for any variable, check your graph for ops'
' that do not support gradients, between variables %s and loss %s.' %
([str(v) for _, v in grads_and_vars], loss))
train_op = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)
metrics = {
'global_step': tf.train.get_or_create_global_step(),
'loss': loss,
'cross_loss': cross_loss,
'reg_loss': reg_loss,
'learning_rate': learning_rate,
'current_epoch': current_epoch,
}
# Logging drop_fraction if dynamic sparse training.
is_dst_method = training_method in DST_METHODS
if is_dst_method:
metrics['drop_fraction'] = optimizer.drop_fraction
def flatten_list_of_vars(var_list):
flat_vars = [tf.reshape(v, [-1]) for v in var_list]
return tf.concat(flat_vars, axis=-1)
if use_tpu:
reduced_grads = [tf.tpu.cross_replica_sum(g) for g, _ in grads_and_vars]
else:
reduced_grads = [g for g, _ in grads_and_vars]
metrics['grad_norm'] = tf.norm(flatten_list_of_vars(reduced_grads))
metrics['var_norm'] = tf.norm(
flatten_list_of_vars([v for _, v in grads_and_vars]))
# Let's log some statistics from a single parameter-mask couple.
# This is useful for debugging.
test_var = pruning.get_weights()[0]
test_var_mask = pruning.get_masks()[0]
metrics.update({
'fw_nz_weight': tf.count_nonzero(test_var),
'fw_nz_mask': tf.count_nonzero(test_var_mask),
'fw_l1_weight': tf.reduce_sum(tf.abs(test_var))
})
masks = pruning.get_masks()
global_sparsity = sparse_utils.calculate_sparsity(masks)
metrics['global_sparsity'] = global_sparsity
metrics.update(
utils.mask_summaries(masks, with_img=FLAGS.log_mask_imgs_each_iteration))
host_call = (functools.partial(utils.host_call_fn, output_dir),
utils.format_tensors(metrics))
return host_call, train_op
def resnet_model_fn_w_pruning(features, labels, mode, params):
"""The model_fn for ResNet-50 with pruning.
Args:
features: A float32 batch of images.
labels: A int32 batch of labels.
mode: Specifies whether training or evaluation.
params: Dictionary of parameters passed to the model.
Returns:
A TPUEstimatorSpec for the model
"""
width = 1. if FLAGS.width <= 0 else FLAGS.width
if isinstance(features, dict):
features = features['feature']
if FLAGS.data_format == 'channels_first':
assert not FLAGS.transpose_input # channels_first only for GPU
features = tf.transpose(features, [0, 3, 1, 2])
if FLAGS.transpose_input and mode != tf_estimator.ModeKeys.PREDICT:
features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC
# Normalize the image to zero mean and unit variance.
features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)
features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)
training_method = params['training_method']
use_tpu = params['use_tpu']
def build_network():
"""Construct the network in the graph."""
if FLAGS.model_architecture == 'mobilenet_v2':
network_func = functools.partial(
mobilenetv2_model.mobilenet_v2,
expansion_factor=FLAGS.expansion_factor)
elif FLAGS.model_architecture == 'mobilenet_v1':
network_func = functools.partial(mobilenetv1_model.mobilenet_v1)
elif FLAGS.model_architecture == 'resnet':
prune_first_layer = FLAGS.first_layer_sparsity != 0.
network_func = functools.partial(
resnet_model.resnet_v1_,
resnet_depth=FLAGS.resnet_depth,
init_method=FLAGS.init_method,
end_sparsity=FLAGS.end_sparsity,
prune_first_layer=prune_first_layer)
elif FLAGS.model_architecture.startswith('vgg'):
network_func = functools.partial(
vgg.vgg,
vgg_type=FLAGS.model_architecture,
init_method=FLAGS.init_method,
end_sparsity=FLAGS.end_sparsity)
else:
raise ValueError('Unknown archiecture ' + FLAGS.archiecture)
prune_last_layer = FLAGS.last_layer_sparsity != 0.
network = network_func(
num_classes=FLAGS.num_label_classes,
# TODO remove the pruning_method option.
pruning_method='threshold',
width=width,
prune_last_layer=prune_last_layer,
data_format=FLAGS.data_format,
weight_decay=FLAGS.weight_decay)
is_training = (mode == tf_estimator.ModeKeys.TRAIN)
if FLAGS.use_batch_statistics:
is_training = True
return network(inputs=features, is_training=is_training)
if FLAGS.precision == 'bfloat16':
with contrib_tpu.bfloat16_scope():
logits = build_network()
logits = tf.cast(logits, tf.float32)
elif FLAGS.precision == 'float32':
logits = build_network()
if mode == tf_estimator.ModeKeys.PREDICT:
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
return tf_estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf_estimator.export.PredictOutput(predictions)
})
output_dir = params['output_dir']
# Calculate loss, which includes softmax cross entropy and L2 regularization.
one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)
# make sure we reuse the same label smoothing parameter is we're doing
# scratch / lottery ticket experiments.
label_smoothing = FLAGS.label_smoothing
if FLAGS.training_method == 'scratch' and FLAGS.load_mask_dir:
scratch_stripped = FLAGS.load_mask_dir.replace('/scratch', '')
label_smoothing = float(scratch_stripped.split('/')[15])
tf.logging.info('LABEL SMOOTHING USED: %.2f' % label_smoothing)
cross_loss = tf.losses.softmax_cross_entropy(
logits=logits,
onehot_labels=one_hot_labels,
label_smoothing=label_smoothing)
# Add regularization loss term
reg_loss = tf.losses.get_regularization_loss()
loss = cross_loss + reg_loss
host_call = None
if mode == tf_estimator.ModeKeys.TRAIN:
host_call, train_op = train_function(training_method, loss, cross_loss,
reg_loss, output_dir, use_tpu)
else:
train_op = None
eval_metrics = None
if mode == tf_estimator.ModeKeys.EVAL:
def metric_fn(labels, logits, cross_loss, reg_loss):
"""Calculate eval metrics."""
logging.info('In metric function')
eval_metrics = {}
predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)
eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss)
eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss)
eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
labels=labels, predictions=predictions)
# If evaluating once lets also calculate sparsities.
if FLAGS.mode == 'eval_once':
sparsity_summaries = utils.mask_summaries(pruning.get_masks())
# We call mean on a scalar to create tensor, update_op pairs.
sparsity_summaries = {k: tf.metrics.mean(v) for k, v
in sparsity_summaries.items()}
eval_metrics.update(sparsity_summaries)
return eval_metrics
tensors = [labels, logits,
tf.broadcast_to(cross_loss, tf.shape(labels)),
tf.broadcast_to(reg_loss, tf.shape(labels))]
eval_metrics = (metric_fn, tensors)
if (FLAGS.load_mask_dir and
FLAGS.training_method not in NO_MASK_INIT_METHODS):
def scaffold_fn():
"""For initialization, passed to the estimator."""
utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir,
FLAGS.output_dir, MASK_SUFFIX)
if FLAGS.initial_value_checkpoint:
utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,
FLAGS.output_dir, PARAM_SUFFIXES)
return tf.train.Scaffold()
elif (FLAGS.mask_init_method and
FLAGS.training_method not in NO_MASK_INIT_METHODS):
def scaffold_fn():
"""For initialization, passed to the estimator."""
if FLAGS.initial_value_checkpoint:
utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,
FLAGS.output_dir, PARAM_SUFFIXES)
all_masks = pruning.get_masks()
assigner = sparse_utils.get_mask_init_fn(
all_masks,
FLAGS.mask_init_method,
FLAGS.end_sparsity,
CUSTOM_SPARSITY_MAP,
erk_power_scale=FLAGS.erk_power_scale)
def init_fn(scaffold, session):
"""A callable for restoring variable from a checkpoint."""
del scaffold # Unused.
session.run(assigner)
return tf.train.Scaffold(init_fn=init_fn)
else:
assert FLAGS.training_method in NO_MASK_INIT_METHODS
scaffold_fn = None
tf.logging.info('No mask is set, starting dense.')
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
host_call=host_call,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
class ExportModelHook(tf.train.SessionRunHook):
"""Train hooks called after each session run for exporting the model."""
def __init__(self, classifier, export_dir):
self.classifier = classifier
self.global_step = None
self.export_dir = export_dir
self.last_export = 0
self.supervised_input_receiver_fn = (
contrib_estimator.build_raw_supervised_input_receiver_fn(
{
'feature':
tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3])
}, tf.placeholder(dtype=tf.int32, shape=[None])))
def begin(self):
self.global_step = tf.train.get_or_create_global_step()
def after_run(self, run_context, run_values):
# export saved model
global_step = run_context.session.run(self.global_step)
if global_step - self.last_export >= FLAGS.export_model_freq:
tf.logging.info(
'Export model for prediction (step={}) ...'.format(global_step))
self.last_export = global_step
contrib_estimator.export_all_saved_models(
self.classifier, os.path.join(self.export_dir, str(global_step)), {
tf_estimator.ModeKeys.EVAL:
self.supervised_input_receiver_fn,
tf_estimator.ModeKeys.PREDICT:
imagenet_input.image_serving_input_fn
})
def main(argv):
del argv # Unused.
tf.enable_resource_variables()
tf.set_random_seed(FLAGS.seed)
set_lr_schedule()
set_custom_sparsity_map()
folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),
str(FLAGS.maskupdate_begin_step),
str(FLAGS.maskupdate_end_step),
str(FLAGS.maskupdate_frequency),
str(FLAGS.drop_fraction),
str(FLAGS.label_smoothing),
str(FLAGS.weight_decay))
output_dir = FLAGS.output_dir
if FLAGS.use_folder_stub:
output_dir = os.path.join(output_dir, folder_stub)
export_dir = os.path.join(output_dir, 'export_dir')
# we pass the updated eval and train string to the params dictionary.
params = {}
params['output_dir'] = output_dir
params['training_method'] = FLAGS.training_method
params['use_tpu'] = FLAGS.use_tpu
dataset_func = functools.partial(
imagenet_input.ImageNetInput, data_dir=FLAGS.data_directory,
transpose_input=False, num_parallel_calls=FLAGS.num_parallel_calls,
use_bfloat16=False)
imagenet_train, imagenet_eval = [dataset_func(is_training=is_training)
for is_training in [True, False]]
run_config = tpu_config.RunConfig(
master=FLAGS.master,
model_dir=output_dir,
save_checkpoints_steps=FLAGS.steps_per_checkpoint,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False),
tpu_config=tpu_config.TPUConfig(
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.num_cores,
tpu_job_name=FLAGS.tpu_job_name))
classifier = tpu_estimator.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=resnet_model_fn_w_pruning,
params=params,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
cpu_classifier = tpu_estimator.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=resnet_model_fn_w_pruning,
params=params,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
export_to_tpu=False,
eval_batch_size=FLAGS.eval_batch_size)
if FLAGS.num_eval_images % FLAGS.eval_batch_size != 0:
raise ValueError(
'eval_batch_size (%d) must evenly divide num_eval_images(%d)!' %
(FLAGS.eval_batch_size, FLAGS.num_eval_images))
eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
if FLAGS.mode == 'eval_once':
ckpt_path = os.path.join(output_dir, FLAGS.eval_once_ckpt_prefix)
dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval
classifier.evaluate(
input_fn=dataset.input_fn,
steps=eval_steps,
checkpoint_path=ckpt_path,
name='{0}'.format(FLAGS.eval_once_ckpt_prefix))
elif FLAGS.mode == 'eval':
# Run evaluation when there's a new checkpoint
for ckpt in evaluation.checkpoints_iterator(output_dir):
tf.logging.info('Starting to evaluate.')
try:
dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval
classifier.evaluate(
input_fn=dataset.input_fn,
steps=eval_steps,
checkpoint_path=ckpt,
name='eval')
# Terminate eval job when final checkpoint is reached
global_step = int(os.path.basename(ckpt).split('-')[1])
if global_step >= FLAGS.train_steps:
tf.logging.info(
'Evaluation finished after training step %d' % global_step)
break
except tf.errors.NotFoundError:
logging('Checkpoint no longer exists,skipping checkpoint.')
else:
global_step = estimator._load_global_step_from_checkpoint_dir(output_dir)
# Session run hooks to export model for prediction
export_hook = ExportModelHook(cpu_classifier, export_dir)
hooks = [export_hook]
if FLAGS.mode == 'train':
tf.logging.info('start training...')
classifier.train(
input_fn=imagenet_train.input_fn,
hooks=hooks,
max_steps=FLAGS.train_steps)
else:
assert FLAGS.mode == 'train_and_eval'
tf.logging.info('start training and eval...')
while global_step < FLAGS.train_steps:
next_checkpoint = min(global_step + FLAGS.steps_per_eval,
FLAGS.train_steps)
classifier.train(
input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
global_step = next_checkpoint
logging('Completed training up to step :', global_step)
classifier.evaluate(input_fn=imagenet_eval.input_fn, steps=eval_steps)
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/imagenet_resnet/mobilenetv1_model.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Straightforward MobileNet v1 for inputs of size 224x224."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import flags
from rigl.imagenet_resnet import resnet_model
from rigl.imagenet_resnet.pruning_layers import sparse_conv2d
from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
FLAGS = flags.FLAGS
def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def depthwise_conv2d_fixed_padding(inputs,
kernel_size,
stride,
data_format='channels_first',
name=None):
"""Depthwise Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
kernel_size: Int designating size of kernel to be used in the convolution.
stride: Int specifying the stride. If stride >1, the input is downsampled.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
name: String that specifies name for model layer.
Returns:
The output activation tensor of size [batch, filters, height_out, width_out]
Raises:
ValueError: If the data_format provided is not a valid string.
"""
if stride > 1:
inputs = resnet_model.fixed_padding(
inputs, kernel_size, data_format=data_format)
padding = 'SAME' if stride == 1 else 'VALID'
if data_format == 'channels_last':
data_format_channels = 'NHWC'
elif data_format == 'channels_first':
data_format_channels = 'NCHW'
else:
raise ValueError('Not a valid channel string:', data_format)
return contrib_layers.separable_conv2d(
inputs=inputs,
num_outputs=None,
kernel_size=kernel_size,
stride=stride,
padding=padding,
data_format=data_format_channels,
activation_fn=None,
weights_regularizer=None,
biases_initializer=None,
biases_regularizer=None,
scope=name)
def conv2d_fixed_padding(inputs,
filters,
kernel_size,
strides,
pruning_method='baseline',
data_format='channels_first',
weight_decay=0.,
name=None):
"""Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
filters: Int specifying number of filters for the first two convolutions.
kernel_size: Int designating size of kernel to be used in the convolution.
strides: Int specifying the stride. If stride >1, the input is downsampled.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
weight_decay: Weight for the l2 regularization loss.
name: String that specifies name for model layer.
Returns:
The output activation tensor of size [batch, filters, height_out, width_out]
Raises:
ValueError: If the data_format provided is not a valid string.
"""
if strides > 1:
inputs = resnet_model.fixed_padding(
inputs, kernel_size, data_format=data_format)
padding = 'VALID'
else:
padding = 'SAME'
kernel_initializer = tf.variance_scaling_initializer()
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
return sparse_conv2d(
x=inputs,
units=filters,
activation=None,
kernel_size=[kernel_size, kernel_size],
use_bias=False,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique=pruning_method,
normalizer_fn=None,
strides=[strides, strides],
padding=padding,
data_format=data_format,
name=name)
def mbv1_block_(inputs,
filters,
is_training,
stride,
width=1.,
block_id=0,
pruning_method='baseline',
data_format='channels_first',
weight_decay=0.):
"""Standard building block for mobilenetv1 networks.
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
filters: Int specifying number of filters for the first two convolutions.
is_training: Boolean specifying whether the model is training.
stride: Int specifying the stride. If stride >1, the input is downsampled.
width: multiplier for channel dimensions
block_id: which block this is
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
weight_decay: Weight for the l2 regularization loss.
Returns:
The output activation tensor.
"""
# separable_conv_2d followed by contracting 1x1 conv.
end_point = 'depthwise_nxn_%s' % block_id
# Depthwise
depthwise_out = depthwise_conv2d_fixed_padding(
inputs=inputs,
kernel_size=3,
stride=stride,
data_format=data_format,
name=end_point)
depthwise_out = resnet_model.batch_norm_relu(
depthwise_out, is_training, relu=True, data_format=data_format)
# Contraction
end_point = 'contraction_1x1_%s' % block_id
divisible_by = 8
if block_id == 0:
divisible_by = 1
out_filters = _make_divisible(int(width * filters), divisor=divisible_by)
contraction_out = conv2d_fixed_padding(
inputs=depthwise_out,
filters=out_filters,
kernel_size=1,
strides=1,
pruning_method=pruning_method,
data_format=data_format,
weight_decay=weight_decay,
name=end_point)
contraction_out = resnet_model.batch_norm_relu(
contraction_out, is_training, relu=True, data_format=data_format)
output = contraction_out
return output
def mobilenet_v1_generator(num_classes=1000,
pruning_method='baseline',
width=1.,
prune_last_layer=False,
data_format='channels_first',
weight_decay=0.,
name=None):
"""Generator for mobilenet v2 models.
Args:
num_classes: Int number of possible classes for image classification.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
width: Float that scales the number of filters in each layer.
prune_last_layer: Whether or not to prune the last layer.
data_format: String either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
weight_decay: Weight for the l2 regularization loss.
name: String that specifies name for model layer.
Returns:
Model `function` that takes in `inputs` and `is_training` and returns the
output `Tensor` of the ResNet model.
"""
def model(inputs, is_training):
"""Creation of the model graph."""
with tf.variable_scope(name, 'resnet_model'):
inputs = resnet_model.fixed_padding(
inputs, kernel_size=3, data_format=data_format)
padding = 'VALID'
kernel_initializer = tf.variance_scaling_initializer()
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
inputs = tf.layers.conv2d(
inputs=inputs,
filters=_make_divisible(32 * width),
kernel_size=3,
strides=2,
padding=padding,
use_bias=False,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
data_format=data_format,
name='initial_conv')
inputs = tf.identity(inputs, 'initial_conv')
inputs = resnet_model.batch_norm_relu(
inputs, is_training, data_format=data_format)
mb_block = functools.partial(
mbv1_block_,
is_training=is_training,
width=width,
pruning_method=pruning_method,
data_format=data_format,
weight_decay=weight_decay)
inputs = mb_block(inputs, filters=64, stride=1, block_id=0)
inputs = mb_block(inputs, filters=128, stride=2, block_id=1)
inputs = mb_block(inputs, filters=128, stride=1, block_id=2)
inputs = mb_block(inputs, filters=256, stride=2, block_id=3)
inputs = mb_block(inputs, filters=256, stride=1, block_id=4)
inputs = mb_block(inputs, filters=512, stride=2, block_id=5)
inputs = mb_block(inputs, filters=512, stride=1, block_id=6)
inputs = mb_block(inputs, filters=512, stride=1, block_id=7)
inputs = mb_block(inputs, filters=512, stride=1, block_id=8)
inputs = mb_block(inputs, filters=512, stride=1, block_id=9)
inputs = mb_block(inputs, filters=512, stride=1, block_id=10)
inputs = mb_block(inputs, filters=1024, stride=2, block_id=11)
inputs = mb_block(inputs, filters=1024, stride=1, block_id=12)
last_block_filters = _make_divisible(int(1024 * width), 8)
if data_format == 'channels_last':
pool_size = (inputs.shape[1], inputs.shape[2])
elif data_format == 'channels_first':
pool_size = (inputs.shape[2], inputs.shape[3])
inputs = tf.layers.average_pooling2d(
inputs=inputs,
pool_size=pool_size,
strides=1,
padding='VALID',
data_format=data_format,
name='final_avg_pool')
inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs, [-1, last_block_filters])
kernel_initializer = tf.variance_scaling_initializer()
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
if prune_last_layer:
inputs = sparse_fully_connected(
x=inputs,
units=num_classes,
sparsity_technique=pruning_method
if prune_last_layer else 'baseline',
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
name='final_dense')
else:
inputs = tf.layers.dense(
inputs=inputs,
units=num_classes,
activation=None,
use_bias=True,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
name='final_dense')
inputs = tf.identity(inputs, 'final_dense')
return inputs
model.default_image_size = 224
return model
def mobilenet_v1(num_classes,
pruning_method='baseline',
width=1.,
prune_last_layer=True,
data_format='channels_first',
weight_decay=0.):
"""Returns the mobilenet_V1 model for a given size and number of output classes.
Args:
num_classes: Int number of possible classes for image classification.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
width: Float multiplier of the number of filters in each layer.
prune_last_layer: Whether or not to prune the last layer.
data_format: String specifying either "channels_first" for `[batch,
channels, height, width]` or "channels_last for `[batch, height, width,
channels]`.
weight_decay: Weight for the l2 regularization loss.
Raises:
ValueError: If the resnet_depth int is not in the model_params dictionary.
"""
return mobilenet_v1_generator(num_classes, pruning_method, width,
prune_last_layer, data_format, weight_decay)
================================================
FILE: rigl/imagenet_resnet/mobilenetv2_model.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Straightforward MobileNet v2 for inputs of size 224x224."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import flags
from rigl.imagenet_resnet import resnet_model
from rigl.imagenet_resnet.pruning_layers import sparse_conv2d
from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
FLAGS = flags.FLAGS
def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def depthwise_conv2d_fixed_padding(inputs,
kernel_size,
stride,
data_format='channels_first',
name=None):
"""Depthwise Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
kernel_size: Int designating size of kernel to be used in the convolution.
stride: Int specifying the stride. If stride >1, the input is downsampled.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
name: String that specifies name for model layer.
Returns:
The output activation tensor of size [batch, filters, height_out, width_out]
Raises:
ValueError: If the data_format provided is not a valid string.
"""
if stride > 1:
inputs = resnet_model.fixed_padding(
inputs, kernel_size, data_format=data_format)
padding = 'SAME' if stride == 1 else 'VALID'
if data_format == 'channels_last':
data_format_channels = 'NHWC'
elif data_format == 'channels_first':
data_format_channels = 'NCHW'
else:
raise ValueError('Not a valid channel string:', data_format)
return contrib_layers.separable_conv2d(
inputs=inputs,
num_outputs=None,
kernel_size=kernel_size,
stride=stride,
padding=padding,
data_format=data_format_channels,
activation_fn=None,
weights_regularizer=None,
biases_initializer=None,
biases_regularizer=None,
scope=name)
def conv2d_fixed_padding(inputs,
filters,
kernel_size,
strides,
pruning_method='baseline',
data_format='channels_first',
weight_decay=0.,
name=None):
"""Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
filters: Int specifying number of filters for the first two convolutions.
kernel_size: Int designating size of kernel to be used in the convolution.
strides: Int specifying the stride. If stride >1, the input is downsampled.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
weight_decay: Weight for the l2 regularization loss.
name: String that specifies name for model layer.
Returns:
The output activation tensor of size [batch, filters, height_out, width_out]
Raises:
ValueError: If the data_format provided is not a valid string.
"""
if strides > 1:
inputs = resnet_model.fixed_padding(
inputs, kernel_size, data_format=data_format)
padding = 'VALID'
else:
padding = 'SAME'
kernel_initializer = tf.variance_scaling_initializer()
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
return sparse_conv2d(
x=inputs,
units=filters,
activation=None,
kernel_size=[kernel_size, kernel_size],
use_bias=False,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique=pruning_method,
normalizer_fn=None,
strides=[strides, strides],
padding=padding,
data_format=data_format,
name=name)
def inverted_res_block_(inputs,
filters,
is_training,
stride,
width=1.,
expansion_factor=6.,
block_id=0,
pruning_method='baseline',
data_format='channels_first',
weight_decay=0.,):
"""Standard building block for mobilenetv2 networks.
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
filters: Int specifying number of filters for the first two convolutions.
is_training: Boolean specifying whether the model is training.
stride: Int specifying the stride. If stride >1, the input is downsampled.
width: multiplier for channel dimensions
expansion_factor: How much to increase the filters before the depthwise
conv.
block_id: which block this is
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
weight_decay: Weight for the l2 regularization loss.
Returns:
The output activation tensor.
"""
# 1x1 expanded conv, followed by separable_conv_2d followed by
# contracting 1x1 conv.
shortcut = inputs
if data_format == 'channels_first':
prev_depth = inputs.get_shape().as_list()[1]
elif data_format == 'channels_last':
prev_depth = inputs.get_shape().as_list()[3]
else:
raise ValueError('Unknown data_format ' + data_format)
# Expand
multiplier = expansion_factor if block_id > 0 else 1
# skip the expansion if this is the first block
if block_id:
end_point = 'expand_1x1_%s' % block_id
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=int(multiplier * prev_depth),
kernel_size=1,
strides=1,
pruning_method=pruning_method,
data_format=data_format,
weight_decay=weight_decay,
name=end_point)
inputs = resnet_model.batch_norm_relu(
inputs, is_training, relu=True, data_format=data_format)
end_point = 'depthwise_nxn_%s' % block_id
# Depthwise
depthwise_out = depthwise_conv2d_fixed_padding(
inputs=inputs,
kernel_size=3,
stride=stride,
data_format=data_format,
name=end_point)
depthwise_out = resnet_model.batch_norm_relu(
depthwise_out, is_training, relu=True, data_format=data_format)
# Contraction
end_point = 'contraction_1x1_%s' % block_id
divisible_by = 8
if block_id == 0:
divisible_by = 1
out_filters = _make_divisible(int(width * filters), divisor=divisible_by)
contraction_out = conv2d_fixed_padding(
inputs=depthwise_out,
filters=out_filters,
kernel_size=1,
strides=1,
pruning_method=pruning_method,
data_format=data_format,
weight_decay=weight_decay,
name=end_point)
contraction_out = resnet_model.batch_norm_relu(
contraction_out, is_training, relu=False, data_format=data_format)
output = contraction_out
if prev_depth == out_filters and stride == 1:
output += shortcut
return output
def mobilenet_v2_generator(num_classes=1000,
pruning_method='baseline',
width=1.,
expansion_factor=6.,
prune_last_layer=False,
data_format='channels_first',
weight_decay=0.,
name=None):
"""Generator for mobilenet v2 models.
Args:
num_classes: Int number of possible classes for image classification.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
width: Float that scales the number of filters in each layer.
expansion_factor: How much to expand the input filters for the depthwise
conv.
prune_last_layer: Whether or not to prune the last layer.
data_format: String either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
weight_decay: Weight for the l2 regularization loss.
name: String that specifies name for model layer.
Returns:
Model `function` that takes in `inputs` and `is_training` and returns the
output `Tensor` of the ResNet model.
"""
def model(inputs, is_training):
"""Creation of the model graph."""
with tf.variable_scope(name, 'resnet_model'):
inputs = resnet_model.fixed_padding(
inputs, kernel_size=3, data_format=data_format)
padding = 'VALID'
kernel_initializer = tf.variance_scaling_initializer()
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
inputs = tf.layers.conv2d(
inputs=inputs,
filters=_make_divisible(32 * width),
kernel_size=3,
strides=2,
padding=padding,
use_bias=False,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
data_format=data_format,
name='initial_conv')
inputs = tf.identity(inputs, 'initial_conv')
inputs = resnet_model.batch_norm_relu(
inputs, is_training, data_format=data_format)
inverted_res_block = functools.partial(
inverted_res_block_,
is_training=is_training,
width=width,
expansion_factor=expansion_factor,
pruning_method=pruning_method,
data_format=data_format,
weight_decay=weight_decay)
inputs = inverted_res_block(inputs, filters=16, stride=1, block_id=0)
inputs = inverted_res_block(inputs, filters=24, stride=2, block_id=1)
inputs = inverted_res_block(inputs, filters=24, stride=1, block_id=2)
inputs = inverted_res_block(inputs, filters=32, stride=2, block_id=3)
inputs = inverted_res_block(inputs, filters=32, stride=1, block_id=4)
inputs = inverted_res_block(inputs, filters=32, stride=1, block_id=5)
inputs = inverted_res_block(inputs, filters=64, stride=2, block_id=6)
inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=7)
inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=8)
inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=9)
inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=10)
inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=11)
inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=12)
inputs = inverted_res_block(inputs, filters=160, stride=2, block_id=13)
inputs = inverted_res_block(inputs, filters=160, stride=1, block_id=14)
inputs = inverted_res_block(inputs, filters=160, stride=1, block_id=15)
inputs = inverted_res_block(inputs, filters=320, stride=1, block_id=16)
last_block_filters = max(1280, _make_divisible(1280 * width, 8))
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=last_block_filters,
kernel_size=1,
strides=1,
pruning_method=pruning_method,
data_format=data_format,
weight_decay=weight_decay,
name='final_1x1_conv')
inputs = resnet_model.batch_norm_relu(
inputs, is_training, data_format=data_format)
if data_format == 'channels_last':
pool_size = (inputs.shape[1], inputs.shape[2])
elif data_format == 'channels_first':
pool_size = (inputs.shape[2], inputs.shape[3])
inputs = tf.layers.average_pooling2d(
inputs=inputs,
pool_size=pool_size,
strides=1,
padding='VALID',
data_format=data_format,
name='final_avg_pool')
inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs, [-1, last_block_filters])
kernel_initializer = tf.variance_scaling_initializer()
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
if prune_last_layer:
inputs = sparse_fully_connected(
x=inputs,
units=num_classes,
sparsity_technique=pruning_method
if prune_last_layer else 'baseline',
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
name='final_dense')
else:
inputs = tf.layers.dense(
inputs=inputs,
units=num_classes,
activation=None,
use_bias=True,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
name='final_dense')
inputs = tf.identity(inputs, 'final_dense')
return inputs
model.default_image_size = 224
return model
def mobilenet_v2(num_classes,
pruning_method='baseline',
width=1.,
expansion_factor=6.,
prune_last_layer=True,
data_format='channels_first',
weight_decay=0.,):
"""Returns the mobilenet_V2 model for a given size and number of output classes.
Args:
num_classes: Int number of possible classes for image classification.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
width: Float multiplier of the number of filters in each layer.
expansion_factor: How much to increase the number of filters before the
depthwise conv.
prune_last_layer: Whether or not to prune the last layer.
data_format: String specifying either "channels_first" for `[batch,
channels, height, width]` or "channels_last for `[batch, height, width,
channels]`.
weight_decay: Weight for the l2 regularization loss.
Raises:
ValueError: If the resnet_depth int is not in the model_params dictionary.
"""
return mobilenet_v2_generator(
num_classes, pruning_method, width, expansion_factor, prune_last_layer,
data_format, weight_decay)
================================================
FILE: rigl/imagenet_resnet/pruning_layers.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tensorflow layers with parameters for implementing pruning."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.python.ops import init_ops
def get_model_variables(getter,
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
rename=None,
use_resource=None,
**_):
"""This ensure variables are retrieved in a consistent way for core layers."""
short_name = name.split('/')[-1]
if rename and short_name in rename:
name_components = name.split('/')
name_components[-1] = rename[short_name]
name = '/'.join(name_components)
return variables.model_variable(
name,
shape=shape,
dtype=dtype,
initializer=initializer,
regularizer=regularizer,
collections=collections,
trainable=trainable,
caching_device=caching_device,
partitioner=partitioner,
custom_getter=getter,
use_resource=use_resource)
def variable_getter(rename=None):
"""Ensures scope is respected and consistently used."""
def layer_variable_getter(getter, *args, **kwargs):
kwargs['rename'] = rename
return get_model_variables(getter, *args, **kwargs)
return layer_variable_getter
def sparse_conv2d(x,
units,
kernel_size,
activation=None,
use_bias=False,
kernel_initializer=None,
kernel_regularizer=None,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique='baseline',
normalizer_fn=None,
strides=(1, 1),
padding='SAME',
data_format='channels_last',
name=None):
"""Function that constructs conv2d with any desired pruning method.
Args:
x: Input, float32 tensor.
units: Int representing size of output tensor.
kernel_size: The size of the convolutional window, int of list of ints.
activation: If None, a linear activation is used.
use_bias: Boolean specifying whether bias vector should be used.
kernel_initializer: Initializer for the convolution weights.
kernel_regularizer: Regularization method for the convolution weights.
bias_initializer: Initalizer of the bias vector.
biases_regularizer: Optional regularizer for the bias vector.
sparsity_technique: Method used to introduce sparsity.
['threshold', 'baseline']
normalizer_fn: function used to transform the output activations.
strides: stride length of convolution, a single int is expected.
padding: May be populated as 'VALID' or 'SAME'.
data_format: Either 'channels_last', 'channels_first'.
name: String speciying name scope of layer in network.
Returns:
Output: activations.
Raises:
ValueError: If the rank of the input is not greater than 2.
"""
if data_format == 'channels_last':
data_format_channels = 'NHWC'
elif data_format == 'channels_first':
data_format_channels = 'NCHW'
else:
raise ValueError('Not a valid channel string:', data_format)
layer_variable_getter = variable_getter({
'bias': 'biases',
'kernel': 'weights',
})
input_rank = x.get_shape().ndims
if input_rank != 4:
raise ValueError('Rank not supported {}'.format(input_rank))
with tf.variable_scope(
name, 'Conv', [x], custom_getter=layer_variable_getter) as sc:
input_shape = x.get_shape().as_list()
if input_shape[-1] is None:
raise ValueError('The last dimension of the inputs to `Convolution` '
'should be defined. Found `None`.')
pruning_methods = ['threshold']
if sparsity_technique in pruning_methods:
return layers.masked_conv2d(
inputs=x,
num_outputs=units,
kernel_size=kernel_size[0],
stride=strides[0],
padding=padding,
data_format=data_format_channels,
rate=1,
activation_fn=activation,
weights_initializer=kernel_initializer,
weights_regularizer=kernel_regularizer,
normalizer_fn=normalizer_fn,
normalizer_params=None,
biases_initializer=bias_initializer,
biases_regularizer=biases_regularizer,
outputs_collections=None,
trainable=True,
scope=sc)
elif sparsity_technique == 'baseline':
return tf.layers.conv2d(
inputs=x,
filters=units,
kernel_size=kernel_size,
strides=strides,
padding=padding,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
data_format=data_format,
name=name)
else:
raise ValueError(
'Unsupported sparsity technique {}'.format(sparsity_technique))
def sparse_fully_connected(x,
units,
activation=None,
use_bias=True,
kernel_initializer=None,
kernel_regularizer=None,
bias_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
sparsity_technique='baseline',
name=None):
"""Constructs sparse_fully_connected with any desired pruning method.
Args:
x: Input, float32 tensor.
units: Int representing size of output tensor.
activation: If None, a linear activation is used.
use_bias: Boolean specifying whether bias vector should be used.
kernel_initializer: Initializer for the convolution weights.
kernel_regularizer: Regularization method for the convolution weights.
bias_initializer: Initalizer of the bias vector.
biases_regularizer: Optional regularizer for the bias vector.
sparsity_technique: Method used to introduce sparsity. ['baseline',
'threshold']
name: String speciying name scope of layer in network.
Returns:
Output: activations.
Raises:
ValueError: If the rank of the input is not greater than 2.
"""
layer_variable_getter = variable_getter({
'bias': 'biases',
'kernel': 'weights',
})
with tf.variable_scope(
name, 'Dense', [x], custom_getter=layer_variable_getter) as sc:
input_shape = x.get_shape().as_list()
if input_shape[-1] is None:
raise ValueError('The last dimension of the inputs to `Dense` '
'should be defined. Found `None`.')
pruning_methods = ['threshold']
if sparsity_technique in pruning_methods:
return layers.masked_fully_connected(
inputs=x,
num_outputs=units,
activation_fn=activation,
weights_initializer=kernel_initializer,
weights_regularizer=kernel_regularizer,
biases_initializer=bias_initializer,
biases_regularizer=biases_regularizer,
outputs_collections=None,
trainable=True,
scope=sc)
elif sparsity_technique == 'baseline':
return tf.layers.dense(
inputs=x,
units=units,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_initializer=bias_initializer,
bias_regularizer=biases_regularizer,
name=name)
else:
raise ValueError(
'Unsupported sparsity technique {}'.format(sparsity_technique))
================================================
FILE: rigl/imagenet_resnet/resnet_model.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet modified to including pruning layers if specified.
Residual networks (ResNets) were proposed in:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from absl import flags
from rigl.imagenet_resnet.pruning_layers import sparse_conv2d
from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.python.ops import init_ops
FLAGS = flags.FLAGS
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
def batch_norm_relu(inputs, is_training, relu=True, init_zero=False,
data_format='channels_first'):
"""Performs a batch normalization followed by a ReLU.
Args:
inputs: `Tensor` of shape `[batch, channels, ...]`.
is_training: `bool` for whether the model is training.
relu: `bool` if False, omits the ReLU operation.
init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0 instead of 1 (default).
data_format: `str` either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
Returns:
A normalized `Tensor` with the same `data_format`.
"""
if init_zero:
gamma_initializer = tf.zeros_initializer()
else:
gamma_initializer = tf.ones_initializer()
if data_format == 'channels_first':
axis = 1
else:
axis = 3
inputs = tf.layers.batch_normalization(
inputs=inputs,
axis=axis,
momentum=BATCH_NORM_DECAY,
epsilon=BATCH_NORM_EPSILON,
center=True,
scale=True,
training=is_training,
fused=True,
gamma_initializer=gamma_initializer)
if relu:
inputs = tf.nn.relu(inputs)
return inputs
def fixed_padding(inputs, kernel_size, data_format='channels_first'):
"""Pads the input along the spatial dimensions independently of input size.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]` or
`[batch, height, width, channels]` depending on `data_format`.
kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`
operations. Should be a positive integer.
data_format: `str` either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
Returns:
A padded `Tensor` of the same `data_format` with size either intact
(if `kernel_size == 1`) or padded (if `kernel_size > 1`).
"""
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
if data_format == 'channels_first':
padded_inputs = tf.pad(inputs, [[0, 0], [0, 0],
[pad_beg, pad_end], [pad_beg, pad_end]])
else:
padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
[pad_beg, pad_end], [0, 0]])
return padded_inputs
class RandomSparseInitializer(init_ops.Initializer):
"""An initializer that sets a fraction of values to zero."""
def __init__(self, sparsity, seed=None, dtype=tf.float32):
if sparsity < 0. or sparsity > 1.:
raise ValueError('sparsity must be in the range [0., 1.].')
self.kernel_initializer = tf.variance_scaling_initializer(seed=seed,
dtype=dtype)
self.seed = seed
self.dtype = dtype
self.sparsity = float(sparsity)
def __call__(self, *args, **kwargs):
init_tensor = self.kernel_initializer(*args, **kwargs)
rand_vals = tf.random_uniform(tf.shape(init_tensor))
threshold = tf.constant(self.sparsity)
masked_tensor = tf.where(rand_vals < threshold,
tf.zeros_like(rand_vals), init_tensor)
return masked_tensor
def get_config(self):
return {
'seed': self.seed,
'dtype': self.dtype.name,
'sparsity': self.sparsity
}
class SparseConvVarianceScalingInitializer(init_ops.Initializer):
"""Define an initializer for an already sparse layer."""
def __init__(self, sparsity, seed=None, dtype=tf.float32):
if sparsity < 0. or sparsity >= 1.:
raise ValueError('sparsity must be in the range [0., 1.).')
self.sparsity = sparsity
self.seed = seed
def __call__(self, shape, dtype=None, partition_info=None):
if partition_info is not None:
raise ValueError('partition_info not supported.')
if dtype is None:
dtype = self.dtype
# Calculate number of non-zero weights
nnz = 1.
for d in shape:
nnz *= d
nnz *= (1. - self.sparsity)
input_channels = shape[-2]
n = nnz / input_channels
variance = (2. / n)**.5
return tf.random_normal(shape, 0, variance, dtype, seed=self.seed)
def get_config(self):
return {
'seed': self.seed,
'dtype': self.dtype.name,
}
class SparseFCVarianceScalingInitializer(init_ops.Initializer):
"""Define an initializer for an already sparse layer."""
def __init__(self, sparsity, seed=None, dtype=tf.float32):
if sparsity < 0. or sparsity >= 1.:
raise ValueError('sparsity must be in the range [0., 1.).')
self.sparsity = sparsity
self.seed = seed
def __call__(self, shape, dtype=None, partition_info=None):
if partition_info is not None:
raise ValueError('partition_info not supported.')
if dtype is None:
dtype = self.dtype
if len(shape) != 2:
raise ValueError('Weights must be 2-dimensional.')
fan_in = shape[0]
fan_out = shape[1]
# Calculate number of non-zero weights
nnz = 1.
for d in shape:
nnz *= d
nnz *= (1. - self.sparsity)
limit = math.sqrt(6. / (nnz / fan_out + nnz / fan_in))
return tf.random_uniform(shape, -limit, limit, dtype, seed=self.seed)
def get_config(self):
return {
'seed': self.seed,
'dtype': self.dtype.name,
}
def _pick_initializer(kernel_initializer, init_method, pruning_method,
end_sparsity):
"""Updates the initializer selected, if necessary."""
if init_method == 'sparse':
if pruning_method != 'threshold':
raise ValueError(
'Unsupported combination of flags, pruning_method must be threshold'
' if init_method is `sparse`.')
else:
kernel_initializer = SparseFCVarianceScalingInitializer(end_sparsity)
elif init_method == 'random_zeros':
if pruning_method != 'baseline':
raise ValueError(
'Unsupported combination of flags, pruning_method must be '
'baseline if init_method is `random_zeros`.')
else:
kernel_initializer = RandomSparseInitializer(end_sparsity)
return kernel_initializer
def conv2d_fixed_padding(inputs,
filters,
kernel_size,
strides,
pruning_method='baseline',
init_method='baseline',
data_format='channels_first',
end_sparsity=0.,
weight_decay=0.,
init_scale=1.0,
name=None):
"""Strided 2-D convolution with explicit padding.
The padding is consistent and is based only on `kernel_size`, not on the
dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
filters: Int specifying number of filters for the first two convolutions.
kernel_size: Int designating size of kernel to be used in the convolution.
strides: Int specifying the stride. If stride >1, the input is downsampled.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with
pruning_method == 'scratch'. 'random_zeros' set random weights to zero
using end_sparsoty parameter and used with 'baseline' method.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
init_scale: float, passed to the VarianceScalingInitializer.
name: String that specifies name for model layer.
Returns:
The output activation tensor of size [batch, filters, height_out, width_out]
Raises:
ValueError: If the data_format provided is not a valid string.
"""
if strides > 1:
inputs = fixed_padding(
inputs, kernel_size, data_format=data_format)
padding = 'SAME' if strides == 1 else 'VALID'
kernel_initializer = tf.variance_scaling_initializer(scale=init_scale)
kernel_initializer = _pick_initializer(kernel_initializer, init_method,
pruning_method, end_sparsity)
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
return sparse_conv2d(
x=inputs,
units=filters,
activation=None,
kernel_size=[kernel_size, kernel_size],
use_bias=False,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique=pruning_method,
normalizer_fn=None,
strides=[strides, strides],
padding=padding,
data_format=data_format,
name=name)
def residual_block_(inputs,
filters,
is_training,
strides,
use_projection=False,
pruning_method='baseline',
init_method='baseline',
data_format='channels_first',
end_sparsity=0.,
weight_decay=0.,
name=''):
"""Standard building block for residual networks with BN after convolutions.
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
filters: Int specifying number of filters for the first two convolutions.
is_training: Boolean specifying whether the model is training.
strides: Int specifying the stride. If stride >1, the input is downsampled.
use_projection: Boolean for whether the layer should use a projection
shortcut Often, use_projection=True for the first block of a block group.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with
pruning_method == 'scratch'. 'random_zeros' sets random weights to zero
using end_sparsoty parameter and used with 'baseline' method.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
name: String that specifies name for model layer.
Returns:
The output activation tensor.
"""
shortcut = inputs
if use_projection:
# Projection shortcut in first layer to match filters and strides
end_point = 'residual_projection_%s' % name
shortcut = conv2d_fixed_padding(
inputs=inputs,
filters=filters,
kernel_size=1,
strides=strides,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
shortcut = batch_norm_relu(
shortcut, is_training, relu=False, data_format=data_format)
end_point = 'residual_1_%s' % name
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=filters,
kernel_size=3,
strides=strides,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
inputs = batch_norm_relu(
inputs, is_training, data_format=data_format)
end_point = 'residual_2_%s' % name
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=filters,
kernel_size=3,
strides=1,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
inputs = batch_norm_relu(
inputs, is_training, relu=False, init_zero=True, data_format=data_format)
return tf.nn.relu(inputs + shortcut)
def bottleneck_block_(inputs,
filters,
is_training,
strides,
use_projection=False,
pruning_method='baseline',
init_method='baseline',
data_format='channels_first',
end_sparsity=0.,
weight_decay=0.,
name=None):
"""Bottleneck block variant for residual networks with BN after convolutions.
Args:
inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,
width].
filters: Int specifying number of filters for the first two convolutions.
is_training: Boolean specifying whether the model is training.
strides: Int specifying the stride. If stride >1, the input is downsampled.
use_projection: Boolean for whether the layer should use a projection
shortcut Often, use_projection=True for the first block of a block group.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with
pruning_method == 'scratch'. 'random_zeros' set random weights to zero
using end_sparsoty parameter and used with 'baseline' method.
data_format: String that specifies either "channels_first" for [batch,
channels, height,width] or "channels_last" for [batch, height, width,
channels].
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
name: String that specifies name for model layer.
Returns:
The output activation tensor.
"""
shortcut = inputs
if use_projection:
# Projection shortcut only in first block within a group. Bottleneck blocks
# end with 4 times the number of filters.
filters_out = 4 * filters
end_point = 'bottleneck_projection_%s' % name
shortcut = conv2d_fixed_padding(
inputs=inputs,
filters=filters_out,
kernel_size=1,
strides=strides,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
shortcut = batch_norm_relu(
shortcut, is_training, relu=False, data_format=data_format)
end_point = 'bottleneck_1_%s' % name
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=filters,
kernel_size=1,
strides=1,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
inputs = batch_norm_relu(
inputs, is_training, data_format=data_format)
end_point = 'bottleneck_2_%s' % name
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=filters,
kernel_size=3,
strides=strides,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
inputs = batch_norm_relu(
inputs, is_training, data_format=data_format)
end_point = 'bottleneck_3_%s' % name
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=4 * filters,
kernel_size=1,
strides=1,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
inputs = batch_norm_relu(
inputs, is_training, relu=False, init_zero=True, data_format=data_format)
return tf.nn.relu(inputs + shortcut)
def block_group(inputs,
filters,
block_fn,
blocks,
strides,
is_training,
name,
pruning_method='baseline',
init_method='baseline',
data_format='channels_first',
end_sparsity=0.,
weight_decay=0.):
"""Creates one group of blocks for the ResNet model.
Args:
inputs: `Tensor` of size `[batch, channels, height, width]`.
filters: `int` number of filters for the first convolution of the layer.
block_fn: `function` for the block to use within the model
blocks: `int` number of blocks contained in the layer.
strides: `int` stride to use for the first convolution of the layer. If
greater than 1, this layer will downsample the input.
is_training: `bool` for whether the model is training.
name: String specifying the Tensor output of the block layer.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with
pruning_method == 'scratch'. 'random_zeros' set random weights to zero
using end_sparsoty parameter and used with 'baseline' method.
data_format: `str` either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
Returns:
The output `Tensor` of the block layer.
"""
with tf.name_scope(name):
end_point = 'block_group_projection_%s' % name
# Only the first block per block_group uses projection shortcut and strides.
inputs = block_fn(
inputs,
filters,
is_training,
strides,
use_projection=True,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
for n in range(1, blocks):
with tf.name_scope('block_group_%d' % n):
end_point = '%s_%d_1' % (name, n)
inputs = block_fn(
inputs,
filters,
is_training,
1,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name=end_point)
return tf.identity(inputs, name)
def resnet_v1_generator(block_fn,
num_blocks,
num_classes,
pruning_method='baseline',
init_method='baseline',
width=1.,
prune_first_layer=True,
prune_last_layer=True,
data_format='channels_first',
end_sparsity=0.,
weight_decay=0.,
name=None):
"""Generator for ResNet v1 models.
Args:
block_fn: String that defines whether to use a `residual_block` or
`bottleneck_block`.
num_blocks: list of Ints that denotes number of blocks to include in each
block group. Each group consists of blocks that take inputs of the same
resolution.
num_classes: Int number of possible classes for image classification.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with
pruning_method == 'scratch'. 'random_zeros' set random weights to zero
using end_sparsoty parameter and used with 'baseline' method.
width: Float that scales the number of filters in each layer.
prune_first_layer: Whether or not to prune the first layer.
prune_last_layer: Whether or not to prune the last layer.
data_format: String either "channels_first" for `[batch, channels, height,
width]` or "channels_last for `[batch, height, width, channels]`.
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
name: String that specifies name for model layer.
Returns:
Model `function` that takes in `inputs` and `is_training` and returns the
output `Tensor` of the ResNet model.
"""
def model(inputs, is_training):
"""Creation of the model graph."""
with tf.variable_scope(name, 'resnet_model'):
inputs = conv2d_fixed_padding(
inputs=inputs,
filters=int(64 * width),
kernel_size=7,
strides=2,
pruning_method=pruning_method if prune_first_layer else 'baseline',
init_method=init_method if prune_first_layer else 'baseline',
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay,
name='initial_conv')
inputs = tf.identity(inputs, 'initial_conv')
inputs = batch_norm_relu(
inputs, is_training, data_format=data_format)
inputs = tf.layers.max_pooling2d(
inputs=inputs,
pool_size=3,
strides=2,
padding='SAME',
data_format=data_format,
name='initial_max_pool')
inputs = tf.identity(inputs, 'initial_max_pool')
inputs = block_group(
inputs=inputs,
filters=int(64 * width),
block_fn=block_fn,
blocks=num_blocks[0],
strides=1,
is_training=is_training,
name='block_group1',
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay)
inputs = block_group(
inputs=inputs,
filters=int(128 * width),
block_fn=block_fn,
blocks=num_blocks[1],
strides=2,
is_training=is_training,
name='block_group2',
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay)
inputs = block_group(
inputs=inputs,
filters=int(256 * width),
block_fn=block_fn,
blocks=num_blocks[2],
strides=2,
is_training=is_training,
name='block_group3',
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay)
inputs = block_group(
inputs=inputs,
filters=int(512 * width),
block_fn=block_fn,
blocks=num_blocks[3],
strides=2,
is_training=is_training,
name='block_group4',
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
end_sparsity=end_sparsity,
weight_decay=weight_decay)
pool_size = (inputs.shape[1], inputs.shape[2])
inputs = tf.layers.average_pooling2d(
inputs=inputs,
pool_size=pool_size,
strides=1,
padding='VALID',
data_format=data_format,
name='final_avg_pool')
inputs = tf.identity(inputs, 'final_avg_pool')
multiplier = 4 if block_fn is bottleneck_block_ else 1
fc_units = multiplier * int(512 * width)
inputs = tf.reshape(inputs, [-1, fc_units])
kernel_initializer = tf.random_normal_initializer(stddev=.01)
# If init_method==sparse and not pruning, skip.
if init_method != 'sparse' or prune_last_layer:
kernel_initializer = _pick_initializer(kernel_initializer, init_method,
pruning_method, end_sparsity)
kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)
inputs = sparse_fully_connected(
x=inputs,
units=num_classes,
sparsity_technique=pruning_method if prune_last_layer else 'baseline',
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
name='final_dense')
inputs = tf.identity(inputs, 'final_dense')
return inputs
model.default_image_size = 224
return model
def resnet_v1_(resnet_depth,
num_classes,
pruning_method='baseline',
init_method='baseline',
width=1.,
prune_first_layer=True,
prune_last_layer=True,
data_format='channels_first',
end_sparsity=0.,
weight_decay=0.,
name=None):
"""Returns the ResNet model for a given size and number of output classes.
Args:
resnet_depth: Int number of blocks in the architecture.
num_classes: Int number of possible classes for image classification.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with
pruning_method == 'scratch'. 'random_zeros' set random weights to zero
using end_sparsoty parameter and used with 'baseline' method.
width: Float multiplier of the number of filters in each layer.
prune_first_layer: Whether or not to prune the first layer.
prune_last_layer: Whether or not to prune the last layer.
data_format: String specifying either "channels_first" for `[batch,
channels, height, width]` or "channels_last for `[batch, height, width,
channels]`.
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
name: String that specifies the prefix for the scope.
Raises:
ValueError: If the resnet_depth int is not in the model_params dictionary.
"""
model_params = {
18: {
'block': residual_block_,
'layers': [2, 2, 2, 2]
},
34: {
'block': residual_block_,
'layers': [3, 4, 6, 3]
},
50: {
'block': bottleneck_block_,
'layers': [3, 4, 6, 3]
},
101: {
'block': bottleneck_block_,
'layers': [3, 4, 23, 3]
},
152: {
'block': bottleneck_block_,
'layers': [3, 8, 36, 3]
},
200: {
'block': bottleneck_block_,
'layers': [3, 24, 36, 3]
}
}
if resnet_depth not in model_params:
raise ValueError('Not a valid resnet_depth:', resnet_depth)
params = model_params[resnet_depth]
return resnet_v1_generator(
params['block'], params['layers'], num_classes, pruning_method,
init_method, width, prune_first_layer, prune_last_layer, data_format,
end_sparsity, weight_decay, name)
================================================
FILE: rigl/imagenet_resnet/train_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Tests for the data_helper input pipeline and the training process.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import absl.testing.parameterized as parameterized
from rigl.imagenet_resnet.imagenet_train_eval import resnet_model_fn_w_pruning
from rigl.imagenet_resnet.imagenet_train_eval import set_lr_schedule
import tensorflow.compat.v1 as tf # tf
from official.resnet import imagenet_input
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
FLAGS = flags.FLAGS
class DataInputTest(tf.test.TestCase, parameterized.TestCase):
def _retrieve_data(self, is_training, data_dir):
dataset = imagenet_input.ImageNetInput(
is_training=is_training,
data_dir=data_dir,
transpose_input=False,
num_parallel_calls=8,
use_bfloat16=False)
return dataset
@parameterized.parameters('snip', 'set', 'rigl', 'scratch')
def testTrainingPipeline(self, training_method):
output_directory = '/tmp/'
g = tf.Graph()
with g.as_default():
dataset = self._retrieve_data(is_training=False, data_dir=False)
FLAGS.transpose_input = False
FLAGS.use_tpu = False
FLAGS.mode = 'train'
FLAGS.mask_init_method = 'random'
FLAGS.precision = 'float32'
FLAGS.train_steps = 1
FLAGS.train_batch_size = 1
FLAGS.eval_batch_size = 1
FLAGS.steps_per_eval = 1
FLAGS.model_architecture = 'resnet'
params = {}
params['output_dir'] = output_directory
params['training_method'] = training_method
params['use_tpu'] = False
set_lr_schedule()
run_config = tpu_config.RunConfig(
master=None,
model_dir=None,
save_checkpoints_steps=1,
tpu_config=tpu_config.TPUConfig(iterations_per_loop=1, num_shards=1))
classifier = tpu_estimator.TPUEstimator(
use_tpu=False,
model_fn=resnet_model_fn_w_pruning,
params=params,
config=run_config,
train_batch_size=1,
eval_batch_size=1)
classifier.train(input_fn=dataset.input_fn, max_steps=1)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: rigl/imagenet_resnet/utils.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helped functions to concatenate subset of noisy images to batch."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compat.v2 import summary
IMG_SUMMARY_PREFIX = '_img_'
def format_tensors(*dicts):
"""Format metrics to be callable as tf.summary scalars on tpu's.
Args:
*dicts: A set of metric dictionaries, containing metric name + value tensor.
Returns:
A single formatted dictionary that holds all tensors.
Raises:
ValueError: if any tensor is not a scalar.
"""
merged_summaries = {}
for d in dicts:
for metric_name, value in d.items():
shape = value.shape.as_list()
if metric_name.startswith(IMG_SUMMARY_PREFIX):
# If image, shape it into 2d.
merged_summaries[metric_name] = tf.reshape(value,
(1, -1, value.shape[-1], 1))
elif not shape:
merged_summaries[metric_name] = tf.expand_dims(value, axis=0)
elif shape == [1]:
merged_summaries[metric_name] = value
else:
raise ValueError(
'Metric {} has value {} that is not reconciliable'.format(
metric_name, value))
return merged_summaries
def host_call_fn(model_dir, **kwargs):
"""host_call function used for creating training summaries when using TPU.
Args:
model_dir: String indicating the output_dir to save summaries in.
**kwargs: Set of metric names and tensor values for all desired summaries.
Returns:
Summary op to be passed to the host_call arg of the estimator function.
"""
gs = kwargs.pop('global_step')[0]
with summary.create_file_writer(model_dir).as_default():
# Always record summaries.
with summary.record_if(True):
for name, tensor in kwargs.items():
if name.startswith(IMG_SUMMARY_PREFIX):
summary.image(name.replace(IMG_SUMMARY_PREFIX, ''), tensor,
max_images=1)
else:
summary.scalar(name, tensor[0], step=gs)
# Following function is under tf:1x, so we use it.
return tf.summary.all_v2_summary_ops()
def mask_summaries(masks, with_img=False):
metrics = {}
for mask in masks:
metrics['pruning/{}/sparsity'.format(
mask.op.name)] = tf.nn.zero_fraction(mask)
if with_img:
metrics[IMG_SUMMARY_PREFIX + 'mask/' + mask.op.name] = mask
return metrics
def initialize_parameters_from_ckpt(ckpt_path, model_dir, param_suffixes):
"""Load parameters from an existing checkpoint.
Args:
ckpt_path: str, loads the mask variables from this checkpoint.
model_dir: str, if checkpoint exists in this folder no-op.
param_suffixes: list or str, suffix of parameters to be load from
checkpoint.
"""
already_has_ckpt = model_dir and tf.train.latest_checkpoint(
model_dir) is not None
if already_has_ckpt:
tf.logging.info(
'Training already started on this model, not loading masks from'
'previously trained model')
return
reader = tf.train.NewCheckpointReader(ckpt_path)
param_names = reader.get_variable_to_shape_map().keys()
param_names = [x for x in param_names if x.endswith(param_suffixes)]
variable_map = {}
for var in tf.global_variables():
var_name = var.name.split(':')[0]
if var_name in param_names:
tf.logging.info('Loading parameter variable from checkpoint: %s',
var_name)
variable_map[var_name] = var
elif var_name.endswith(param_suffixes):
tf.logging.info(
'Cannot find parameter variable in checkpoint, skipping: %s',
var_name)
tf.train.init_from_checkpoint(ckpt_path, variable_map)
================================================
FILE: rigl/imagenet_resnet/vgg.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains model definitions for versions of the Oxford VGG network.
These model definitions were introduced in the following technical report:
Very Deep Convolutional Networks For Large-Scale Image Recognition
Karen Simonyan and Andrew Zisserman
arXiv technical report, 2015
PDF: http://arxiv.org/pdf/1409.1556.pdf
ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
CC-BY-4.0
More information can be obtained from the VGG website:
www.robots.ox.ac.uk/~vgg/research/very_deep/
Usage:
with arg_scope(vgg.vgg_arg_scope()):
outputs, end_points = vgg.vgg_net(inputs,scope='vgg_19')
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from rigl.imagenet_resnet import resnet_model
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers
network_cfg = {
'vgg_a': [1, 1, 2, 2, 2],
'vgg_16': [2, 2, 3, 3, 3],
'vgg_19': [2, 2, 4, 4, 4],
}
def vgg_net(inputs,
num_classes=1000,
spatial_squeeze=True,
name='vgg_a',
global_pool=True,
pruning_method='baseline',
init_method='baseline',
data_format='channels_last',
width=1.,
prune_last_layer=True,
end_sparsity=0.,
weight_decay=0.):
"""Oxford Net VGG.
Note: All the fully_connected layers have been transformed to conv2d layers.
To use in classification mode, resize input to 224x224.
Args:
inputs: a tensor of size [batch_size, height, width, channels].
num_classes: number of predicted classes. If 0 or None, the logits layer is
omitted and the input features to the logits layer are returned instead.
spatial_squeeze: whether or not should squeeze the spatial dimensions of the
outputs. Useful to remove unnecessary dimensions for classification.
name: Optional scope for the variables.
global_pool: Optional boolean flag. If True, the input to the classification
layer is avgpooled to size 1x1, for any input size. (This is not part
of the original VGG architecture.)
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with
pruning_method == 'scratch'. 'random_zeros' set random weights to zero
using end_sparsoty parameter and used with 'baseline' method.
data_format: String specifying either "channels_first" for `[batch,
channels, height, width]` or "channels_last for `[batch, height, width,
channels]`.
width: Float multiplier of the number of filters in each layer.
prune_last_layer: Whether or not to prune the last layer.
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
Returns:
net: the output of the logits layer (if num_classes is a non-zero integer),
or the non-dropped-out input to the logits layer (if num_classes is 0 or
None).
end_points: a dict of tensors with intermediate activations. For
backwards compatibility, some Tensors appear multiple times in the dict.
"""
net_cfg = network_cfg[name]
sparse_conv2d = functools.partial(
resnet_model.conv2d_fixed_padding,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
init_scale=2.0, # Heinit
end_sparsity=end_sparsity,
weight_decay=weight_decay)
def new_sparse_conv2d(*args, **kwargs):
kwargs['name'] = kwargs['scope']
del kwargs['scope']
activation_fn = 'relu'
if 'activation_fn' in kwargs:
activation_fn = kwargs['activation_fn']
del kwargs['activation_fn']
out = sparse_conv2d(*args, **kwargs)
if activation_fn == 'relu':
out = tf.nn.relu(out)
return out
with tf.variable_scope(name, name, values=[inputs]):
net = layers.repeat(
inputs,
net_cfg[0],
new_sparse_conv2d,
int(64 * width),
3,
strides=1,
scope='conv1')
net = layers.max_pool2d(net, [2, 2], scope='pool1')
net = layers.repeat(
net,
net_cfg[1],
new_sparse_conv2d,
int(128 * width),
3,
strides=1,
scope='conv2')
net = layers.max_pool2d(net, [2, 2], scope='pool2')
net = layers.repeat(
net,
net_cfg[2],
new_sparse_conv2d,
int(256 * width),
3,
strides=1,
scope='conv3')
net = layers.max_pool2d(net, [2, 2], scope='pool3')
net = layers.repeat(
net,
net_cfg[3],
new_sparse_conv2d,
int(512 * width),
3,
strides=1,
scope='conv4')
net = layers.max_pool2d(net, [2, 2], scope='pool4')
net = layers.repeat(
net,
net_cfg[4],
new_sparse_conv2d,
int(512 * width),
3,
strides=1,
scope='conv5')
# # Use conv2d instead of fully_connected layers.
# net = new_sparse_conv2d(net, 512, [7, 7], strides=1, scope='fc6')
# # net = layers.dropout(net, dropout_keep_prob, is_training=is_training,
# # scope='dropout6')
# net = new_sparse_conv2d(net, 512, [1, 1], strides=1, scope='fc7')
if global_pool:
net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
if num_classes:
# net = layers.dropout(net, dropout_keep_prob, is_training=is_training,
# scope='dropout7')
if prune_last_layer:
net = new_sparse_conv2d(
net, num_classes, 1, activation_fn=None, strides=1, scope='fc8')
else:
net = layers.conv2d(
net, num_classes, [1, 1], activation_fn=None, scope='fc8')
if spatial_squeeze:
net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
return net
def vgg(vgg_type,
num_classes,
pruning_method='baseline',
init_method='baseline',
width=1.,
prune_last_layer=True,
data_format='channels_last',
end_sparsity=0.,
weight_decay=0.):
"""Returns the ResNet model for a given size and number of output classes.
Args:
vgg_type: Int number of blocks in the architecture.
num_classes: Int number of possible classes for image classification.
pruning_method: String that specifies the pruning method used to identify
which weights to remove.
init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard
initialization or initialization that takes into the existing sparsity of
the layer. 'sparse' only makes sense when combined with pruning_method ==
'scratch'. 'random_zeros' set random weights to zero using end_sparsoty
parameter and used with 'baseline' method.
width: Float multiplier of the number of filters in each layer.
prune_last_layer: Whether or not to prune the last layer.
data_format: String specifying either "channels_first" for `[batch,
channels, height, width]` or "channels_last for `[batch, height, width,
channels]`.
end_sparsity: Desired sparsity at the end of training. Necessary to
initialize an already sparse network.
weight_decay: Weight for the l2 regularization loss.
Raises:
ValueError: If the resnet_depth int is not in the model_params dictionary.
"""
def model_fn(inputs, is_training):
del is_training
return vgg_net(
inputs,
num_classes,
name=vgg_type,
pruning_method=pruning_method,
init_method=init_method,
data_format=data_format,
width=width,
prune_last_layer=prune_last_layer,
end_sparsity=end_sparsity,
weight_decay=weight_decay)
return model_fn
================================================
FILE: rigl/mnist/mnist_train_eval.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""A configurable, multi-layer fully connected network trained on MNIST.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
import numpy as np
from rigl import sparse_optimizers
from rigl import sparse_utils
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data
flags.DEFINE_string('mnist', '/tmp/data', 'Location of the MNIST ' 'dataset.')
## optimizer hyperparameters
flags.DEFINE_integer('batch_size', 100, 'The number of samples in each batch')
flags.DEFINE_float('learning_rate', .2, 'Initial learning rate.')
flags.DEFINE_float('momentum', .9, 'Momentum.')
flags.DEFINE_boolean('use_nesterov', True, 'Use nesterov momentum.')
flags.DEFINE_integer('num_epochs', 200, 'Number of epochs to run.')
flags.DEFINE_integer('lr_drop_epoch', 75, 'The epoch to start dropping lr.')
flags.DEFINE_string('optimizer', 'momentum',
'Optimizer to use. sgd, momentum or adam')
flags.DEFINE_float('l2_scale', 1e-4, 'l2 loss scale')
flags.DEFINE_string('network_type', 'fc',
'Type of the network. See below for available options.')
flags.DEFINE_enum(
'training_method', 'baseline',
('scratch', 'set', 'baseline', 'momentum', 'rigl', 'static', 'snip',
'prune'),
'Method used for training sparse network. `scratch` means initial mask is '
'kept during training. `set` is for sparse evalutionary training and '
'`baseline` is for dense baseline.')
flags.DEFINE_float('drop_fraction', 0.3,
'When changing mask dynamically, this fraction decides how '
'much of the ')
flags.DEFINE_string('drop_fraction_anneal', 'cosine',
'If not empty the drop fraction is annealed during sparse'
' training. One of the following: `constant`, `cosine` or '
'`exponential_(\\d*\\.?\\d*)$`. For example: '
'`exponential_3`, `exponential_.3`, `exponential_0.3`. '
'The number after `exponential` defines the exponent.')
flags.DEFINE_string('grow_init', 'zeros',
'Passed to the SparseInitializer, one of: zeros, '
'initial_value, random_normal, random_uniform.')
flags.DEFINE_float('s_momentum', 0.9,
'Momentum values for exponential moving average of '
'gradients. Used when training_method="momentum".')
flags.DEFINE_string(
'input_mask_path', '',
'If given, uses the first mask of the checkpoint to mask '
'the input. If all the outgoing connections are masked '
'in the mask, we mask that dimension of the input.')
flags.DEFINE_float('sparsity_scale', 0.9, 'Relative sparsity of second layer.')
flags.DEFINE_float('rigl_acc_scale', 0.,
'Used to scale initial accumulated gradients for new '
'connections.')
flags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin mask updates.')
flags.DEFINE_integer('maskupdate_end_step', 50000, 'Step to end mask updates.')
flags.DEFINE_integer('maskupdate_frequency', 100,
'Step interval between mask updates.')
flags.DEFINE_integer('mask_record_frequency', 0,
'Step interval between mask logging.')
flags.DEFINE_string(
'mask_init_method',
default='random',
help='If not empty string and mask is not loaded from a checkpoint, '
'indicates the method used for mask initialization. One of the following: '
'`random`, `erdos_renyi`.')
flags.DEFINE_integer('prune_begin_step', 2000, 'step to begin pruning')
flags.DEFINE_integer('prune_end_step', 30000, 'step to end pruning')
flags.DEFINE_float('end_sparsity', .98, 'desired sparsity of final model.')
flags.DEFINE_integer('pruning_frequency', 500, 'how often to prune.')
flags.DEFINE_float('threshold_decay', 0, 'threshold_decay for pruning.')
flags.DEFINE_string('save_path', '', 'Where to save the model.')
flags.DEFINE_boolean('save_model', True, 'Whether to save model or not.')
flags.DEFINE_integer('seed', default=0, help=('Sets the random seed.'))
FLAGS = flags.FLAGS
# momentum = 0.9
# lr = 0.2
# batch = 100
# decay = 1e-4
def mnist_network_fc(input_batch, reuse=False, model_pruning=False):
"""Define a basic FC network."""
regularizer = contrib_layers.l2_regularizer(scale=FLAGS.l2_scale)
if model_pruning:
y = layers.masked_fully_connected(
inputs=input_batch[0],
num_outputs=300,
activation_fn=tf.nn.relu,
weights_regularizer=regularizer,
reuse=reuse,
scope='layer1')
y1 = layers.masked_fully_connected(
inputs=y,
num_outputs=100,
activation_fn=tf.nn.relu,
weights_regularizer=regularizer,
reuse=reuse,
scope='layer2')
logits = layers.masked_fully_connected(
inputs=y1, num_outputs=10, reuse=reuse, activation_fn=None,
weights_regularizer=regularizer, scope='layer3')
else:
y = tf.layers.dense(
inputs=input_batch[0],
units=300,
activation=tf.nn.relu,
kernel_regularizer=regularizer,
reuse=reuse,
name='layer1')
y1 = tf.layers.dense(
inputs=y,
units=100,
activation=tf.nn.relu,
kernel_regularizer=regularizer,
reuse=reuse,
name='layer2')
logits = tf.layers.dense(inputs=y1, units=10, reuse=reuse,
kernel_regularizer=regularizer, name='layer3')
cross_entropy = tf.losses.sparse_softmax_cross_entropy(
labels=input_batch[1], logits=logits)
cross_entropy += tf.losses.get_regularization_loss()
predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
accuracy = tf.reduce_mean(
tf.cast(tf.equal(input_batch[1], predictions), tf.float32))
return cross_entropy, accuracy
def get_compressed_fc(masks):
"""Given the masks of a sparse network returns the compact network."""
# Dead input pixels.
inds = np.sum(masks[0], axis=1) != 0
masks[0] = masks[0][inds]
compressed_masks = []
for i in range(len(masks)):
w = masks[i]
# Find neurons that doesn't have any incoming edges.
do_w = np.sum(w, axis=0) != 0
if i < (len(masks) - 1):
# Find neurons that doesn't have any outgoing edges.
di_wnext = np.sum(masks[i+1], axis=1) != 0
# Kept neurons should have at least one incoming and one outgoing edges.
do_w = np.logical_and(do_w, di_wnext)
compressed_w = w[:, do_w]
compressed_masks.append(compressed_w)
if i < (len(masks) - 1):
# Remove incoming edges from removed neurons.
masks[i+1] = masks[i+1][do_w]
sparsities = [np.sum(m == 0) / float(np.size(m)) for m in compressed_masks]
sizes = [compressed_masks[0].shape[0]]
for m in compressed_masks:
sizes.append(m.shape[1])
return sparsities, sizes
def main(unused_args):
tf.set_random_seed(FLAGS.seed)
tf.get_variable_scope().set_use_resource(True)
np.random.seed(FLAGS.seed)
# Load the MNIST data and set up an iterator.
mnist_data = input_data.read_data_sets(
FLAGS.mnist, one_hot=False, validation_size=0)
train_images = mnist_data.train.images
test_images = mnist_data.test.images
if FLAGS.input_mask_path:
reader = tf.train.load_checkpoint(FLAGS.input_mask_path)
input_mask = reader.get_tensor('layer1/mask')
indices = np.sum(input_mask, axis=1) != 0
train_images = train_images[:, indices]
test_images = test_images[:, indices]
dataset = tf.data.Dataset.from_tensor_slices(
(train_images, mnist_data.train.labels.astype(np.int32)))
num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size
dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0])
batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
iterator = batched_dataset.make_one_shot_iterator()
test_dataset = tf.data.Dataset.from_tensor_slices(
(test_images, mnist_data.test.labels.astype(np.int32)))
num_test_images = mnist_data.test.images.shape[0]
test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images)
test_iterator = test_dataset.make_one_shot_iterator()
# Set up loss function.
use_model_pruning = FLAGS.training_method != 'baseline'
if FLAGS.network_type == 'fc':
cross_entropy_train, _ = mnist_network_fc(
iterator.get_next(), model_pruning=use_model_pruning)
cross_entropy_test, accuracy_test = mnist_network_fc(
test_iterator.get_next(), reuse=True, model_pruning=use_model_pruning)
else:
raise RuntimeError(FLAGS.network + ' is an unknown network type.')
# Remove extra added ones. Current implementation adds the variables twice
# to the collection. Improve this hacky thing.
# TODO test the following with the convnet or any other network.
if use_model_pruning:
for k in ('masks', 'masked_weights', 'thresholds', 'kernel'):
# del tf.get_collection_ref(k)[2]
# del tf.get_collection_ref(k)[2]
collection = tf.get_collection_ref(k)
del collection[len(collection)//2:]
print(tf.get_collection_ref(k))
# Set up optimizer and update ops.
global_step = tf.train.get_or_create_global_step()
batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size
if FLAGS.optimizer != 'adam':
if not use_model_pruning:
boundaries = [int(round(s * batch_per_epoch)) for s in [60, 70, 80]]
else:
boundaries = [int(round(s * batch_per_epoch)) for s
in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]]
learning_rate = tf.train.piecewise_constant(
global_step, boundaries,
values=[FLAGS.learning_rate / (3. ** i)
for i in range(len(boundaries) + 1)])
else:
learning_rate = FLAGS.learning_rate
if FLAGS.optimizer == 'adam':
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
elif FLAGS.optimizer == 'momentum':
opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum,
use_nesterov=FLAGS.use_nesterov)
elif FLAGS.optimizer == 'sgd':
opt = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type')
custom_sparsities = {
'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale,
'layer3': FLAGS.end_sparsity * 0
}
if FLAGS.training_method == 'set':
# We override the train op to also update the mask.
opt = sparse_optimizers.SparseSETOptimizer(
opt, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal)
elif FLAGS.training_method == 'static':
# We override the train op to also update the mask.
opt = sparse_optimizers.SparseStaticOptimizer(
opt, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal)
elif FLAGS.training_method == 'momentum':
# We override the train op to also update the mask.
opt = sparse_optimizers.SparseMomentumOptimizer(
opt, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
grow_init=FLAGS.grow_init,
drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)
elif FLAGS.training_method == 'rigl':
# We override the train op to also update the mask.
opt = sparse_optimizers.SparseRigLOptimizer(
opt, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency,
drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal,
initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)
elif FLAGS.training_method == 'snip':
opt = sparse_optimizers.SparseSnipOptimizer(
opt,
mask_init_method=FLAGS.mask_init_method,
default_sparsity=FLAGS.end_sparsity,
custom_sparsity_map=custom_sparsities,
use_tpu=False)
elif FLAGS.training_method in ('scratch', 'baseline', 'prune'):
pass
else:
raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
train_op = opt.minimize(cross_entropy_train, global_step=global_step)
if FLAGS.training_method == 'prune':
hparams_string = ('begin_pruning_step={0},sparsity_function_begin_step={0},'
'end_pruning_step={1},sparsity_function_end_step={1},'
'target_sparsity={2},pruning_frequency={3},'
'threshold_decay={4}'.format(
FLAGS.prune_begin_step, FLAGS.prune_end_step,
FLAGS.end_sparsity, FLAGS.pruning_frequency,
FLAGS.threshold_decay))
pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
pruning_hparams.set_hparam('weight_sparsity_map',
['{0}:{1}'.format(k, v) for k, v
in custom_sparsities.items()])
print(pruning_hparams)
pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
with tf.control_dependencies([train_op]):
train_op = pruning_obj.conditional_mask_update_op()
weight_sparsity_levels = pruning.get_weight_sparsity()
global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks())
tf.summary.scalar('test_accuracy', accuracy_test)
tf.summary.scalar('global_sparsity', global_sparsity)
for k, v in zip(pruning.get_masks(), weight_sparsity_levels):
tf.summary.scalar('sparsity/%s' % k.name, v)
if FLAGS.training_method in ('prune', 'snip', 'baseline'):
mask_init_op = tf.no_op()
tf.logging.info('No mask is set, starting dense.')
else:
all_masks = pruning.get_masks()
mask_init_op = sparse_utils.get_mask_init_fn(
all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity,
custom_sparsities)
if FLAGS.save_model:
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
hyper_params_string = '_'.join([FLAGS.network_type, str(FLAGS.batch_size),
str(FLAGS.learning_rate),
str(FLAGS.momentum),
FLAGS.optimizer,
str(FLAGS.l2_scale),
FLAGS.training_method,
str(FLAGS.prune_begin_step),
str(FLAGS.prune_end_step),
str(FLAGS.end_sparsity),
str(FLAGS.pruning_frequency),
str(FLAGS.seed)])
tf.io.gfile.makedirs(FLAGS.save_path)
filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt')
merged_summary_op = tf.summary.merge_all()
# Run session.
if not use_model_pruning:
with tf.Session() as sess:
summary_writer = tf.summary.FileWriter(FLAGS.save_path,
graph=tf.get_default_graph())
print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy')
sess.run([init_op])
tic = time.time()
with tf.io.gfile.GFile(filename, 'w') as outputfile:
for i in range(FLAGS.num_epochs * num_batches):
sess.run([train_op])
if (i % num_batches) == (-1 % num_batches):
epoch_time = time.time() - tic
loss, accuracy, summary = sess.run([cross_entropy_test,
accuracy_test,
merged_summary_op])
# Write logs at every test iteration.
summary_writer.add_summary(summary, i)
log_str = '%d, %.4f, %.4f, %.4f' % (
i // num_batches, epoch_time, loss, accuracy)
print(log_str)
print(log_str, file=outputfile)
tic = time.time()
if FLAGS.save_model:
saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
else:
with tf.Session() as sess:
summary_writer = tf.summary.FileWriter(FLAGS.save_path,
graph=tf.get_default_graph())
log_str = ','.join([
'Epoch', 'Iteration', 'Test loss', 'Test accuracy', 'G_Sparsity',
'Sparsity Layer 0', 'Sparsity Layer 1'
])
sess.run(init_op)
sess.run(mask_init_op)
tic = time.time()
mask_records = {}
with tf.io.gfile.GFile(filename, 'w') as outputfile:
print(log_str)
print(log_str, file=outputfile)
for i in range(FLAGS.num_epochs * num_batches):
if (FLAGS.mask_record_frequency > 0 and
i % FLAGS.mask_record_frequency == 0):
mask_vals = sess.run(pruning.get_masks())
# Cast into bool to save space.
mask_records[i] = [a.astype(bool) for a in mask_vals]
sess.run([train_op])
weight_sparsity, global_sparsity_val = sess.run(
[weight_sparsity_levels, global_sparsity])
if (i % num_batches) == (-1 % num_batches):
epoch_time = time.time() - tic
loss, accuracy, summary = sess.run([cross_entropy_test,
accuracy_test,
merged_summary_op])
# Write logs at every test iteration.
summary_writer.add_summary(summary, i)
log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % (
i // num_batches, i, loss, accuracy, global_sparsity_val,
weight_sparsity[0], weight_sparsity[1])
print(log_str)
print(log_str, file=outputfile)
mask_vals = sess.run(pruning.get_masks())
if FLAGS.network_type == 'fc':
sparsities, sizes = get_compressed_fc(mask_vals)
print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities,
sizes))
print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities,
sizes),
file=outputfile)
tic = time.time()
if FLAGS.save_model:
saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
if mask_records:
np.save(os.path.join(FLAGS.save_path, 'mask_records'), mask_records)
if __name__ == '__main__':
tf.app.run()
================================================
FILE: rigl/mnist/visualize_mask_records.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Visualizes the dumped masks using matplotlib.
We count the number of outgoing edges from the input dimensions. For the first
layer input dimensions correspond to the input pixels and we can visualize it
nicely. You can control which layer is visualized by changing `layer_id` and
`new_shape`. Default is the first layer and we visualize the number of outgoing
connections from individual pixels.
python visualize_mask_records.py --records_path=/tmp/mnist/mask_records.npy
To save the results as gif:
python visualize_mask_records.py --records_path=/path/to/mask_records.npy \
--save_path=/path/to/mask.gif
Modified from:
https://eli.thegreenplace.net/2016/drawing-animated-gifs-with-matplotlib/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v1 as tf
flags.DEFINE_string('records_path', '/tmp/mnist/mask_records.npy',
'Path to load masks records.')
flags.DEFINE_string('save_path', '', 'Path to save the animation.')
flags.DEFINE_list('new_shape', '28,28', 'Path for reshaping the units.')
flags.DEFINE_integer('interval', 100, 'Miliseconds between plot updates.')
flags.DEFINE_integer('layer_id', 0, 'of which we plot statistics during '
'training.')
flags.DEFINE_integer('skip_mask', 10, 'number of checkpoints to skip for '
'each frame.')
flags.DEFINE_integer(
'slow_until', 50, 'Number of masks to show with slower '
'speed. After this number of frames, we start skipping '
'frames to make the video shorter.')
FLAGS = flags.FLAGS
def main(unused_args):
fig, ax = plt.subplots()
fig.set_tight_layout(True)
# Query the figure's on-screen size and DPI. Note that when saving the figure
# to a file, we need to provide a DPI for that separately.
print('fig size: {0} DPI, size in inches {1}'.format(fig.get_dpi(),
fig.get_size_inches()))
# Plot a scatter that persists (isn't redrawn) and the initial line.
mask_records = np.load(FLAGS.records_path, allow_pickle=True).item()
sorted_keys = sorted(mask_records.keys())
new_shape = [int(a) for a in FLAGS.new_shape]
reshape_fn = lambda mask: np.reshape(np.sum(mask, axis=1), new_shape)
c_mask = mask_records[sorted_keys[0]][FLAGS.layer_id]
im = plt.imshow(reshape_fn(c_mask), interpolation='none', vmin=0, vmax=30)
fig.colorbar(im, ax=ax)
def update(i):
"""Updates the plot."""
save_iter = sorted_keys[i]
label = 'timestep {0}'.format(save_iter)
print(label)
# Update the line and the axes (with a new xlabel). Return a tuple of
# "artists" that have to be redrawn for this frame.
c_data = reshape_fn(mask_records[save_iter][FLAGS.layer_id])
im.set_data(c_data)
ax.set_xlabel(label)
return [im, ax]
# FuncAnimation will call the 'update' function for each frame; here
# animating over 10 frames, with an interval of 200ms between frames.
iteration = FLAGS.slow_until
frames = (
list(np.arange(0, iteration, 1)) +
list(np.arange(iteration, len(sorted_keys), FLAGS.skip_mask)))
anim = FuncAnimation(fig, update, frames=frames, interval=FLAGS.interval)
if FLAGS.save_path:
anim.save(FLAGS.save_path, dpi=80, writer='imagemagick')
else:
# plt.show() will just loop the animation forever.
plt.show()
if __name__ == '__main__':
tf.app.run(main)
================================================
FILE: rigl/requirements.txt
================================================
absl-py>=0.6.0
gin-config
numpy>=1.15.4
six>=1.12.0
tensorflow>=1.12.0,<2.0 # change to 'tensorflow-gpu' for gpu support
tensorflow-datasets==2.1
tensorflow-model-optimization
================================================
FILE: rigl/rigl_tf2/README.md
================================================
# Gradient Flow in Sparse Neural Networks and How Lottery Tickets Win
**Paper**: [https://arxiv.org/abs/2010.03533](https://arxiv.org/abs/2010.03533)
This code includes a TF-2 implementation of RigL and some other popular sparse training methods along with pruning, scratch and lottery ticket experiments in a unified codebase.
Run pruning experiments.
```
python train.py --gin_config=configs/prune.gin
```
Runs lottery training.
```
Lottery experiments:
python train.py logdir=/tmp/sparse_spectrum/lottery --seed=8 \
--gin_config=configs/lottery.gin
```
Runs scratch training.
```
python train.py --logdir=/tmp/sparse_spectrum/scratch --seed=8 \
--gin_config=configs/scratch.gin
```
For assigning different gin flags use gin_bindings. i.e.
```
`--gin_bindings='network.weight_init_method="unit_scaled"'
--gin_bindings='unit_scaled_init.init_method="faninout_uniform"'
```
Calculating eigenvalues of hessian. Use logdir to point different checkpoints.
```
python train.py --mode=hessian \
--gin_config=configs/hessian.gin
```
Point `mlp_configs` to run MLP experiments.
```
python train.py --gin_config=mlp_configs/prune.gin
```
Running interpolation experiments is done as the following:
```
python interpolate.py --logdir=/tmp/sparse_spectrum/scratch \
--gin_config=configs/interpolate.gin \
--ckpt_start=/path_to_lottery_logdir/cp-11719.ckpt \
--ckpt_end=/path_to_prune_logdir/cp-11719.ckpt \
--operative_gin=/path_to_logdir/operative_config.gin \
--logdir=/path_to_prune_logdir/ltsolution2prune/
```
## a journey with train.py.
1) check `main()`.
- Load preload_gin_config. This is useful for scratch experiments to use same
hyper_parameters as the pruning experiments. We can overwrite these with
regular `gin_configs/bindings` flags.
- Load data and create the network. Network might load its values from a
checkpoint. These arguments are set through gin. See utils.get_network for
details.
- Then the code either trains the network `mode=train_eval` or calculates the
hessian: `mode=hessian`.
2) train_model()
- Create the optimizer and samples a validation set from the training set.
Validation set is a subset of the training set and used to get better
estimates of certain metrics.
- Create the `mask_updater` object. The returned value can be none, then the
masks are not updated.
- Perform pre-training updates to the network: i.e. meta_initialization.
- Set-up checkpointing so that if a checkpoint exist continue from where it is
left.
- Define gradient function. This function is used during training and for
certain other metrics. Note that we have to manually mask the gradients
since they are dense.
- Define logging function for logging tensorboard event summaries.
- Main training loop: save, log, gradient step, mask update.
================================================
FILE: rigl/rigl_tf2/colabs/MnistProp.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "e5O1UdsY202_"
},
"source": [
"##### Copyright 2020 Google LLC.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jUW1g2_jWmBk"
},
"source": [
"## Measuring Signal Properties of Various Initializations\n",
"For a random signal x ~ normal(0, 1), and a neural network denoted with f(x)=y; ensuring std(y)=1 at initialization is a common goal for popular NN initialization schemes. Here we measure signal propagation for different sparse initializations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "4rvDSX8FFYTI"
},
"outputs": [],
"source": [
"#@title Imports and Definitions\n",
"import numpy as np\n",
"import os\n",
"import tensorflow.compat.v2 as tf\n",
"tf.enable_v2_behavior()\n",
"\n",
"import gin\n",
"from rigl import sparse_utils\n",
"from rigl.rigl_tf2 import init_utils\n",
"from rigl.rigl_tf2 import utils\n",
"from rigl.rigl_tf2 import train\n",
"from rigl.rigl_tf2 import networks\n",
"from rigl.rigl_tf2 import mask_updaters\n",
"\n",
"import functools\n",
"\n",
"pruning_params = utils.get_pruning_params(mode='constant', final_sparsity = 0., begin_step=int(1e10))\n",
"INPUT_SHAPE = (28, 28, 3)\n",
"class Lenet5(tf.keras.Model):\n",
"\n",
" def __init__(self,\n",
" input_shape,\n",
" num_classes,\n",
" activation: str,\n",
" hidden_sizes = (6, 16, 120, 84)):\n",
" super(Lenet5, self).__init__()\n",
" l = tf.keras.layers\n",
" kwargs = {'activation': activation}\n",
" filter_fn = lambda _: True\n",
" wrap_fn = functools.partial(utils.maybe_prune_layer, params=pruning_params, filter_fn=filter_fn)\n",
" self.conv1 = wrap_fn(l.Conv2D(hidden_sizes[0], 5, input_shape=input_shape, **kwargs))\n",
" self.pool1 = l.MaxPool2D(pool_size=(2, 2))\n",
" self.conv2 = wrap_fn(l.Conv2D(hidden_sizes[1], 5, input_shape=input_shape, **kwargs))\n",
" self.pool2 = l.MaxPool2D(pool_size=(2, 2))\n",
" self.flatten = l.Flatten()\n",
" self.dense1 = wrap_fn(l.Dense(hidden_sizes[2], **kwargs))\n",
" self.dense2 = wrap_fn(l.Dense(hidden_sizes[3], **kwargs))\n",
" self.dense3 = wrap_fn(l.Dense(num_classes, **kwargs))\n",
" self.build((1,)+input_shape)\n",
"\n",
" def call(self, inputs):\n",
" x = inputs\n",
" results = {}\n",
" for l_name in ['conv1', 'pool1', 'conv2', 'pool2', 'flatten', 'dense1', 'dense2', 'dense3']:\n",
" x = getattr(self, l_name)(x)\n",
" results[l_name] = x \n",
" return results\n",
"\n",
"def get_mask_random_numpy(mask_shape, sparsity):\n",
" \"\"\"Creates a random sparse mask with deterministic sparsity.\n",
"\n",
" Args:\n",
" mask_shape: list, used to obtain shape of the random mask.\n",
" sparsity: float, between 0 and 1.\n",
"\n",
" Returns:\n",
" numpy.ndarray\n",
" \"\"\"\n",
" all_ones = np.abs(np.ones(mask_shape))\n",
" n_zeros = int(np.floor(sparsity * all_ones.size))\n",
" rand_vals = np.random.uniform(size=mask_shape, high=range(1,mask_shape[-1]+1))\n",
" randflat=rand_vals.flatten()\n",
" randflat.sort()\n",
" t = randflat[n_zeros]\n",
" all_ones[rand_vals\u003c=t] = 0\n",
" return all_ones\n",
"\n",
"def create_convnet(sparsity=0, weight_init_method = None, scale=2, method='fanin_normal'):\n",
" model = Lenet5(INPUT_SHAPE, num_classes, 'relu')\n",
" if sparsity \u003e 0:\n",
" all_masks = [layer.pruning_vars[0][1] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]\n",
" for mask in all_masks:\n",
" new_mask = tf.cast(get_mask_random_numpy(mask.shape, sparsity), dtype=mask.dtype)\n",
" mask.assign(new_mask)\n",
" if weight_init_method:\n",
" all_weights = [layer.pruning_vars[0][0] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]\n",
" for mask, param in zip(all_masks, all_weights):\n",
" if weight_init_method == 'unit':\n",
" new_init = init_utils.unit_scaled_init(mask, method=method, scale=scale)\n",
" elif weight_init_method == 'layer':\n",
" new_init = init_utils.layer_scaled_init(mask, method=method, scale=scale)\n",
" else:\n",
" raise ValueError\n",
" param.assign(new_init)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fkZ_GNjyYYqZ"
},
"source": [
"Here we demonstrate how we can calculate the standard deviation of random noise at initialization for `layer-wise` scaled initialization of Liu et. al."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NsmPRCuZnxDA"
},
"outputs": [],
"source": [
"# Let's create a 95% sparse Lenet-5.\n",
"model = create_convnet(sparsity=0.95, weight_init_method='layer', scale=2, method='fanin_normal')\n",
"# Random input signal\n",
"random_input = tf.random.normal((1000,) + INPUT_SHAPE)\n",
"output_dict = model(random_input)\n",
"all_stds = []\n",
"for k in ['dense1', 'dense2', 'dense3']:\n",
" out_dim = output_dict[k].shape[-1]\n",
" stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)\n",
" all_stds.append(stds)\n",
"print('Mean deviation per neuron', np.mean(np.concatenate(all_stds, axis=0)))\n",
"print('Mean deviation per output neuron', np.mean(all_stds[-1]))\n",
"print('Deviation at output', np.std(random_input))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l3ttY88rYovo"
},
"source": [
"Now we define the code above as a function and use it on a grid to plot signal propagation at different sparsities."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"executionInfo": {
"elapsed": 320,
"status": "ok",
"timestamp": 1613388807790,
"user": {
"displayName": "",
"photoUrl": "",
"userId": ""
},
"user_tz": -180
},
"id": "4rfMGKciOOHf"
},
"outputs": [],
"source": [
"def propagate_signal(sparsity, init_method, batch_size=500):\n",
" model = create_convnet(sparsity=sparsity, weight_init_method=init_method)\n",
" random_input = tf.random.normal((batch_size,) + INPUT_SHAPE)\n",
" # print(np.mean(random_input), np.std(random_input))\n",
" output_dict = model(random_input)\n",
" out_std = np.std(output_dict['dense3'])\n",
" all_stds = []\n",
" for k in ['dense1', 'dense2', 'dense3']:\n",
" out_dim = output_dict[k].shape[-1]\n",
" stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)\n",
" all_stds.append(stds)\n",
" meanstd = np.mean(np.concatenate(all_stds, axis=0))\n",
" return meanstd, out_std"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F1rNPLXk7Ins"
},
"outputs": [],
"source": [
"import itertools, collections\n",
"import numpy as np\n",
"all_results = collections.defaultdict(dict)\n",
"\n",
"N_EXP = 3\n",
"for s in np.linspace(0.8,0.98,5):\n",
" print(s)\n",
" for method, name in zip((None, 'unit', 'layer'), ('Masked Dense', 'Ours', 'Scaled-Init')):\n",
" all_results[name][s] = [propagate_signal(s, method) for _ in range(N_EXP)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Sbjc7LxpVGl0"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"for k, v in all_results.items():\n",
" # if k == 'Masked Dense':\n",
" # continue\n",
" x = sorted(v.keys())\n",
" y = [np.mean([vv[1] for vv in v[kk]])+1e-5 for kk in x]\n",
" plt.plot(x, y, label=k)\n",
"plt.hlines(y=1, color='r', xmin=0, xmax=1)\n",
"plt.yscale('log')\n",
"plt.title('std(output)')\n",
"plt.legend()\n",
"plt.show()\n",
"\n",
"for k, v in all_results.items():\n",
" # if k == 'Masked Dense':\n",
" # continue\n",
" x = sorted(v.keys())\n",
" y = [np.mean([vv[0] for vv in v[kk]])+1e-5 for kk in x]\n",
" plt.plot(x, y, label=k)\n",
"plt.yscale('log')\n",
"plt.hlines(y=1, color='r', xmin=0, xmax=1)\n",
"plt.title('mean(std_per_neuron)')\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
"kind": "private"
},
"name": "Mnist propagation init sparse .ipynb",
"provenance": [
{
"file_id": "126QJDydlS0V4tQ-KhiN6bSlCOisqLV-Z",
"timestamp": 1612472405306
},
{
"file_id": "137QdNeUdTGoAOEPKpPMC09keiwlu12Bh",
"timestamp": 1601472560303
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: rigl/rigl_tf2/configs/dense.gin
================================================
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 500 steps.
training.log_freq = 200
network.network_name = 'lenet5'
network.weight_decay = 0.0005
# original_hidden_size/sqrt(20) -> 20 comes from 95% sparsity.
# following lenet has 2399 params vs 2396 (95% sparse lenet5).
lenet5.hidden_sizes = (6, 16, 120, 84)
lenet5.use_batch_norm = False
optimizer.name = "momentum"
optimizer.learning_rate = 0.05
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/configs/grasp.gin
================================================
training.use_metainit = False
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
training.gradient_regularization=0
optimizer.name = "momentum"
optimizer.learning_rate = 0.1
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
network.weight_decay = 0.0002
# Disable GMP pruning.
pruning.mode = 'constant'
pruning.final_sparsity = 0.
# Enable one shot pruning.
training.oneshot_prune_fraction = 0.95
training.val_batch_size = 5000
pruning.begin_step = 100000000 # High begin_step, so it never starts.
# Mask Updates
mask_updater.update_alg = 'rigl_grasp' # Prune part of rigl_grasp corresponds to grasp.
mask_updater.last_update_step=0 # Never updates.
================================================
FILE: rigl/rigl_tf2/configs/hessian.gin
================================================
hessian.batch_size = 60000
hessian.rows_at_once = 2
# range(0,100,5) + range(100,2000,100) + range(2000,11719,500)
hessian.ckpt_ids = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, 10500, 11000, 11500]
# range(4000,11719,50)
# For Rigl updates
# hessian.ckpt_ids = [-499, -999, -1499, -1999, -2499, -2999, -3499, -3999, -4499, -4999, -5499, -5999, -6499, -6999, -7499, -7999, -8499, -8999, -9499, -9999, -10499, -10999, -11499, -500, -1000, -1500, -2000, -2500, -3000, -3500, -4000, -4500, -5000, -5500, -6000, -6500, -7000, -7500, -8000, -8500, -9000, -9500, -10000, -10500, -11000, -11500]
# hessian.ckpt_ids = [-100, -99, -199, -200, -500, -499, -999, -1999, -1499, -1500, -1000, -2000]
hessian.overwrite = True
================================================
FILE: rigl/rigl_tf2/configs/interpolate.gin
================================================
interpolate.i_start = -0.20
interpolate.i_end = 1.20
interpolate.n_interpolation = 29
================================================
FILE: rigl/rigl_tf2/configs/lottery.gin
================================================
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
optimizer.name = "momentum"
optimizer.learning_rate = 0.05
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_path = '/tmp/sparse_spectrum/ckpt-0'
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/configs/prune.gin
================================================
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
network.network_name = 'lenet5'
network.mask_init_path = None
network.weight_decay = 0.0005
lenet5.use_batch_norm = False
lenet5.hidden_sizes = (6, 16, 120, 84)
optimizer.name = "momentum"
optimizer.learning_rate = 0.05
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
pruning.mode = 'prune'
pruning.initial_sparsity = 0.0
pruning.final_sparsity = 0.95
pruning.begin_step = 3000
pruning.end_step = 7000
pruning.frequency = 100
================================================
FILE: rigl/rigl_tf2/configs/rigl.gin
================================================
training.use_metainit = False
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
training.gradient_regularization=0
optimizer.name = "momentum"
optimizer.learning_rate = 0.05
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_method = None
network.weight_decay = 0.0005
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
unit_scaled_init.method='fanin_normal'
# Mask Updates
mask_updater.update_alg = 'rigl'
mask_updater.schedule_alg = 'lr'
mask_updater.update_freq = 100
mask_updater.init_drop_fraction = 0.3
mask_updater.last_update_step=-1
================================================
FILE: rigl/rigl_tf2/configs/scratch.gin
================================================
training.use_metainit = False
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
training.gradient_regularization=0
optimizer.name = "momentum"
optimizer.learning_rate = 0.05
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_method = None
network.shuffle_mask = False
network.weight_decay = 0.0005
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/configs/set.gin
================================================
training.use_metainit = False
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
training.gradient_regularization=0
optimizer.name = "momentum"
optimizer.learning_rate = 0.05
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_method = None
network.weight_decay = 0.0005
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
unit_scaled_init.method='fanin_normal'
# Mask Updates
mask_updater.update_alg = 'set'
mask_updater.schedule_alg = 'lr'
mask_updater.update_freq = 100
mask_updater.init_drop_fraction = 0.3
mask_updater.last_update_step=-1
================================================
FILE: rigl/rigl_tf2/configs/small_dense.gin
================================================
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
network.network_name = 'lenet5'
network.weight_decay = 0.0005
# original_hidden_size/sqrt(20) -> 20 comes from 95% sparsity.
# following lenet has 2399 params vs 2396 (95% sparse lenet5).
lenet5.hidden_sizes = (3, 3, 27, 20)
lenet5.use_batch_norm = False
optimizer.name = "momentum"
optimizer.learning_rate = 0.05
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/configs/snip.gin
================================================
training.use_metainit = False
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
training.gradient_regularization=0
optimizer.name = "momentum"
optimizer.learning_rate = 0.1
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
network.weight_decay = 0.0002
# Disable GMP pruning.
pruning.mode = 'constant'
pruning.final_sparsity = 0.
# Enable one shot pruning.
training.oneshot_prune_fraction = 0.95
training.val_batch_size = 5000
pruning.begin_step = 100000000 # High begin_step, so it never starts.
# Mask Updates
mask_updater.update_alg = 'rigl_s' # Prune part of rigl_s corresponds to snip.
mask_updater.last_update_step=0 # Never updates.
================================================
FILE: rigl/rigl_tf2/init_utils.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements initializations for sparse layers."""
import math
import gin
import tensorflow as tf
@gin.configurable(denylist=['mask'])
def unit_scaled_init(mask, method='fanavg_uniform', scale=1.0):
"""Scales the variance of each unit with correct fan_in."""
mode, distribution = method.strip().split('_')
# Lets calculate all fan_ins.
if len(mask.shape) == 4:
mask_reduced2d = tf.reduce_sum(mask, axis=[0, 1])
elif len(mask.shape) == 2:
mask_reduced2d = mask
else:
raise ValueError(f'mask.shape: {mask.shape} must be 4 or 2 dimensional.')
fan_ins = tf.reduce_sum(mask_reduced2d, axis=-2)
fan_outs = tf.reduce_sum(mask_reduced2d, axis=-1)
non_zero_indices = tf.where(mask) # shape=(NZ, N_dim)
# Lets sample each row with the correct fan_in.
new_vals = []
# Following iterates over each output channel.
for index in non_zero_indices:
# Get fan_in and out of neurons that the non_zero connection connects.
fan_in = fan_ins[index[-1]]
fan_out = fan_outs[index[-2]]
# Following code is modified from `tensorflow/python/ops/init_ops_v2.py`.
if mode == 'fanin':
current_scale = scale / max(1., fan_in)
elif mode == 'fanout':
current_scale = scale / max(1., fan_out)
elif mode == 'fanavg':
current_scale = scale / max(1., (fan_in + fan_out) / 2.)
else:
raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')
if distribution == 'normal':
stddev = math.sqrt(current_scale)
new_val = tf.random.normal((1,), 0.0, stddev, mask.dtype)
elif distribution == 'uniform':
limit = math.sqrt(3.0 * current_scale)
new_val = tf.random.uniform((1,), -limit, limit, mask.dtype)
else:
raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')
new_vals.append(new_val)
new_vals = tf.concat(new_vals, axis=-1)
new_weights = tf.scatter_nd(
indices=non_zero_indices,
updates=new_vals,
shape=mask.shape)
return new_weights
@gin.configurable(denylist=['mask'])
def layer_scaled_init(mask, method='fanavg_uniform', scale=1.0):
"""Scales the variance of each unit with correct fan_in."""
mode, distribution = method.strip().split('_')
init_factory = tf.keras.initializers.VarianceScaling(
mode=mode.replace('fan', 'fan_'), scale=scale, distribution=distribution)
dense_init = init_factory(shape=mask.shape, dtype=mask.dtype)
fraction_nnz = tf.reduce_sum(mask) / tf.size(mask, out_type=mask.dtype)
new_weights = dense_init / tf.math.sqrt(fraction_nnz)
return new_weights
def unit_scaled_init_tf1(mask,
method='fanavg_uniform',
scale=1.0,
dtype=tf.float32):
"""Scales the variance of each unit with correct fan_in."""
mode, distribution = method.strip().split('_')
# Lets calculate all fan_ins.
if len(mask.shape) == 4:
mask_reduced2d = tf.reduce_sum(mask, axis=[0, 1])
elif len(mask.shape) == 2:
mask_reduced2d = mask
else:
raise ValueError(f'mask.shape: {mask.shape} must be 4 or 2 dimensional.')
fan_ins = tf.reduce_sum(mask_reduced2d, axis=-2)
fan_outs = tf.reduce_sum(mask_reduced2d, axis=-1)
non_zero_indices = tf.where(mask) # shape=(NZ, N_dim)
# Lets sample each row with the correct fan_in.
def new_val_fn(index):
# Get fan_in and out of neurons that the non_zero connection connects.
fan_in = fan_ins[index[-1]]
fan_out = fan_outs[index[-2]]
# Following code is modified from `tensorflow/python/ops/init_ops_v2.py`.
if mode == 'fanin':
current_scale = scale / tf.math.maximum(1., fan_in)
elif mode == 'fanout':
current_scale = scale / tf.math.maximum(1., fan_out)
elif mode == 'fanavg':
current_scale = scale / tf.math.maximum(1., (fan_in + fan_out) / 2.)
else:
raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')
if distribution == 'normal':
stddev = tf.math.sqrt(current_scale)
new_val = tf.random.normal((1,), 0.0, stddev, dtype)
elif distribution == 'uniform':
limit = tf.math.sqrt(3.0 * current_scale)
new_val = tf.random.uniform((1,), -limit, limit, dtype)
else:
raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')
return new_val
# Following iterates over each output channel.
new_vals = tf.squeeze(tf.map_fn(new_val_fn, non_zero_indices, dtype=dtype))
new_weights = tf.scatter_nd(
indices=non_zero_indices, updates=new_vals, shape=mask.shape)
return new_weights
================================================
FILE: rigl/rigl_tf2/interpolate.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Script for interpolating between checkpoints.
"""
import os
from absl import app
from absl import flags
from absl import logging
import gin
import numpy as np
from rigl.rigl_tf2 import utils
import tensorflow.compat.v2 as tf
from pyglib import timer
FLAGS = flags.FLAGS
flags.DEFINE_string('logdir', '/tmp/sparse_spectrum/interpolation',
'Directory to save experiment in.')
flags.DEFINE_string('ckpt_start', '/tmp/sparse_spectrum/cp-0001.ckpt',
'Directory to save experiment in.')
flags.DEFINE_string('ckpt_end', '/tmp/sparse_spectrum/cp-0041.ckpt',
'Directory to save experiment in.')
flags.DEFINE_string(
'preload_gin_config', '', 'If non-empty reads a gin file '
'before parsing gin_config and bindings. This is useful,'
'when you want to start from a configuration of another '
'run. Values are then overwritten by additional configs '
'and bindings provided.')
flags.DEFINE_bool('use_tpu', True, 'Whether to run on TPU or not.')
flags.DEFINE_bool('eval_on_train', True, 'Whether to evaluate on training set.')
flags.DEFINE_integer('load_mask_from', 0, '0 means start checkpoint, 1 means '
'end checkpoint. -1 means no mask loaded.')
flags.DEFINE_enum('mode', 'train_eval', ('train_eval', 'hessian'),
'Whether to run on TPU or not.')
flags.DEFINE_string(
'tpu_job_name', 'tpu_worker',
'Name of the TPU worker job. This is required when having '
'multiple TPU worker jobs.')
flags.DEFINE_string('master', None, 'TPU worker.')
flags.DEFINE_multi_string('gin_config', [],
'List of paths to the config files.')
flags.DEFINE_multi_string('gin_bindings', [],
'Newline separated list of Gin parameter bindings.')
def test_model(model, d_test, batch_size=1000):
"""Tests the model and calculates cross entropy loss and accuracy."""
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for x, y in d_test.batch(batch_size):
predictions = model(x, training=False)
batch_loss = loss_object(y, predictions)
test_loss.update_state(batch_loss)
test_accuracy.update_state(y, predictions)
logging.info('Test loss: %f', test_loss.result().numpy())
logging.info('Test accuracy: %f', test_accuracy.result().numpy())
return test_loss.result().numpy(), test_accuracy.result().numpy()
@gin.configurable(
'interpolate',
denylist=['model_start', 'model_end', 'model_inter', 'd_set'])
def interpolate(model_start, model_end, model_inter, d_set,
i_start=-0.2, i_end=1.2, n_interpolation=29):
"""Interpolates between 2 sparse networks linearly and evaluates."""
interpolation_coefs = np.linspace(i_start, i_end, n_interpolation)
all_scores = {}
for i_coef in interpolation_coefs:
logging.info('Interpolating with: %f', i_coef)
for var_start, var_end, var_inter in zip(model_start.trainable_variables,
model_end.trainable_variables,
model_inter.trainable_variables):
new_value = (1 - i_coef) * var_start + i_coef * var_end
var_inter.assign(new_value)
scores = test_model(model_inter, d_set)
all_scores[i_coef] = scores
return all_scores
def main(unused_argv):
init_timer = timer.Timer()
init_timer.Start()
if FLAGS.preload_gin_config:
# Load default values from the original experiment, always the first one.
with gin.unlock_config():
gin.parse_config_file(FLAGS.preload_gin_config, skip_unknown=True)
logging.info('Operative Gin configurations loaded from: %s',
FLAGS.preload_gin_config)
gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
data_train, data_test, info = utils.get_dataset()
input_shape = info.features['image'].shape
num_classes = info.features['label'].num_classes
logging.info('Input Shape: %s', input_shape)
logging.info('train samples: %s', info.splits['train'].num_examples)
logging.info('test samples: %s', info.splits['test'].num_examples)
data_eval = data_train if FLAGS.eval_on_train else data_test
pruning_params = utils.get_pruning_params(mode='constant')
mask_load_dict = {-1: None, 0: FLAGS.ckpt_start, 1: FLAGS.ckpt_end}
mask_path = mask_load_dict[FLAGS.load_mask_from]
# Currently we interpolate only on the same sparse space.
model_start = utils.get_network(
pruning_params,
input_shape,
num_classes,
mask_init_path=mask_path,
weight_init_path=FLAGS.ckpt_start)
model_start.summary()
model_end = utils.get_network(
pruning_params,
input_shape,
num_classes,
mask_init_path=mask_path,
weight_init_path=FLAGS.ckpt_end)
model_end.summary()
# Create a third network for interpolation.
model_inter = utils.get_network(
pruning_params,
input_shape,
num_classes,
mask_init_path=mask_path,
weight_init_path=FLAGS.ckpt_end)
logging.info('Performance at init (model_start:')
test_model(model_start, data_eval)
logging.info('Performance at init (model_end:')
test_model(model_end, data_eval)
all_results = interpolate(model_start=model_start, model_end=model_end,
model_inter=model_inter, d_set=data_eval)
tf.io.gfile.makedirs(FLAGS.logdir)
results_path = os.path.join(FLAGS.logdir, 'all_results')
with tf.io.gfile.GFile(results_path, 'wb') as f:
np.save(f, all_results)
logging.info('Total runtime: %.3f s', init_timer.GetDuration())
logconfigfile_path = os.path.join(FLAGS.logdir, 'operative_config.gin')
with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
if __name__ == '__main__':
tf.enable_v2_behavior()
app.run(main)
================================================
FILE: rigl/rigl_tf2/mask_updaters.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements RigL."""
import gin
from rigl.rigl_tf2 import utils
import tensorflow as tf
def get_all_layers(model, filter_fn=lambda _: True):
"""Gets all layers of a model and layers of a layer if it is a keras.Model."""
all_layers = []
for l in model.layers:
if hasattr(l, 'layers'):
all_layers.extend(get_all_layers(l, filter_fn=filter_fn))
elif filter_fn(l):
all_layers.append(l)
return all_layers
def is_pruned(layer):
return isinstance(layer, utils.PRUNING_WRAPPER) and layer.trainable
class MaskUpdater(object):
"""Base class for mask update algorithms.
Attributes:
model: tf.keras.Model
optimizer: tf.train.Optimizer
use_stateless: bool, if True stateless operations are used. This is
important for multi-worker jobs not to diverge.
stateless_seed_offset: int, added to the seed of stateless operations.
Use this to create randomness without divergence across workers.
"""
def __init__(self, model, optimizer, use_stateless=True,
stateless_seed_offset=0, loss_fn=None):
self._model = model
self._optimizer = optimizer
self._use_stateless = use_stateless
self._stateless_seed_offset = stateless_seed_offset
self._loss_fn = loss_fn
self.val_x = self.val_y = None
def prune_masks(self, prune_fraction):
"""Updates a fraction of weights in each layer."""
all_masks, all_vars = self.get_vars_and_masks()
drop_scores = self.get_drop_scores(all_vars, all_masks)
grow_score = None
for mask, var, drop_score in zip(all_masks, all_vars, drop_scores):
self.generic_mask_update(mask, var, drop_score, grow_score,
prune_fraction)
def update_masks(self, drop_fraction):
"""Updates a fraction of weights in each layer."""
all_masks, all_vars = self.get_vars_and_masks()
drop_scores = self.get_drop_scores(all_vars, all_masks)
grow_scores = self.get_grow_scores(all_vars, all_masks)
for mask, var, drop_score, grow_score in zip(all_masks, all_vars,
drop_scores, grow_scores):
self.generic_mask_update(mask, var, drop_score, grow_score, drop_fraction)
def get_all_pruning_layers(self):
"""Returns all pruned layers from the model."""
if hasattr(self._model, 'layers'):
return get_all_layers(self._model, filter_fn=is_pruned)
else:
return [self._model] if is_pruned(self._model) else []
def get_vars_and_masks(self):
"""Gets all masked variables and corresponding masks."""
all_masks = []
all_vars = []
for layer in self.get_all_pruning_layers():
for var, mask, _ in layer.pruning_vars:
all_vars.append(var)
all_masks.append(mask)
return all_masks, all_vars
def get_drop_scores(self, all_vars, all_masks):
raise NotImplementedError
def get_grow_scores(self, all_vars, all_masks):
raise NotImplementedError
def generic_mask_update(self, mask, var, score_drop, score_grow,
drop_fraction, reinit_when_same=False):
"""Prunes+grows connections, all tensors same shape."""
n_total = tf.size(score_drop)
n_ones = tf.cast(tf.reduce_sum(mask), dtype=tf.int32)
n_prune = tf.cast(
tf.cast(n_ones, dtype=tf.float32) * drop_fraction, tf.int32)
n_keep = n_ones - n_prune
# Sort the entire array since the k needs to be constant for TPU.
_, sorted_indices = tf.math.top_k(
tf.reshape(score_drop, [-1]), k=n_total)
sorted_indices_ex = tf.expand_dims(sorted_indices, 1)
# We will have zeros after having `n_keep` many ones.
new_values = tf.where(
tf.range(n_total) < n_keep,
tf.ones_like(sorted_indices, dtype=mask.dtype),
tf.zeros_like(sorted_indices, dtype=mask.dtype))
mask1 = tf.scatter_nd(sorted_indices_ex, new_values,
new_values.shape)
if score_grow is not None:
# Flatten the scores.
score_grow = tf.reshape(score_grow, [-1])
# Set scores of the enabled connections(ones) to min(s) - 1, so that they
# have the lowest scores.
score_grow_lifted = tf.where(
tf.math.equal(mask1, 1),
tf.ones_like(mask1) * (tf.reduce_min(score_grow) - 1), score_grow)
_, sorted_indices = tf.math.top_k(score_grow_lifted, k=n_total)
sorted_indices_ex = tf.expand_dims(sorted_indices, 1)
new_values = tf.where(
tf.range(n_total) < n_prune,
tf.ones_like(sorted_indices, dtype=mask.dtype),
tf.zeros_like(sorted_indices, dtype=mask.dtype))
mask2 = tf.scatter_nd(sorted_indices_ex, new_values, new_values.shape)
# Ensure masks are disjoint.
tf.debugging.assert_near(tf.reduce_sum(mask1 * mask2), 0.)
# Let's set the weights of the growed connections.
mask2_reshaped = tf.reshape(mask2, mask.shape)
# Set the values of the new connections.
grow_tensor = tf.zeros_like(var, dtype=var.dtype)
if reinit_when_same:
# If dropped and grown, we re-initialize.
new_connections = tf.math.equal(mask2_reshaped, 1)
else:
new_connections = tf.math.logical_and(
tf.math.equal(mask2_reshaped, 1), tf.math.equal(mask, 0))
new_weights = tf.where(new_connections, grow_tensor, var)
var.assign(new_weights)
# Ensure there is no momentum value for new connections
self.reset_momentum(var, new_connections)
mask_combined = tf.reshape(mask1 + mask2, mask.shape)
else:
mask_combined = tf.reshape(mask1, mask.shape)
mask.assign(mask_combined)
def reset_momentum(self, var, new_connections):
for s_name in self._optimizer.get_slot_names():
# Momentum variable for example, we reset the aggregated values to zero.
optim_var = self._optimizer.get_slot(var, s_name)
new_values = tf.where(new_connections,
tf.zeros_like(optim_var), optim_var)
optim_var.assign(new_values)
def _random_uniform(self, *args, **kwargs):
if self._use_stateless:
c_seed = self._stateless_seed_offset + kwargs['seed']
kwargs['seed'] = tf.cast(
tf.stack([c_seed, self._optimizer.iterations]), tf.int32)
return tf.random.stateless_uniform(*args, **kwargs)
else:
return tf.random.uniform(*args, **kwargs)
def _random_normal(self, *args, **kwargs):
if self._use_stateless:
c_seed = self._stateless_seed_offset + kwargs['seed']
kwargs['seed'] = tf.cast(
tf.stack([c_seed, self._optimizer.iterations]), tf.int32)
return tf.random.stateless_normal(*args, **kwargs)
else:
return tf.random.normal(*args, **kwargs)
def set_validation_data(self, val_x, val_y):
self.val_x, self.val_y = val_x, val_y
def _get_gradients(self, all_vars):
"""Returns the gradients of the given weights using the validation data."""
with tf.GradientTape() as tape:
batch_loss = self._loss_fn(self.val_x, self.val_y)
grads = tape.gradient(batch_loss, all_vars)
if grads:
grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
return grads
class SET(MaskUpdater):
"""Implementation of dynamic sparsity optimizers.
Implementation of SET.
See https://www.nature.com/articles/s41467-018-04316-3
This optimizer wraps a regular optimizer and performs updates on the masks
according to schedule given.
"""
def get_drop_scores(self, all_vars, all_masks, noise_std=0):
def score_fn(mask, var):
score = tf.math.abs(mask*var)
if noise_std != 0:
score += self._random_normal(
score.shape, stddev=noise_std, dtype=score.dtype,
seed=(hash(var.name + 'drop')))
return score
return [score_fn(mask, var) for mask, var in zip(all_masks, all_vars)]
def get_grow_scores(self, all_vars, all_masks):
return [self._random_uniform(var.shape, seed=hash(var.name + 'grow'))
for var in all_vars]
class RigL(MaskUpdater):
"""Implementation of dynamic sparsity optimizers.
Implementation of RigL.
"""
def get_drop_scores(self, all_vars, all_masks, noise_std=0):
def score_fn(mask, var):
score = tf.math.abs(mask*var)
if noise_std != 0:
score += self._random_normal(
score.shape, stddev=noise_std, dtype=score.dtype,
seed=(hash(var.name + 'drop')))
return score
return [score_fn(mask, var) for mask, var in zip(all_masks, all_vars)]
def get_grow_scores(self, all_vars, all_masks):
return [tf.abs(g) for g in self._get_gradients(all_vars)]
class RigLInverted(RigL):
"""Implementation of dynamic sparsity optimizers.
Implementation of RigL.
"""
def get_grow_scores(self, all_vars, all_masks):
return [-tf.abs(g) for g in self._get_gradients(all_vars)]
class UpdateSchedule(object):
"""Base class for mask update algorithms.
Attributes:
mask_updater: MaskUpdater, to invoke.
update_freq: int, frequency of mask updates.
init_drop_fraction: float, initial drop fraction.
"""
def __init__(self, mask_updater, init_drop_fraction, update_freq,
last_update_step):
self._mask_updater = mask_updater
self.update_freq = update_freq
self.last_update_step = last_update_step
self.init_drop_fraction = tf.convert_to_tensor(init_drop_fraction)
self.last_drop_fraction = 0
def get_drop_fraction(self, step):
raise NotImplementedError
def is_update_iter(self, step):
"""Returns true if it is a valid mask update step."""
# last_update_step < 0 means, there is no last step.
# last_update_step = 0 means, never update.
tf.debugging.Assert(step >= 0, [step])
if self.last_update_step < 0:
is_valid_step = True
elif self.last_update_step == 0:
is_valid_step = False
else:
is_valid_step = step <= self.last_update_step
return tf.logical_and(is_valid_step, step % self.update_freq == 0)
def update(self, step, check_update_iter=True):
if check_update_iter:
tf.debugging.Assert(self.is_update_iter(step), [step])
self.last_drop_fraction = self.get_drop_fraction(step)
def true_fn():
self._mask_updater.update_masks(self.last_drop_fraction)
tf.cond(self.last_drop_fraction > 0., true_fn, lambda: None)
def prune(self, prune_fraction):
self.last_drop_fraction = prune_fraction
self._mask_updater.prune_masks(self.last_drop_fraction)
def set_validation_data(self, val_x, val_y):
self._mask_updater.set_validation_data(val_x, val_y)
class ConstantUpdateSchedule(UpdateSchedule):
"""Updates a constant fraction of connections."""
def get_drop_fraction(self, step):
return self.init_drop_fraction
class CosineUpdateSchedule(UpdateSchedule):
"""Updates a constant fraction of connections."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._drop_fraction_fn = tf.keras.experimental.CosineDecay(
self.init_drop_fraction,
self.last_update_step,
alpha=0.0,
name='cosine_drop_fraction')
def get_drop_fraction(self, step):
return self._drop_fraction_fn(step)
class ScaledLRUpdateSchedule(UpdateSchedule):
"""Scales the drop fraction with learning rate."""
def __init__(self, mask_updater, init_drop_fraction, update_freq,
last_update_step, optimizer):
self._optimizer = optimizer
self._initial_lr = self._get_lr(0)
super(ScaledLRUpdateSchedule, self).__init__(
mask_updater, init_drop_fraction, update_freq, last_update_step)
def _get_lr(self, step):
if isinstance(self._optimizer.lr, tf.Variable):
return self._optimizer.lr.numpy()
else:
return self._optimizer.lr(step)
def get_drop_fraction(self, step):
current_lr = self._get_lr(step)
return (self.init_drop_fraction / self._initial_lr) * current_lr
@gin.configurable(
'mask_updater',
allowlist=[
'update_alg',
'schedule_alg',
'update_freq',
'init_drop_fraction',
'last_update_step',
'use_stateless',
])
def get_mask_updater(
model,
optimizer,
loss_fn,
update_alg='',
schedule_alg='lr',
update_freq=100,
init_drop_fraction=0.3,
last_update_step=-1,
use_stateless=True):
"""Retrieves the update algorithm and passes it to the schedule object."""
if not update_alg:
return None
elif update_alg == 'set':
mask_updater = SET(model, optimizer, use_stateless=use_stateless)
elif update_alg == 'rigl':
mask_updater = RigL(
model, optimizer, loss_fn=loss_fn, use_stateless=use_stateless)
elif update_alg == 'rigl_inverted':
mask_updater = RigLInverted(
model, optimizer, loss_fn=loss_fn, use_stateless=use_stateless)
else:
raise ValueError('update_alg:%s is not valid.' % update_alg)
if schedule_alg == 'lr':
update_schedule = ScaledLRUpdateSchedule(
mask_updater, init_drop_fraction, update_freq, last_update_step,
optimizer)
elif schedule_alg == 'cosine':
update_schedule = CosineUpdateSchedule(
mask_updater, init_drop_fraction, update_freq, last_update_step)
elif schedule_alg == 'constant':
update_schedule = ConstantUpdateSchedule(mask_updater, init_drop_fraction,
update_freq, last_update_step)
else:
raise ValueError('schedule_alg:%s is not valid.' % schedule_alg)
return update_schedule
================================================
FILE: rigl/rigl_tf2/metainit.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetaInit algorithm to dynamically initialize neural nets."""
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
class ScaleSGD(tf1.train.Optimizer):
"""SGD optimizer that only trains the scales of the parameters.
This optimizer only tunes the scales of weight matrices.
"""
def __init__(self, learning_rate=0.1, momentum=0.9, mindim=3,
use_locking=False, name="ScaleSGD"):
super(ScaleSGD, self).__init__(use_locking, name)
self._lr = learning_rate
self._momentum = momentum
self._mindim = mindim
# Tensor versions of the constructor arguments, created in _prepare().
self._lr_t = None
self._momentum_t = None
def _prepare(self):
self._lr_t = tf1.convert_to_tensor(self._lr, name="learning_rate")
self._momentum_t = tf1.convert_to_tensor(self._momentum, name="momentum_t")
def _create_slots(self, var_list):
for v in var_list:
self._get_or_make_slot_with_initializer(v,
tf1.constant_initializer(0),
tf1.TensorShape([]),
tf1.float32,
"m",
self._name)
def _resource_apply_dense(self, grad, handle):
var = handle
m = self.get_slot(var, "m")
if len(var.shape) < self._mindim:
return tf.group(*[var, m])
lr_t = tf1.cast(self._lr_t, var.dtype.base_dtype)
momentum_t = tf1.cast(self._momentum_t, var.dtype.base_dtype)
scale = tf1.sqrt(tf1.reduce_sum(var ** 2))
dscale = tf1.sign(tf1.reduce_sum(var * grad) / (scale + 1e-12))
m_t = m.assign(momentum_t * m - lr_t * dscale)
new_scale = scale + m_t
var_update = tf1.assign(var, var * new_scale / (scale + 1e-12))
return tf1.group(*[var_update, m_t])
def _apply_dense(self, grad, var):
return self._resource_apply_dense(grad, var)
def _apply_sparse(self, grad, var):
raise NotImplementedError("Sparse gradient updates are not supported.")
def meta_init(model, loss, x_shape, y_shape, n_params, learning_rate=0.001,
momentum=0.9, meta_steps=1000, eps=1e-5, mask_gradient_fn=None):
"""Run MetaInit algorithm. See `https://papers.nips.cc/paper/9427-metainit-initializing-learning-by-learning-to-initialize`"""
optimizer = ScaleSGD(learning_rate, momentum=momentum)
for _ in range(meta_steps):
x = np.random.normal(0, 1, x_shape)
y = np.random.randint(0, y_shape[1], y_shape[0])
with tf.GradientTape(persistent=True) as tape:
batch_loss = loss(y, model(x, training=True))
grad = tape.gradient(batch_loss, model.trainable_variables)
if mask_gradient_fn is not None:
grad = mask_gradient_fn(model, grad, model.trainable_variables)
prod = tape.gradient(tf.reduce_sum([tf.reduce_sum(g**2) / 2
for g in grad]),
model.trainable_variables)
if mask_gradient_fn is not None:
prod = mask_gradient_fn(model, prod, model.trainable_variables)
meta_loss = [tf.abs(1 - ((g - p) / (g + eps * tf.stop_gradient(
(2 * tf.cast(tf.greater_equal(g, 0), tf.float32)) - 1))))
for g, p in zip(grad, prod)]
if mask_gradient_fn is not None:
meta_loss = mask_gradient_fn(model, meta_loss,
model.trainable_variables)
meta_loss = sum([tf.reduce_sum(m) for m in meta_loss]) / n_params
tf.summary.scalar("meta_loss", meta_loss)
gradients = tape.gradient(meta_loss, model.trainable_variables)
if mask_gradient_fn is not None:
gradients = mask_gradient_fn(model, gradients, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
================================================
FILE: rigl/rigl_tf2/mlp_configs/dense.gin
================================================
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 500 steps.
training.log_freq = 200
network.network_name = 'mlp'
network.weight_decay = 0.0001
optimizer.name = "momentum"
optimizer.learning_rate = 0.2
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/mlp_configs/lottery.gin
================================================
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_path = '/tmp/sparse_spectrum/ckpt-0'
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/mlp_configs/prune.gin
================================================
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
network.network_name = 'mlp'
network.mask_init_path = None
network.weight_decay = 0.0001
optimizer.name = "momentum"
optimizer.learning_rate = 0.2
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
pruning.mode = 'prune'
pruning.initial_sparsity = 0.0
pruning.final_sparsity = 0.98
pruning.begin_step = 3000
pruning.end_step = 7000
pruning.frequency = 100
================================================
FILE: rigl/rigl_tf2/mlp_configs/rigl.gin
================================================
training.use_metainit = False
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_method = None
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
unit_scaled_init.method='fanin_normal'
# Mask Updates
mask_updater.update_alg = 'rigl'
mask_updater.schedule_alg = 'lr'
mask_updater.update_freq = 500
mask_updater.init_drop_fraction = 0.3
mask_updater.last_update_step=-1
================================================
FILE: rigl/rigl_tf2/mlp_configs/scratch.gin
================================================
training.use_metainit = False
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_method = None
network.shuffle_mask = False
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/mlp_configs/set.gin
================================================
training.use_metainit = False
# NON-DEFAULT
network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'
network.weight_init_method = None
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
unit_scaled_init.method='fanin_normal'
# Mask Updates
mask_updater.update_alg = 'set'
mask_updater.schedule_alg = 'lr'
mask_updater.update_freq = 500
mask_updater.init_drop_fraction = 0.3
mask_updater.last_update_step=-1
================================================
FILE: rigl/rigl_tf2/mlp_configs/small_dense.gin
================================================
training.total_steps = 11719 # 6e4/128*25 epochs=11719
training.batch_size = 128
training.save_freq = 500 # Log every 5 steps.
training.log_freq = 200
network.network_name = 'mlp'
network.weight_decay = 0.0001
# (28*28*300 + 300*100 + 100*10)*0.02 + 410 = 5734 params
# (28*28*8 + 8*8 + 8*10) + 8+8+10 = 6442
mlp.hidden_sizes = (8, 8)
optimizer.name = "momentum"
optimizer.learning_rate = 0.2
optimizer.momentum = 0.9
optimizer.clipvalue = None
optimizer.clipnorm = None
# NON-DEFAULT
pruning.mode = 'constant'
pruning.final_sparsity = 0.
pruning.begin_step = 100000000 # High begin_step, so it never starts.
================================================
FILE: rigl/rigl_tf2/networks.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module has networks used in experiments.
"""
from typing import Optional, Tuple # Non-expensive-to-import types.
import gin
import tensorflow.compat.v2 as tf
@gin.configurable(allowlist=['hidden_sizes', 'use_batch_norm'])
def lenet5(input_shape,
num_classes,
activation,
kernel_regularizer,
use_batch_norm = False,
hidden_sizes = (6, 16, 120, 84)):
"""Lenet5 implementation."""
network = tf.keras.Sequential()
kwargs = {
'activation': activation,
'kernel_regularizer': kernel_regularizer,
}
def maybe_add_batchnorm():
if use_batch_norm:
network.add(tf.keras.layers.BatchNormalization())
network.add(tf.keras.layers.Conv2D(
hidden_sizes[0], 5, input_shape=input_shape, **kwargs))
network.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2)))
maybe_add_batchnorm()
network.add(tf.keras.layers.Conv2D(hidden_sizes[1], 5, **kwargs))
network.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2)))
maybe_add_batchnorm()
network.add(tf.keras.layers.Flatten())
network.add(tf.keras.layers.Dense(hidden_sizes[2], **kwargs))
maybe_add_batchnorm()
network.add(tf.keras.layers.Dense(hidden_sizes[3], **kwargs))
maybe_add_batchnorm()
kwargs['activation'] = None
network.add(tf.keras.layers.Dense(num_classes, **kwargs))
return network
@gin.configurable(allowlist=['hidden_sizes', 'use_batch_norm'])
def mlp(input_shape,
num_classes,
activation,
kernel_regularizer,
use_batch_norm = False,
hidden_sizes = (300, 100)):
"""Lenet5 implementation."""
network = tf.keras.Sequential()
kwargs = {
'activation': activation,
'kernel_regularizer': kernel_regularizer
}
def maybe_add_batchnorm():
if use_batch_norm:
network.add(tf.keras.layers.BatchNormalization())
network.add(tf.keras.layers.Flatten(input_shape=input_shape))
network.add(tf.keras.layers.Dense(hidden_sizes[0], **kwargs))
maybe_add_batchnorm()
network.add(tf.keras.layers.Dense(hidden_sizes[1], **kwargs))
maybe_add_batchnorm()
kwargs['activation'] = None
network.add(tf.keras.layers.Dense(num_classes, **kwargs))
return network
================================================
FILE: rigl/rigl_tf2/train.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Training script for running experiments.
"""
import os
from typing import List # Non-expensive-to-import types.
from absl import app
from absl import flags
from absl import logging
import gin
import jax
from jax.scipy.linalg import eigh
import numpy as np
from rigl.rigl_tf2 import mask_updaters
from rigl.rigl_tf2 import metainit
from rigl.rigl_tf2 import utils
import tensorflow.compat.v2 as tf
from pyglib import timer
FLAGS = flags.FLAGS
flags.DEFINE_string('logdir', '/tmp/sparse_spectrum',
'Directory to save experiment in.')
flags.DEFINE_string('preload_gin_config', '', 'If non-empty reads a gin file '
'before parsing gin_config and bindings. This is useful,'
'when you want to start from a configuration of another '
'run. Values are then overwritten by additional configs '
'and bindings provided.')
flags.DEFINE_bool('use_tpu', True, 'Whether to run on TPU or not.')
flags.DEFINE_enum('mode', 'train_eval', ('train_eval', 'hessian'),
'Whether to run on TPU or not.')
flags.DEFINE_string(
'tpu_job_name', 'tpu_worker',
'Name of the TPU worker job. This is required when having '
'multiple TPU worker jobs.')
flags.DEFINE_integer('seed', default=0, help=('Sets the random seed.'))
flags.DEFINE_multi_string('gin_config', [],
'List of paths to the config files.')
flags.DEFINE_multi_string('gin_bindings', [],
'Newline separated list of Gin parameter bindings.')
@tf.function
def get_rows(model, variables, masks, ind_l, indices, x_batch, y_batch,
is_dense_spectrum):
"""Calculates the rows (given by `ind_l`) of the Hessian."""
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
with tf.GradientTape(persistent=True) as tape:
predictions = model(x_batch, training=True)
loss = loss_object(y_batch, predictions)
grads, = tape.gradient(loss, [variables[ind_l]])
# Since the variables are masked before not during the forward pass,
# gradients are dense. We need to ensure they are sparse.
sparse_grads = grads * masks[ind_l]
single_grad = tf.reshape(sparse_grads, [-1])
s_grads = tf.gather(single_grad, indices)
flattened_list = []
hessians_slice_vars = tape.jacobian(
s_grads, variables, experimental_use_pfor=False)
for h, m in zip(hessians_slice_vars, masks):
if is_dense_spectrum:
# We apply the masks since weights are not hard constrained with sparsity.
vals = tf.reshape(h * m, (h.shape[0], -1))
else:
boolean_mask = tf.broadcast_to(tf.equal(m, 1), h.shape)
vals = tf.reshape(h[boolean_mask], (h.shape[0], -1))
flattened_list.append(vals)
res = tf.concat(flattened_list, 1)
return res
def sparse_hessian_calculator(model,
data,
rows_at_once,
eigvals_path,
overwrite,
is_dense_spectrum=False):
"""Calculates the Hessian of the model parameters. Biases are dense."""
# Read all data at once
x_batch, y_batch = list(data.batch(100000))[0]
if tf.io.gfile.exists(eigvals_path) and overwrite:
logging.info('Deleting existing Eigvals: %s', eigvals_path)
tf.io.gfile.rmtree(eigvals_path)
if tf.io.gfile.exists(eigvals_path):
with tf.io.gfile.GFile(eigvals_path, 'rb') as f:
eigvals = np.load(f)
logging.info('Eigvals exists, skipping :%s', eigvals_path)
return eigvals
# First lets create lists that indicate the valid dimension of each variable.
# If we want to calculate sparse spectrum, then we have to omit masked
# dimensions. Biases are dense, therefore have masks of 1's.
masks = []
variables = []
layer_group_indices = []
for l in model.layers:
if isinstance(l, utils.PRUNING_WRAPPER):
# TODO following the outcome of b/148083099, update following.
# Add the weight, mask and the valid dimensions.
weight = l.weights[0]
variables.append(weight)
mask = l.weights[2]
masks.append(mask)
logging.info(mask.shape)
if is_dense_spectrum:
n_params = tf.size(mask)
layer_group_indices.append(tf.range(n_params))
else:
fmask = tf.reshape(mask, [-1])
indices = tf.where(tf.equal(fmask, 1))[:, 0]
layer_group_indices.append(indices)
# Add the bias mask of ones and all of its dimensions.
bias = l.weights[1]
variables.append(bias)
masks.append(tf.ones_like(bias))
layer_group_indices.append(tf.range(tf.size(bias)))
else:
# For now we assume all parameterized layers are wrapped with
# PruneLowMagnitude.
assert not l.trainable_variables
result_all = []
init_timer = timer.Timer()
init_timer.Start()
n_total = 0
logging.info('Calculating Hessian...')
for i, inds in enumerate(layer_group_indices):
n_split = np.ceil(tf.size(inds).numpy() / rows_at_once)
logging.info('Nsplit: %d', n_split)
for c_slice in np.array_split(inds.numpy(), n_split):
res = get_rows(model, variables, masks, i, c_slice, x_batch, y_batch,
is_dense_spectrum)
result_all.append(res.numpy())
n_total += res.shape[0]
target_n = float(res.shape[1])
logging.info('%.3f %% ..', (n_total / target_n))
# We convert in numpy so that it is on cpu automatically and we don't get OOM.
c_hessian = np.concatenate(result_all, 0)
logging.info('Total runtime for hessian: %.3f s', init_timer.GetDuration())
init_timer.Start()
eigens = jax.jit(eigh, backend='cpu')(c_hessian)
eigvals = np.asarray(eigens[0])
with tf.io.gfile.GFile(eigvals_path, 'wb') as f:
np.save(f, eigvals)
logging.info('EigVals saved: %s', eigvals_path)
logging.info('Total runtime for eigvals: %.3f s', init_timer.GetDuration())
return eigvals
@gin.configurable(denylist=['model', 'ds_train', 'logdir'])
def hessian(model,
ds_train,
logdir,
ckpt_ids = gin.REQUIRED,
overwrite = False,
batch_size = 1000,
rows_at_once = 10,
is_dense_spectrum = False):
"""Loads checkpoints under a folder and calculates their hessian spectrum."""
# Note that hessian is calculated using the same batch in different runs.
# This is needed since if the job dies and restarted we want it to be same.
data_hessian = ds_train.take(batch_size)
for ckpt_id in ckpt_ids:
# `cp-0005.ckpt.index` -> 15012
ckpt = tf.train.Checkpoint(model=model)
c_path = os.path.join(logdir, 'ckpt-%d' % ckpt_id)
ckpt.restore(c_path)
logging.info('Loaded from: %s', c_path)
eigvals_path = c_path + '.eigvals'
sparse_hessian_calculator(
model=model, data=data_hessian, eigvals_path=eigvals_path,
overwrite=overwrite, is_dense_spectrum=is_dense_spectrum,
rows_at_once=rows_at_once)
def update_prune_step(model, step):
for layer in model.layers:
if isinstance(layer, utils.PRUNING_WRAPPER):
# Assign iteration count to the layer pruning_step.
layer.pruning_step.assign(step)
def log_sparsities(model):
for layer in model.layers:
if isinstance(layer, utils.PRUNING_WRAPPER):
for _, mask, threshold in layer.pruning_vars:
scalar_name = f'sparsity/{mask.name}'
sparsity = 1 - tf.reduce_mean(mask)
tf.summary.scalar(scalar_name, sparsity)
tf.summary.scalar(f'threshold/{threshold.name}', threshold)
def cosine_distance(x, y):
"""Calculates the distance between 2 tensors of same shape."""
normalizedx = tf.math.l2_normalize(x)
normalizedy = tf.math.l2_normalize(y)
return 1. - tf.reduce_sum(tf.multiply(normalizedx, normalizedy))
def flatten_list_of_vars(var_list):
flat_vars = [tf.reshape(v, -1) for v in var_list]
return tf.concat(flat_vars, axis=-1)
def var_to_img(tensor):
if len(tensor.shape) <= 1:
gray_image = tf.reshape(tensor, [1, -1])
elif len(tensor.shape) == 2:
gray_image = tensor
else:
gray_image = tf.reshape(tensor, [-1, tensor.shape[-1]])
# (H, W) -> (1, H, W, 1)
return tf.expand_dims(tf.expand_dims(gray_image, 0), -1)
def mask_gradients(model, gradients, variables):
name_to_grad = {var.name: grad for grad, var in zip(gradients, variables)}
for layer in model.layers:
if isinstance(layer, utils.PRUNING_WRAPPER):
for weights, mask, _ in layer.pruning_vars:
if weights.name in name_to_grad:
name_to_grad[weights.name] = name_to_grad[weights.name] * mask
masked_gradients = [name_to_grad[var.name] for var in variables]
return masked_gradients
@gin.configurable(
'training', denylist=['model', 'ds_train', 'ds_test', 'logdir'])
def train_model(model,
ds_train,
ds_test,
logdir,
total_steps = 5000,
batch_size = 128,
val_batch_size = 1000,
save_freq = 5,
log_freq = 250,
use_metainit = False,
oneshot_prune_fraction = 0.,
gradient_regularization=0):
"""Training of the CNN on MNIST."""
logging.info('Writing training logs to %s', logdir)
writer = tf.summary.create_file_writer(os.path.join(logdir, 'train_logs'))
optimizer = utils.get_optimizer(total_steps)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_batch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_batch_accuracy')
# Let's create 2 disjoint validation sets.
(val_x, val_y), (val2_x, val2_y) = [
d for d in ds_train.take(val_batch_size * 2).batch(val_batch_size)
]
# We use a separate set than the one we are using in our training.
def loss_fn(x, y):
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
predictions = model(x, training=True)
reg_loss = tf.add_n(model.losses) if model.losses else 0
return loss_object(y, predictions) + reg_loss
mask_updater = mask_updaters.get_mask_updater(model, optimizer, loss_fn)
if mask_updater:
mask_updater.set_validation_data(val2_x, val2_y)
update_prune_step(model, 0)
if oneshot_prune_fraction > 0:
logging.info('Running one shot prunning at the beginning.')
if not mask_updater:
raise ValueError('mask_updater does not exists. Please set '
'mask_updater.update_alg flag for one shot pruning.')
mask_updater.prune(oneshot_prune_fraction)
if use_metainit:
n_params = 0
for layer in model.layers:
if isinstance(layer, utils.PRUNING_WRAPPER):
for _, mask, _ in layer.pruning_vars:
n_params += tf.reduce_sum(mask)
metainit.meta_init(model, loss_object, (128, 28, 28, 1), (128, 10),
n_params, mask_gradient_fn=mask_gradients)
# This is used to calculate some distances, would give incorrect results when
# we restart the training.
initial_params = list(map(lambda a: a.numpy(), model.trainable_variables))
# Create the checkpoint object and restore if there is a checkpoint in the
# folder.
ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(
checkpoint=ckpt, directory=logdir, max_to_keep=None)
if ckpt_manager.latest_checkpoint:
logging.info('Restored from %s', ckpt_manager.latest_checkpoint)
ckpt.restore(ckpt_manager.latest_checkpoint)
is_restored = True
else:
logging.info('Starting from scratch.')
is_restored = False
# Obtain global_step after loading checkpoint.
global_step = optimizer.iterations
tf.summary.experimental.set_step(global_step)
trainable_vars = model.trainable_variables
def get_gradients(x, y, log_batch_gradient=False, is_regularized=True):
"""Gets spars gradients and possibly logs some statistics."""
is_grad_regularized = gradient_regularization != 0
with tf.GradientTape(persistent=is_grad_regularized) as tape:
predictions = model(x, training=True)
batch_loss = loss_object(y, predictions)
if is_regularized and is_grad_regularized:
gradients = tape.gradient(batch_loss, trainable_vars)
gradients = mask_gradients(model, gradients, trainable_vars)
grad_vec = flatten_list_of_vars(gradients)
batch_loss += tf.nn.l2_loss(grad_vec) * gradient_regularization
# Regularization might have been disabled.
reg_loss = tf.add_n(model.losses) if model.losses else 0
if is_regularized:
batch_loss += reg_loss
gradients = tape.gradient(batch_loss, trainable_vars)
# Gradients are dense, we should mask them to ensure updates are sparse;
# So is the norm calculation.
gradients = mask_gradients(model, gradients, trainable_vars)
# If batch gradient log it.
if log_batch_gradient:
tf.summary.scalar('train_batch_loss', batch_loss)
tf.summary.scalar('train_batch_reg_loss', reg_loss)
train_batch_accuracy.update_state(y, predictions)
tf.summary.scalar('train_batch_accuracy', train_batch_accuracy.result())
train_batch_accuracy.reset_states()
return gradients
def log_fn():
logging.info('Logging at iter: %d', global_step.numpy())
log_sparsities(model)
test_loss, test_acc = test_model(model, ds_test)
tf.summary.scalar('test_loss', test_loss)
tf.summary.scalar('test_acc', test_acc)
# Log gradient norm.
# We want to obtain/log gradients without regularization term.
gradients = get_gradients(val_x, val_y, log_batch_gradient=False,
is_regularized=False)
for var, grad in zip(trainable_vars, gradients):
tf.summary.scalar(f'gradnorm/{var.name}', tf.norm(grad))
# Log all gradients together
all_norm = tf.norm(flatten_list_of_vars(gradients))
tf.summary.scalar('.allparams/gradnorm', all_norm)
# Log momentum values:
for s_name in optimizer.get_slot_names():
# Currently we only log momentum.
if s_name not in ['momentum']:
continue
all_slots = [optimizer.get_slot(var, s_name) for var in trainable_vars]
all_norm = tf.norm(flatten_list_of_vars(all_slots))
tf.summary.scalar(f'.allparams/norm_{s_name}', all_norm)
# Log distance to init.
for initial_val, val in zip(initial_params, model.trainable_variables):
tf.summary.scalar(f'dist_init_l2/{val.name}', tf.norm(initial_val - val))
cos_distance = cosine_distance(initial_val, val)
tf.summary.scalar(f'dist_init_cosine/{val.name}', cos_distance)
# Mask update logs:
if mask_updater:
tf.summary.scalar('drop_fraction', mask_updater.last_drop_fraction)
# Log all distances together.
flat_initial = flatten_list_of_vars(initial_params)
flat_current = flatten_list_of_vars(model.trainable_variables)
tf.summary.scalar('.allparams/dist_init_l2/',
tf.norm(flat_initial - flat_current))
tf.summary.scalar('.allparams/dist_init_cosine/',
cosine_distance(flat_initial, flat_current))
# Log masks
for layer in model.layers:
if isinstance(layer, utils.PRUNING_WRAPPER):
for _, mask, _ in layer.pruning_vars:
tf.summary.image('mask/%s' % mask.name, var_to_img(mask))
writer.flush()
def save_fn(step=None):
save_step = step if step else global_step
saved_ckpt = ckpt_manager.save(checkpoint_number=save_step)
logging.info('Saved checkpoint: %s', saved_ckpt)
with writer.as_default():
for x, y in ds_train.repeat().shuffle(
buffer_size=60000).batch(batch_size):
if global_step >= total_steps:
logging.info('Total steps: %d is completed', global_step.numpy())
save_fn()
break
update_prune_step(model, global_step)
if tf.equal(global_step, 0):
logging.info('Seed: %s First 10 Label: %s', FLAGS.seed, y[:10])
if global_step % save_freq == 0:
# If just loaded, don't save it again.
if is_restored:
is_restored = False
else:
save_fn()
if global_step % log_freq == 0:
log_fn()
gradients = get_gradients(x, y, log_batch_gradient=True)
tf.summary.scalar('lr', optimizer.lr(global_step))
optimizer.apply_gradients(zip(gradients, trainable_vars))
if mask_updater and mask_updater.is_update_iter(global_step):
# Save the network before mask_update, we want to use negative integers
# for this.
save_fn(step=(-global_step + 1))
# Gradient norm before.
gradients = get_gradients(
val_x, val_y, log_batch_gradient=False, is_regularized=False)
norm_before = tf.norm(flatten_list_of_vars(gradients))
results = mask_updater.update(global_step)
# Save network again
save_fn(step=-global_step)
if results:
for mask_name, drop_frac in results.items():
tf.summary.scalar('drop_fraction/%s' % mask_name, drop_frac)
# Gradient norm after mask update.
gradients = get_gradients(
val_x, val_y, log_batch_gradient=False, is_regularized=False)
norm_after = tf.norm(flatten_list_of_vars(gradients))
tf.summary.scalar('.allparams/gradnorm_mask_update_improvment',
norm_after - norm_before)
logging.info('Performance after training:')
log_fn()
return model
def test_model(model, d_test, batch_size=1000):
"""Tests the model and calculates cross entropy loss and accuracy."""
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for x, y in d_test.batch(batch_size):
predictions = model(x, training=False)
batch_loss = loss_object(y, predictions)
test_loss.update_state(batch_loss)
test_accuracy.update_state(y, predictions)
logging.info('Test loss: %f', test_loss.result().numpy())
logging.info('Test accuracy: %f', test_accuracy.result().numpy())
return test_loss.result(), test_accuracy.result()
def main(unused_argv):
tf.random.set_seed(FLAGS.seed)
init_timer = timer.Timer()
init_timer.Start()
if FLAGS.mode == 'hessian':
# Load default values from the original experiment.
FLAGS.preload_gin_config = os.path.join(FLAGS.logdir,
'operative_config.gin')
# Maybe preload a gin config.
if FLAGS.preload_gin_config:
config_path = FLAGS.preload_gin_config
gin.parse_config_file(config_path)
logging.info('Gin configuration pre-loaded from: %s', config_path)
gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
ds_train, ds_test, info = utils.get_dataset()
input_shape = info.features['image'].shape
num_classes = info.features['label'].num_classes
logging.info('Input Shape: %s', input_shape)
logging.info('train samples: %s', info.splits['train'].num_examples)
logging.info('test samples: %s', info.splits['test'].num_examples)
pruning_params = utils.get_pruning_params()
model = utils.get_network(pruning_params, input_shape, num_classes)
model.summary(print_fn=logging.info)
if FLAGS.mode == 'train_eval':
train_model(model, ds_train, ds_test, FLAGS.logdir)
elif FLAGS.mode == 'hessian':
test_model(model, ds_test)
hessian(model, ds_train, FLAGS.logdir)
logging.info('Total runtime: %.3f s', init_timer.GetDuration())
logconfigfile_path = os.path.join(
FLAGS.logdir,
'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin')
with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
if __name__ == '__main__':
tf.enable_v2_behavior()
app.run(main)
================================================
FILE: rigl/rigl_tf2/utils.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for training.
"""
import functools
from typing import Optional, Tuple
from absl import flags
from absl import logging
import gin
from rigl.rigl_tf2 import init_utils
from rigl.rigl_tf2 import networks
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
FLAGS = flags.FLAGS
PRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude
PRUNED_LAYER_TYPES = (tf.keras.layers.Conv2D, tf.keras.layers.Dense)
@gin.configurable('data')
def get_dataset():
"""Loads the dataset."""
# the data, shuffled and split between train and test sets.
datasets, info = tfds.load('mnist', with_info=True)
ds_train, ds_test = datasets['train'].cache(), datasets['test'].cache()
preprocess_fn = lambda x: (tf.cast(x['image'], tf.float32) / 255., x['label'])
ds_train = ds_train.map(preprocess_fn)
ds_test = tfds.load('mnist', split='test').cache()
ds_test = ds_test.map(preprocess_fn)
return ds_train, ds_test, info
@gin.configurable('pruning')
def get_pruning_params(mode='prune',
initial_sparsity=0.0,
final_sparsity=0.8,
begin_step=2000,
end_step=4000,
frequency=200):
"""Gets pruning hyper-parameters."""
p_params = {}
if mode == 'prune':
p_params['pruning_schedule'] = pruning_schedule.PolynomialDecay(
initial_sparsity=initial_sparsity,
final_sparsity=final_sparsity,
begin_step=begin_step,
end_step=end_step,
frequency=frequency)
elif mode == 'constant':
p_params['pruning_schedule'] = pruning_schedule.ConstantSparsity(
target_sparsity=final_sparsity, begin_step=begin_step)
else:
raise ValueError('Mode: %s, is not valid' % mode)
return p_params
# Forked from tensorflow_model_optimization/python/core/sparsity/keras/prune.py
def maybe_prune_layer(layer, params, filter_fn):
if filter_fn(layer):
return PRUNING_WRAPPER(layer, **params)
return layer
@gin.configurable('network')
def get_network(
pruning_params,
input_shape,
num_classes,
activation = 'relu',
network_name = 'lenet5',
mask_init_path = None,
shuffle_mask = False,
weight_init_path = None,
weight_init_method = None,
weight_decay = 0.,
noise_stddev = 0.,
pruned_layer_types = PRUNED_LAYER_TYPES):
"""Creates the network."""
kernel_regularizer = (
tf.keras.regularizers.l2(weight_decay) if (weight_decay > 0) else None)
# (1) Create keras model.
model = getattr(networks, network_name)(
input_shape, num_classes, activation=activation,
kernel_regularizer=kernel_regularizer)
model.summary(print_fn=logging.info)
# (2) Adding wrappers. i.e. sparsify if conv or dense.
filter_fn = lambda layer: isinstance(layer, pruned_layer_types)
clone_fn = functools.partial(maybe_prune_layer,
params=pruning_params,
filter_fn=filter_fn)
model = tf.keras.models.clone_model(model, clone_function=clone_fn)
# (3) Update parameters of the model as necessary.
if mask_init_path:
logging.info('Loading masks from: %s', mask_init_path)
mask_init_model = tf.keras.models.clone_model(model)
ckpt = tf.train.Checkpoint(model=mask_init_model)
ckpt.restore(mask_init_path)
for l_source, l_target in zip(mask_init_model.layers, model.layers):
if isinstance(l_source, PRUNING_WRAPPER):
# l.pruning_vars[0][1] is the mask.
mask = l_target.pruning_vars[0][1]
n_active = tf.reduce_sum(mask)
n_dense = tf.cast(tf.size(mask), dtype=n_active.dtype)
logging.info('Before: %s, %.2f', l_target.name,
(n_active / n_dense).numpy())
loaded_mask = l_source.pruning_vars[0][1]
if shuffle_mask:
# tf shuffle shuffles along the first dim, so we need to flatten.
loaded_mask = tf.reshape(
tf.random.shuffle(tf.reshape(loaded_mask, -1)), loaded_mask.shape)
mask.assign(loaded_mask)
n_active = tf.reduce_sum(mask)
n_dense = tf.cast(tf.size(mask), dtype=n_active.dtype)
logging.info('After: %s, %.2f', l_target.name,
(n_active / n_dense).numpy())
del mask_init_model
if weight_init_path:
logging.info('Loading weights from: %s', weight_init_path)
weight_init_model = tf.keras.models.clone_model(model)
ckpt = tf.train.Checkpoint(model=weight_init_model)
ckpt.restore(weight_init_path)
for l_source, l_target in zip(weight_init_model.layers, model.layers):
for var_source, var_target in zip(l_source.trainable_variables,
l_target.trainable_variables):
var_target.assign(var_source)
logging.info('Weight %s loaded from ckpt.', var_target.name)
del weight_init_model
elif weight_init_method == 'unit_scaled':
logging.info('Using unit_scaled initialization.')
for layer in model.layers:
if isinstance(layer, PRUNING_WRAPPER):
# TODO following the outcome of b/148083099, update following.
# Add the weight, mask and the valid dimensions.
weight = layer.weights[0]
mask = layer.weights[2]
new_init = init_utils.unit_scaled_init(mask)
weight.assign(new_init)
logging.info('Weight %s updated init.', weight.name)
elif weight_init_method == 'layer_scaled':
logging.info('Using layer_scaled initialization.')
for layer in model.layers:
if isinstance(layer, PRUNING_WRAPPER):
# TODO following the outcome of b/148083099, update following.
# Add the weight, mask and the valid dimensions.
weight = layer.weights[0]
mask = layer.weights[2]
new_init = init_utils.layer_scaled_init(mask)
weight.assign(new_init)
logging.info('Weight %s updated init.', weight.name)
if noise_stddev > 0.:
logging.info('Adding noise to the initial point')
for layer in model.layers:
for var in layer.trainable_variables:
noise = tf.random.normal(var.shape, mean=0, stddev=noise_stddev)
var.assign_add(noise)
# Do this call to mask the weights with existing masks if it is not done
# already. This is needed for example when we use initial parameters to cal-
# culate distance.
model(tf.expand_dims(tf.ones(input_shape), 0))
return model
@gin.configurable('optimizer', denylist=['total_steps'])
def get_optimizer(total_steps,
name = 'adam',
learning_rate = 0.001,
clipnorm = None,
clipvalue = None,
momentum = None):
"""Creates the optimizer according to the arguments."""
name = name.lower()
# We use cosine decay.
lr_decayed_fn = tf.keras.experimental.CosineDecay(learning_rate, total_steps)
kwargs = {}
if clipnorm:
# Not correct implementation, see http://b/152868229 .
kwargs['clipnorm'] = clipnorm
if clipvalue:
kwargs['clipvalue'] = clipvalue
if name == 'adam':
return tf.keras.optimizers.Adam(lr_decayed_fn, **kwargs)
if name == 'momentum':
return tf.keras.optimizers.SGD(lr_decayed_fn, momentum=momentum, **kwargs)
if name == 'sgd':
return tf.keras.optimizers.SGD(lr_decayed_fn, **kwargs)
if name == 'rmsprop':
return tf.keras.optimizers.RMSprop(
lr_decayed_fn, momentum=momentum, **kwargs)
raise NotImplementedError(f'Optimizers {name} not implemented.')
================================================
FILE: rigl/rl/README.md
================================================
# The State of Sparse Training in Deep Reinforcement Learning
[**Paper**] [goo.gle/sparserl-paper](https://goo.gle/sparserl-paper)
[**Video**] [goo.gle/sparserl-video](https://goo.gle/sparserl-video)
This code requires Tensorflow 2.0; therefore we need to use a separate
requirements file. Please follow the instructions below:
First clone this repo.
```bash
git clone https://github.com/google-research/rigl.git
cd rigl
```
We use [Neurips 2019 MicroNet Challenge](https://micronet-challenge.github.io/)
code for counting operations and size of our networks. Let's clone the
google_research repo and add current folder to the python path.
```bash
git clone https://github.com/google-research/google-research.git
mv google-research/ google_research/
export PYTHONPATH=$PYTHONPATH:$PWD
```
Now we can run some tests. Following script creates a virtual environment and
installs the necessary libraries. Finally, it runs few tests.
```bash
virtualenv -p python3 env_sparserl
source env_sparserl/bin/activate
pip install -r rigl/rl/requirements.txt
python -m rigl.sparse_utils_test
```
Follow instructions here to install MuJoCo: https://github.com/openai/mujoco-py#install-mujoco
To run PPO:
```
python3 rigl/rl/tfagents/ppo_train_eval.py \
--gin_file=rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin \
--root_dir=/tmp/sparserl/ --is_mujoco=True
```
To run SAC:
```
python3 rigl/rl/tfagents/sac_train_eval.py \
--gin_file=rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin \
--root_dir=/tmp/sparserl/ --is_mujoco=True
```
**Citation**:
```
@InProceedings{graesser22a,
title = {The State of Sparse Training in Deep Reinforcement Learning},
author = {Graesser, Laura and Evci, Utku and Elsen, Erich and Castro, Pablo Samuel},
booktitle = {Proceedings of the 39th International Conference on Machine Learning},
pages = {7766--7792},
year = {2022},
editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
volume = {162},
series = {Proceedings of Machine Learning Research},
month = {17--23 Jul},
publisher = {PMLR},
pdf = {https://proceedings.mlr.press/v162/graesser22a/graesser22a.pdf},
url = {https://proceedings.mlr.press/v162/graesser22a.html},
}
```
================================================
FILE: rigl/rl/dqn_agents.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Variants of DQN with sparsity."""
import functools
import math
from absl import logging
from dopamine.agents.dqn import dqn_agent
from dopamine.discrete_domains import atari_lib
import gin
from rigl.rl import sparse_utils
import tensorflow as tf
import tensorflow.compat.v1 as tf1
# one of ('dense', 'prune', 'rigl', 'static', 'set'). If 'dense' no modification
# done. If 'prune', the agent is pruned after training.
# If ('rigl', 'static', 'set') the corresponding sparse-to-sparse training
# algorithm is used.
LEARNER_MODES = ('dense', 'prune', 'rigl', 'static', 'set')
def flatten_list_of_vars(var_list):
flat_vars = [tf.reshape(v, [-1]) for v in var_list]
return tf.concat(flat_vars, axis=-1)
def _get_bn_layer_name(block_id, i):
return f'batch_norm_{block_id},{i}'
def _get_conv_layer_name(block_id, i):
return f'conv_{block_id},{i}'
class _Stack(tf.keras.Model):
"""Stack of pooling and convolutional blocks with residual connections.
"""
def __init__(self,
num_ch,
num_blocks,
use_max_pooling=True,
use_batch_norm=False,
name='stack'):
super(_Stack, self).__init__(name=name)
self._conv = tf.keras.layers.Conv2D(num_ch, 3, strides=1, padding='same')
self.use_max_pooling = use_max_pooling
self.use_batch_norm = use_batch_norm
self.num_blocks = num_blocks
if self.use_batch_norm:
self._batch_norm = tf.keras.layers.BatchNormalization()
if self.use_max_pooling:
self._max_pool = tf.keras.layers.MaxPool2D(
pool_size=3, padding='same', strides=2)
for block_id in range(num_blocks):
for i in range(2):
name = _get_conv_layer_name(block_id, i)
layer = tf.keras.layers.Conv2D(
num_ch, 3, strides=1, padding='same',
name=f'res_{block_id}/conv2d_{i}')
setattr(self, name, layer)
if self.use_batch_norm:
name = _get_bn_layer_name(block_id, i)
setattr(self, name, tf.keras.layers.BatchNormalization())
def call(self, conv_out, training=False):
# Downscale.
conv_out = self._conv(conv_out)
if self.use_max_pooling:
conv_out = self._max_pool(conv_out)
if self.use_batch_norm:
conv_out = self._batch_norm(conv_out, training=training)
# Residual block(s).
for block_id in range(self.num_blocks):
block_input = conv_out
for i in range(2):
conv_out = tf.nn.relu(conv_out)
conv_layer = getattr(self, _get_conv_layer_name(block_id, i))
conv_out = conv_layer(conv_out)
if self.use_batch_norm:
bn_layer = getattr(self, _get_bn_layer_name(block_id, i))
conv_out = bn_layer(conv_out, training=training)
conv_out += block_input
return conv_out
@gin.configurable
class ImpalaNetwork(tf.keras.Model):
"""Agent with ResNet, but without LSTM and additional inputs.
The deep model used for DQN which follows
"IMPALA: Scalable Distributed Deep-RL with Importance Weighted
Actor-Learner Architectures" by Espeholt, Soyer, Munos et al.
Original implementation by Rishabh Agarwal, with minor modifications as
follows:
* rename nn_scale to width to fit with the sparserl API
* allow for non-integer widths.
* add training mode.
* removed the option to have multiple heads.
* modified the call function to return a compatible type.
* added custom logic for sparse training.
"""
def __init__(self,
num_actions,
width=1.0,
mode='dense',
name='impala_deep_network',
prune_allow_key='',
use_batch_norm=False):
super().__init__(name=name)
self._width = width
self._mode = mode
def _scale_width(n):
return int(math.ceil(n * width))
self.num_actions = num_actions
self.use_batch_norm = use_batch_norm
logging.info('Using batch norm in %s: %s', name, use_batch_norm)
stack_fn = functools.partial(_Stack, use_batch_norm=use_batch_norm)
# Parameters and layers for _torso.
self._stacks = [
stack_fn(_scale_width(32), 2, name='stack1'),
stack_fn(_scale_width(64), 2, name='stack2'),
stack_fn(_scale_width(64), 2, name='stack3'),
]
self._dense1 = tf.keras.layers.Dense(_scale_width(256))
self._dense2 = tf.keras.layers.Dense(
self.num_actions, name='policy_logits')
layer_shape_dict = {
'_dense1': (7744, 512),
'_dense2': (512, self.num_actions),
}
def add_stack_shapes(name, in_width, out_width):
# First conv
layer_shape_dict[f'{name}/_conv'] = (3, 3, in_width, out_width)
for i in range(2):
for j in range(2):
l_name = _get_conv_layer_name(i, j)
layer_shape_dict[f'{name}/{l_name}'] = (3, 3, out_width, out_width)
add_stack_shapes('stack0', 4, _scale_width(32))
add_stack_shapes('stack1', _scale_width(32), _scale_width(64))
add_stack_shapes('stack2', _scale_width(64), _scale_width(64))
if mode != 'dense':
custom_sparsities = sparse_utils.get_pruning_sparsities(layer_shape_dict)
for l_name, sparsity in custom_sparsities.items():
logging.info('pruning, layer: %s, sparsity: %.4f', l_name, sparsity)
if l_name.startswith('stack'):
# stack1 -> 1
stack_id = int(l_name[len('stack')])
c_module = self._stacks[stack_id]
# `stack1/_conv` -> `_conv`
l_name = l_name.split('/')[1]
else:
c_module = self
if mode == 'prune':
if prune_allow_key and (prune_allow_key not in l_name):
sparsity = 0
logging.info('%s not pruned since, prune_allow_key: %s', l_name,
prune_allow_key)
wrapped_layer = sparse_utils.maybe_prune_layer(
getattr(c_module, l_name),
params=sparse_utils.get_pruning_params(
mode, final_sparsity=sparsity))
else:
wrapped_layer = sparse_utils.maybe_prune_layer(
getattr(c_module, l_name),
params=sparse_utils.get_pruning_params(mode))
setattr(c_module, l_name, wrapped_layer)
def get_features(self, state, training=True):
x = tf.cast(state, tf.float32)
x /= 255
conv_out = x
for stack in self._stacks:
conv_out = stack(conv_out, training=training)
conv_out = tf.nn.relu(conv_out)
conv_out = tf.keras.layers.Flatten()(conv_out)
out = self._dense1(conv_out)
out = tf.nn.relu(out)
out = self._dense2(out)
return out
def call(self, state, training=True):
out = self.get_features(state, training=training)
return atari_lib.DQNNetworkType(out)
@gin.configurable
class NatureDQNNetwork(tf.keras.Model):
"""The convolutional network used to compute the agent's Q-values."""
def __init__(self, num_actions, width=1, mode='dense', name=None):
"""Creates the layers used for calculating Q-values.
Args:
num_actions: int, number of actions.
width: float, Scales the width of the network uniformly.
mode: str, one of LEARNER_MODES.
name: str, used to create scope for network parameters.
"""
super().__init__(name=name)
self.num_actions = num_actions
self._width = width
self._mode = mode
def _scale_width(n):
return int(math.ceil(n * width))
# Defining layers.
activation_fn = tf.keras.activations.relu
# Setting names of the layers manually to make variable names more similar
# with tf.slim variable names/checkpoints.
self.conv1 = tf.keras.layers.Conv2D(
_scale_width(32), [8, 8],
strides=4,
padding='same',
activation=activation_fn,
name='Conv')
self.conv2 = tf.keras.layers.Conv2D(
_scale_width(64), [4, 4],
strides=2,
padding='same',
activation=activation_fn,
name='Conv')
self.conv3 = tf.keras.layers.Conv2D(
_scale_width(64), [3, 3],
strides=1,
padding='same',
activation=activation_fn,
name='Conv')
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(
_scale_width(512), activation=activation_fn,
name='fully_connected')
self.dense2 = tf.keras.layers.Dense(num_actions, name='fully_connected')
layer_shape_dict = {
'conv1': (_scale_width(32), 8, 8, 4),
'conv2': (_scale_width(64), 4, 4, _scale_width(32)),
'conv3': (_scale_width(64), 3, 3, _scale_width(64)),
'dense1': (7744, _scale_width(512)),
'dense2': (_scale_width(512), num_actions)
}
if mode == 'dense':
pass
elif mode == 'prune':
custom_sparsities = sparse_utils.get_pruning_sparsities(layer_shape_dict)
for l_name, sparsity in custom_sparsities.items():
logging.info('pruning, layer: %s, sparsity: %.4f', l_name, sparsity)
wrapped_layer = sparse_utils.maybe_prune_layer(
getattr(self, l_name),
params=sparse_utils.get_pruning_params(
mode, final_sparsity=sparsity))
setattr(self, l_name, wrapped_layer)
else:
# static, rigl, set.
for l_name in layer_shape_dict:
wrapped_layer = sparse_utils.maybe_prune_layer(
getattr(self, l_name),
params=sparse_utils.get_pruning_params(mode))
setattr(self, l_name, wrapped_layer)
def call(self, state):
"""Creates the output tensor/op given the state tensor as input.
See https://www.tensorflow.org/api_docs/python/tf/keras/Model for more
information on this. Note that tf.keras.Model implements `call` which is
wrapped by `__call__` function by tf.keras.Model.
Parameters created here will have scope according to the `name` argument
given at `.__init__()` call.
Args:
state: Tensor, input tensor.
Returns:
collections.namedtuple, output ops (graph mode) or output tensors (eager).
"""
x = tf.cast(state, tf.float32)
x = x / 255
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.flatten(x)
x = self.dense1(x)
return atari_lib.DQNNetworkType(self.dense2(x))
@gin.configurable
class SparseDQNAgent(dqn_agent.DQNAgent):
"""A variant of DQN that is trained with sparse backbones."""
def __init__(self,
sess,
num_actions,
mode='dense',
weight_decay=0.,
summary_writer=None):
"""Initializes the agent and constructs graph components.
Args:
sess: tf.Session, for executing ops.
num_actions: int, number of actions the agent can take at any state.
mode: str, one of LEARNER_MODES.
weight_decay: float, used to regularize online_convnet.
summary_writer: tf.SummaryWriter, for Tensorboard.
"""
self._weight_decay = weight_decay
if mode in LEARNER_MODES:
self._mode = mode
else:
raise ValueError(f'mode:{mode} not one of {LEARNER_MODES}')
self._global_step = tf1.train.get_or_create_global_step()
# update_period=1, we always update as the supervisor is fixed.
super().__init__(
sess, num_actions, summary_writer=summary_writer)
def _create_network(self, name):
network = self.network(
self.num_actions,
name=name + 'learner',
mode=self._mode)
return network
def _set_additional_ops(self):
if self._mode == 'dense':
self.step_update_op = tf.no_op()
self.mask_update_op = tf.no_op()
self.mask_init_op = tf.no_op()
elif self._mode in ['rigl', 'set', 'static']:
self.step_update_op = sparse_utils.update_prune_step(
self.online_convnet, self._global_step)
# This ensures sparse masks are applied before each run.
self.mask_update_op = sparse_utils.update_prune_masks(self.online_convnet)
self.mask_init_op = sparse_utils.init_masks(self.online_convnet)
# Wrap the optimizer.
if self._mode == 'rigl':
self.optimizer = sparse_utils.UpdatedRigLOptimizer(self.optimizer)
self.optimizer.set_model(self.online_convnet)
elif self._mode == 'set':
self.optimizer = sparse_utils.UpdatedSETOptimizer(self.optimizer)
self.optimizer.set_model(self.online_convnet)
elif self._mode == 'prune':
self.step_update_op = sparse_utils.update_prune_step(
self.online_convnet, self._global_step)
self.mask_update_op = sparse_utils.update_prune_masks(self.online_convnet)
self.mask_init_op = tf.no_op()
else:
raise ValueError(f'Invalid mode: {self._mode}')
def _build_train_op(self):
"""Builds a training op.
Returns:
train_op: An op performing one step of training from replay data.
"""
replay_action_one_hot = tf.one_hot(
self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')
replay_chosen_q = tf.reduce_sum(
self._replay_net_outputs.q_values * replay_action_one_hot,
axis=1,
name='replay_chosen_q')
target = tf.stop_gradient(self._build_target_q_op())
loss = tf1.losses.huber_loss(
target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)
loss = tf.reduce_mean(loss)
if self.summary_writer is not None:
tf1.summary.scalar('Losses/HuberLoss', loss)
reg_loss = 0.
if self._weight_decay:
for v in self.online_convnet.trainable_variables:
if 'bias' not in v.name:
reg_loss += tf.nn.l2_loss(v) * self._weight_decay
loss += reg_loss
tf1.summary.scalar('Losses/RegLoss', reg_loss)
tf1.summary.scalar('Losses/TotalLoss', loss)
sparse_utils.log_sparsities(self.online_convnet)
self._set_additional_ops()
grads_and_vars = self.optimizer.compute_gradients(loss)
train_op = self.optimizer.apply_gradients(
grads_and_vars, global_step=self._global_step)
self._create_summary_ops(grads_and_vars)
return train_op
def _create_summary_ops(self, grads_and_vars):
with tf1.variable_scope('Norm'):
all_norm = tf.norm(
flatten_list_of_vars(self.online_convnet.trainable_variables))
tf1.summary.scalar('online_convnet/weights_norm', all_norm)
all_norm = tf.norm(
flatten_list_of_vars(self.target_convnet.trainable_variables))
tf1.summary.scalar('target_convnet/weights_norm', all_norm)
all_grad_norm = tf.norm(
flatten_list_of_vars([
g for g, v in grads_and_vars
if v in self.online_convnet.trainable_variables
]))
tf1.summary.scalar('online_convnet/grad_norm', all_grad_norm)
total_params, nparam_dict = sparse_utils.get_total_params(
self.online_convnet)
tf1.summary.scalar('params/total', total_params)
for k, val in nparam_dict.items():
tf1.summary.scalar('params/' + k, val)
if self._mode == 'rigl':
tf1.summary.scalar('drop_fraction', self.optimizer.drop_fraction)
def update_prune_step(self):
self._sess.run(self.step_update_op)
def maybe_update_and_apply_masks(self):
self._sess.run(self.mask_update_op)
def maybe_init_masks(self):
# If `dense`; no initialization.
self._sess.run(self.mask_init_op)
def _train_step(self):
if self._replay.memory.add_count > self.min_replay_history:
if self.training_steps % self.update_period == 0:
self.update_prune_step()
self.maybe_update_and_apply_masks()
self._sess.run(self._train_op)
c_step = self._sess.run(self._global_step)
if (self.summary_writer is not None and
self._merged_summaries is not None and
c_step % self.summary_writing_frequency == 0):
summary = self._sess.run(self._merged_summaries)
self.summary_writer.add_summary(summary, c_step)
if self.training_steps % self.target_update_period == 0:
# Mask weights before syncing
self.maybe_update_and_apply_masks()
self._sess.run(self._sync_qt_ops)
self.training_steps += 1
def _build_sync_op(self):
"""Builds ops for assigning weights from online to target network.
Returns:
ops: A list of ops assigning weights from online to target network.
"""
# Get trainable variables from online and target DQNs
sync_qt_ops = []
online_vars = sparse_utils.get_all_variables_and_masks(self.online_convnet)
target_vars = sparse_utils.get_all_variables_and_masks(self.target_convnet)
for (v_online, v_target) in zip(online_vars, target_vars):
# Assign weights from online to target network.
sync_qt_ops.append(v_target.assign(v_online, use_locking=True))
return sync_qt_ops
def _build_networks(self):
"""Builds the Q-value network computations needed for acting and training.
Same as the `super` class expect training=True flags are passed.
These are:
self.online_convnet: For computing the current state's Q-values.
self.target_convnet: For computing the next state's target Q-values.
self._net_outputs: The actual Q-values.
self._q_argmax: The action maximizing the current state's Q-values.
self._replay_net_outputs: The replayed states' Q-values.
self._replay_next_target_net_outputs: The replayed next states' target
Q-values (see Mnih et al., 2015 for details).
"""
self.online_convnet = self._create_network(name='Online')
self.target_convnet = self._create_network(name='Target')
self._net_outputs = self.online_convnet(self.state_ph, training=True)
self._q_argmax = tf.argmax(self._net_outputs.q_values, axis=1)[0]
self._replay_net_outputs = self.online_convnet(self._replay.states,
training=True)
self._replay_next_target_net_outputs = self.target_convnet(
self._replay.next_states)
================================================
FILE: rigl/rl/requirements.txt
================================================
absl-py>=0.6.0
dopamine-rl==4.0.5
gin-config
mujoco-py<2.2,>=2.1
numpy>=1.15.4
six>=1.12.0
tensorflow==2.9.1 # change to 'tensorflow-gpu' for gpu support
tensorflow-datasets==2.1
tensorflow-model-optimization==0.7.2
tf-agents[reverb]=0.13.0
================================================
FILE: rigl/rl/run.sh
================================================
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
set -x
virtualenv -p python3 .
source ./bin/activate
pip install tensorflow
pip install -r sparse_rl/requirements.txt
python -m sparse_rl.tfagents.sac_train_eval.py \
--gin_file=sparse_rl/tfagents/configs/sac_mujoco_sparse_config.gin
================================================
FILE: rigl/rl/run_experiment.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run policy evaluation as supervised learning, reloading representations."""
import sys
from absl import logging
from dopamine.discrete_domains import gym_lib
from dopamine.discrete_domains import run_experiment
import gin
import numpy as np
from rigl.rl import dqn_agents
import tensorflow.compat.v1 as tf1
# Last 10% of the training is averaged to get final reward.
AVG_REWARD_FRAC = 0.1
@gin.configurable
def create_sparse_agent(sess, num_actions, agent=None, summary_writer=None):
"""Creates a sparse agent.
Args:
sess: tf.Session.
num_actions: int, number of actions.
agent: str, type of learner/actor agent to create.
summary_writer: tf.SummaryWriter, for Tensorboard.
Returns:
A learner/actor agent.
"""
assert agent is not None
if agent == 'dqn':
return dqn_agents.SparseDQNAgent(
sess, num_actions, summary_writer=summary_writer)
else:
raise ValueError('Unknown learner agent: {}'.format(agent))
@gin.configurable
class SparseTrainRunner(run_experiment.Runner):
"""Policy evaluation as supervised learning, from a loaded representation."""
def __init__(self,
base_dir,
agent_type,
checkpoint_file_prefix='ckpt',
logging_file_prefix='log',
log_every_n=1,
num_iterations=200,
training_steps=250000,
evaluation_steps=125000,
max_steps_per_episode=27000,
load_env_fn=gym_lib.create_gym_environment,
clip_rewards=True,
atari_100k_eval=False,
num_eval_episodes=100,
observation_noise=None):
"""Initialize SparseTrainRunner in charge of running the experiment.
Args:
base_dir: str, the base directory to host all required sub-directories.
agent_type: str, defines the type of targets to be learned. Can be one of
{'dqn', 'rainbow'}.
checkpoint_file_prefix: str, the prefix to use for checkpoint files.
logging_file_prefix: str, prefix to use for the log files.
log_every_n: int, the frequency for writing logs.
num_iterations: int, the iteration number threshold (must be greater than
start_iteration).
training_steps: int, the number of training steps to perform.
evaluation_steps: int, the number of evaluation steps to perform.
max_steps_per_episode: int, maximum number of steps after which an episode
terminates.
load_env_fn: fn, function which loads and returns an environment.
clip_rewards: bool, whether to clip rewards in [-1, 1].
atari_100k_eval: bool, whether we are using the eval for Atari 100K.
num_eval_episodes: int, the number of full episodes to run during eval,
only used if atari_100k_eval is True.
observation_noise: float (optional), the stddev to use to add noise to the
observations before sending to the agent.
"""
self._logging_file_prefix = logging_file_prefix
self._log_every_n = log_every_n
self._num_iterations = num_iterations
self._training_steps = training_steps
self._evaluation_steps = evaluation_steps
self._max_steps_per_episode = max_steps_per_episode
self._clip_rewards = clip_rewards
self._atari_100k_eval = atari_100k_eval
self._num_eval_episodes = num_eval_episodes
self._base_dir = base_dir
self._create_directories()
self._summary_writer = tf1.summary.FileWriter(self._base_dir)
self._observation_noise = observation_noise
self._environment = load_env_fn()
num_actions = self._environment.action_space.n
config = tf1.ConfigProto(allow_soft_placement=True)
# Allocate only subset of the GPU memory as needed which allows for running
# multiple agents/workers on the same GPU.
config.gpu_options.allow_growth = True
# Set up a session and initialize variables.
self._sess = tf1.Session('local', config=config)
self._agent = create_sparse_agent(
self._sess, num_actions, agent=agent_type,
summary_writer=self._summary_writer)
self._summary_writer.add_graph(graph=tf1.get_default_graph())
self._sess.run(tf1.global_variables_initializer())
self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)
def _run_one_phase_fix_episodes(self, max_episodes, statistics):
"""Run one eval phase for the Atari 100k benchmark.
As opposed to the standard eval phase which runs for a fixed number of
steps, this will run for a fixed number of episodes, producing less noisy
results.
Args:
max_episodes: int, max number of episodes to run.
statistics: `IterationStatistics` object which records the experimental
results.
Returns:
Tuple containing the number of steps taken in this phase (int), the sum of
returns (float), and the number of episodes performed (int).
"""
step_count = 0
num_episodes = 0
sum_returns = 0.
while num_episodes < max_episodes:
episode_length, episode_return = self._run_one_episode()
statistics.append({
'eval_episode_lengths': episode_length,
'eval_episode_returns': episode_return
})
step_count += episode_length
sum_returns += episode_return
num_episodes += 1
# We use sys.stdout.write instead of logging so as to flush frequently
# without generating a line break.
sys.stdout.write('Steps executed: {} '.format(step_count) +
'Episode length: {} '.format(episode_length) +
'Num episodes: {} '.format(num_episodes) +
'Return: {}\r'.format(episode_return))
sys.stdout.flush()
return step_count, sum_returns, num_episodes
def _run_eval_phase(self, statistics):
if not self._atari_100k_eval:
return super()._run_eval_phase(statistics)
self._agent.eval_mode = True
_, sum_returns, num_episodes = self._run_one_phase_fix_episodes(
self._num_eval_episodes, statistics)
average_return = sum_returns / num_episodes if num_episodes > 0 else 0.0
logging.info('Average undiscounted return per evaluation episode: %.2f',
average_return)
statistics.append({'eval_average_return': average_return})
return num_episodes, average_return
def _run_one_step(self, action):
"""Maybe adds noise to observations."""
observation, reward, is_terminal, _ = self._environment.step(action)
if self._observation_noise is not None:
observation += np.random.normal(
scale=self._observation_noise,
size=observation.shape).astype(observation.dtype)
return observation, reward, is_terminal
def run_experiment(self):
"""Runs a full experiment, spread over multiple iterations."""
logging.info('Beginning training...')
if self._num_iterations <= self._start_iteration:
logging.warning('num_iterations (%d) < start_iteration(%d)',
self._num_iterations, self._start_iteration)
return
self._agent.update_prune_step()
self._agent.maybe_init_masks()
all_eval_returns = []
for iteration in range(self._start_iteration, self._num_iterations):
statistics = self._run_one_iteration(iteration)
all_eval_returns.append(statistics['eval_average_return'][-1])
self._log_experiment(iteration, statistics)
self._checkpoint_experiment(iteration)
last_n = int(self._num_iterations * AVG_REWARD_FRAC)
avg_return = np.mean(all_eval_returns[-last_n:])
logging.info('Step %d, Average Return: %f', iteration, avg_return)
================================================
FILE: rigl/rl/sparse_utils.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines pruning and sparse training utilities."""
import functools
import re
import gin
from rigl import sparse_optimizers_base as sparse_opt_base
from rigl import sparse_utils
from rigl.rigl_tf2 import init_utils
import tensorflow as tf
import tensorflow.compat.v1 as tf1
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
PRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude
PRUNED_LAYER_TYPES = (tf.keras.layers.Conv2D, tf.keras.layers.Dense)
def get_total_params(model):
"""Obtains total active parameters of a given network."""
all_layers = get_all_layers(model)
total_count = 0.
nparams_dict = {}
for layer in all_layers:
n_param = 0.
if isinstance(layer, PRUNING_WRAPPER):
mask = layer.pruning_vars[0][1]
n_param += tf.reduce_sum(mask)
n_param += tf.size(layer.weights[1], out_type=tf.float32)
else:
for w in layer.weights:
n_param += tf.size(w, out_type=tf.float32)
nparams_dict[layer.name] = n_param
total_count += n_param
return total_count, nparams_dict
@gin.configurable(denylist=['layer_dict'])
def get_pruning_sparsities(
layer_dict,
mask_init_method='erdos_renyi_kernel',
target_sparsity=0.9,
erk_power_scale=1.,
custom_sparsity_map=None):
"""Creates name/sparsity dict using the name/shapes dict (layer_dict)."""
if target_sparsity == 0:
return {k: 0 for k in layer_dict.keys()}
if custom_sparsity_map is None:
custom_sparsity_map = {}
extract_name_fn = lambda x: re.findall('(.+):0', x)[0]
dummy_masks_dict = {k: tf.ones(v) for k, v in layer_dict.items()}
reverse_dict = {v.name: k
for k, v in dummy_masks_dict.items()}
sparsity_dict = sparse_utils.get_sparsities(
list(dummy_masks_dict.values()),
mask_init_method,
target_sparsity,
custom_sparsity_map,
extract_name_fn=extract_name_fn,
erk_power_scale=erk_power_scale)
renamed_sparsity_dict = {reverse_dict[k]: float(v)
for k, v in sparsity_dict.items()}
return renamed_sparsity_dict
@gin.configurable('pruning')
def get_pruning_params(mode,
initial_sparsity=0.0,
final_sparsity=0.95,
begin_step=30000,
end_step=100000,
frequency=1000):
"""Gets pruning hyper-parameters."""
p_params = {}
if mode == 'prune':
p_params['pruning_schedule'] = pruning_schedule.PolynomialDecay(
initial_sparsity=initial_sparsity,
final_sparsity=final_sparsity,
begin_step=begin_step,
end_step=end_step,
frequency=frequency)
elif mode in ('rigl', 'static', 'set'):
# For sparse training methods we don't use the pruning library to update the
# masks. Therefore we need to disable it. Following `pruning` flags serve
# that purpose.
# 1B. High begin_step, so it never starts.
p_params['pruning_schedule'] = pruning_schedule.ConstantSparsity(
target_sparsity=0, begin_step=1000000000)
else:
raise ValueError('Mode: %s, is not valid' % mode)
return p_params
def maybe_prune_layer(layer, params, filter_fn=None):
if filter_fn is None:
filter_fn = lambda l: isinstance(l, PRUNED_LAYER_TYPES)
if filter_fn(layer):
return PRUNING_WRAPPER(layer, **params)
return layer
def get_wrap_fn(mode):
"""Creates a function that wraps a given layer conditionally.
Args:
mode: str, If 'dense' no modification done. Otherwise the layer is pruned.
Returns:
function that accepts layer and returns a possibly wrapped one.
"""
if mode == 'dense':
# Do not wrap the layer.
wrap_fn = lambda x: x
else:
wrap_fn = functools.partial(
maybe_prune_layer, params=get_pruning_params(mode))
return wrap_fn
def update_prune_step(model, step):
"""Updates the pruning steps of each pruning layer."""
assign_ops = []
for layer in get_all_pruning_layers(model):
# Assign iteration count to the layer pruning_step.
# pruning wrapper requires step to be >0.
assign_op = tf1.assign(layer.pruning_step, tf.maximum(step, 1))
assign_ops.append(assign_op)
return tf.group(assign_ops)
def update_prune_masks(model):
"""Updates the masks if it is an update iteration."""
update_ops = [op for op in model.updates
if 'prune_low_magnitude' in op.name]
return tf.group(update_ops)
def get_all_layers(model, filter_fn=lambda _: True):
"""Gets all layers of a model and layers of a layer if it is a keras.Model."""
all_layers = []
for l in model.layers:
if hasattr(l, 'layers'):
all_layers.extend(get_all_layers(l, filter_fn=filter_fn))
elif filter_fn(l):
all_layers.append(l)
return all_layers
def get_all_variables_and_masks(model):
"""Gets all trainable variables (+their masks) of a model."""
all_layers = get_all_layers(model)
all_variables = []
for l in all_layers:
all_variables.extend(l.trainable_variables)
if isinstance(l, PRUNING_WRAPPER):
all_variables.append(l.pruning_vars[0][1]) # Adding mask.
return all_variables
def get_all_pruning_layers(model):
"""Gets all pruned layers of a model and layers of a layer if keras.Model."""
return get_all_layers(
model, filter_fn=lambda l: isinstance(l, PRUNING_WRAPPER))
def log_sparsities(model):
for layer in get_all_pruning_layers(model):
for _, mask, threshold in layer.pruning_vars:
scalar_name = f'sparsity/{mask.name}'
sparsity = 1 - tf.reduce_mean(mask)
if len(mask.shape) == 2:
reshaped_mask = tf.expand_dims(tf.expand_dims(mask, 0), -1)
tf1.summary.image(f'img/{mask.name}', reshaped_mask)
tf1.summary.scalar(scalar_name, sparsity)
tf1.summary.scalar(f'threshold/{threshold.name}', threshold)
class SparseOptTf2Mixin:
"""Tf2 model_optimization pruning library specific variable retrieval."""
def compute_gradients(self, *args, **kwargs):
"""Wraps the compute gradient of passed optimizer."""
return self._optimizer.compute_gradients(*args, **kwargs)
def set_model(self, model):
self.model = model
def get_weights(self):
all_weights = [
layer.pruning_vars[0][0] for layer in get_all_pruning_layers(self.model)
]
return all_weights
def get_masks(self):
all_masks = [
layer.pruning_vars[0][1] for layer in get_all_pruning_layers(self.model)
]
return all_masks
def get_masked_weights(self):
all_masked_weights = [
w * m for w, m in zip(self.get_weights(), self.get_masks())
]
return all_masked_weights
@gin.configurable()
class UpdatedSETOptimizer(SparseOptTf2Mixin,
sparse_opt_base.SparseSETOptimizerBase):
def _before_apply_gradients(self, grads_and_vars):
return tf1.no_op()
@gin.configurable()
class UpdatedRigLOptimizer(SparseOptTf2Mixin,
sparse_opt_base.SparseRigLOptimizerBase):
def _before_apply_gradients(self, grads_and_vars):
"""Updates momentum before updating the weights with gradient."""
self._weight2masked_grads = {w.name: g for g, w in grads_and_vars}
return tf1.no_op()
@gin.configurable()
def init_masks(model,
mask_init_method='random',
sparsity=0.9,
erk_power_scale=1.,
custom_sparsity_map=None,
fixed_sparse_init=False):
"""Inits the masks randomly according to the given sparsity."""
if sparsity == 0:
return None
if custom_sparsity_map is None:
custom_sparsity_map = {}
all_masks = [
layer.pruning_vars[0][1] for layer in get_all_pruning_layers(model)
]
assigner = sparse_utils.get_mask_init_fn(
all_masks,
mask_init_method,
sparsity,
custom_sparsity_map,
erk_power_scale=erk_power_scale)
if fixed_sparse_init:
all_weights = [
layer.pruning_vars[0][0] for layer in get_all_pruning_layers(model)
]
with tf.control_dependencies([assigner]):
assign_ops = []
for param, mask in zip(all_weights, all_masks):
new_init = init_utils.unit_scaled_init_tf1(mask)
assign_ops.append(tf1.assign(param, new_init))
assigner = tf.group(assign_ops)
return assigner
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_dense.gin
================================================
include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin'
import rigl.rl.dqn_agents
DQNAgent.network = @dqn_agents.NatureDQNNetwork
DQNAgent.optimizer = @tf.train.AdamOptimizer()
tf.train.AdamOptimizer.learning_rate = 0.00025
WrappedReplayBuffer.batch_size = 32 # Same as original
SparseDQNAgent.mode = 'dense'
SparseDQNAgent.weight_decay = 0.0
atari_lib.create_atari_environment.game_name = 'Pong'
SparseTrainRunner.load_env_fn = @atari_lib.create_atari_environment
SparseTrainRunner.agent_type = 'dqn'
SparseTrainRunner.num_iterations = 40
SparseTrainRunner.training_steps = 250000
SparseTrainRunner.evaluation_steps = 125000
SparseTrainRunner.max_steps_per_episode = 27000 # Default max episode length.
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin
================================================
include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin'
import rigl.rl.dqn_agents
DQNAgent.network = @dqn_agents.ImpalaNetwork
DQNAgent.optimizer = @tf.train.AdamOptimizer()
tf.train.AdamOptimizer.learning_rate = 0.0001
tf.train.AdamOptimizer.epsilon = 0.0003125
WrappedReplayBuffer.batch_size = 32 # Same as original
SparseDQNAgent.mode = 'dense'
SparseDQNAgent.weight_decay = 1e-05
atari_lib.create_atari_environment.game_name = 'Pong'
SparseTrainRunner.load_env_fn = @atari_lib.create_atari_environment
SparseTrainRunner.agent_type = 'dqn'
SparseTrainRunner.num_iterations = 40
SparseTrainRunner.training_steps = 250000
SparseTrainRunner.evaluation_steps = 125000
SparseTrainRunner.max_steps_per_episode = 27000 # Default max episode length.
ImpalaNetwork.use_batch_norm = False
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_prune.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'
SparseDQNAgent.mode = 'prune'
get_pruning_sparsities.target_sparsity = 0.95
get_pruning_sparsities.mask_init_method = 'erdos_renyi_kernel'
pruning.initial_sparsity = 0.0
# 0.5M = 20% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
pruning.begin_step = 500000 # 500k
# 2M = 80% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
pruning.end_step = 2000000 # 2M
pruning.frequency = 5000
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_prune_impala_net.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'
SparseDQNAgent.mode = 'prune'
get_pruning_sparsities.target_sparsity = 0.95
get_pruning_sparsities.mask_init_method = 'erdos_renyi_kernel'
pruning.initial_sparsity = 0.0
# 0.5M = 20% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
pruning.begin_step = 500000 # 500k
# 2M = 80% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
pruning.end_step = 2000000 # 2M
pruning.frequency = 5000
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_rigl.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'
SparseDQNAgent.mode = 'rigl'
# For sparse training methods we don't use the pruning library to update the
# masks. Therefore we need to disable it. Following `pruning` flags serve that
# purpose.
pruning.final_sparsity = 0.
pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.fixed_sparse_init = True
init_masks.sparsity = 0.9
UpdatedRigLOptimizer.begin_step = 0
# 2M = 80% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
UpdatedRigLOptimizer.end_step = 2000000
UpdatedRigLOptimizer.frequency = 5000
UpdatedRigLOptimizer.drop_fraction_anneal = 'cosine'
UpdatedRigLOptimizer.drop_fraction = 0.3
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_rigl_impala_net.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'
SparseDQNAgent.mode = 'rigl'
# For sparse training methods we don't use the pruning library to update the
# masks. Therefore we need to disable it. Following `pruning` flags serve that
# purpose.
pruning.final_sparsity = 0.
pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.fixed_sparse_init = True
init_masks.sparsity = 0.9
UpdatedRigLOptimizer.begin_step = 0
# 2M = 80% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
UpdatedRigLOptimizer.end_step = 2000000
UpdatedRigLOptimizer.frequency = 5000
UpdatedRigLOptimizer.drop_fraction_anneal = 'cosine'
UpdatedRigLOptimizer.drop_fraction = 0.3
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_set.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'
SparseDQNAgent.mode = 'set'
# For sparse training methods we don't use the pruning library to update the
# masks. Therefore we need to disable it. Following `pruning` flags serve that
# purpose.
pruning.final_sparsity = 0.
pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.fixed_sparse_init = True
init_masks.sparsity = 0.9
UpdatedSETOptimizer.begin_step = 0
# 2M = 80% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
UpdatedSETOptimizer.end_step = 2000000
UpdatedSETOptimizer.frequency = 5000
UpdatedSETOptimizer.drop_fraction_anneal = 'cosine'
UpdatedSETOptimizer.drop_fraction = 0.3
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_set_impala_net.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'
SparseDQNAgent.mode = 'set'
# For sparse training methods we don't use the pruning library to update the
# masks. Therefore we need to disable it. Following `pruning` flags serve that
# purpose.
pruning.final_sparsity = 0.
pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.fixed_sparse_init = True
init_masks.sparsity = 0.9
UpdatedSETOptimizer.begin_step = 0
# 2M = 80% optimizer steps when training for 40M env steps with a frame skip
# of 4 (= 10M transitions), and training every 4th env transition (2.5M train
# steps in total).
UpdatedSETOptimizer.end_step = 2000000
UpdatedSETOptimizer.frequency = 5000
UpdatedSETOptimizer.drop_fraction_anneal = 'cosine'
UpdatedSETOptimizer.drop_fraction = 0.3
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_static.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'
SparseDQNAgent.mode = 'static'
# For sparse training methods we don't use the pruning library to update the
# masks. Therefore we need to disable it. Following `pruning` flags serve that
# purpose.
pruning.final_sparsity = 0.
pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.sparsity = 0.95
================================================
FILE: rigl/rl/sparsetrain_configs/dqn_atari_static_impala_net.gin
================================================
include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'
SparseDQNAgent.mode = 'static'
# For sparse training methods we don't use the pruning library to update the
# masks. Therefore we need to disable it. Following `pruning` flags serve that
# purpose.
pruning.final_sparsity = 0.
pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.sparsity = 0.95
================================================
FILE: rigl/rl/tfagents/configs/dqn_gym_dense_config.gin
================================================
# Configs to run DQN training for dense networks on classic control environments.
train_eval.env_name='CartPole-v0'
train_eval.fc_layer_params = (512, 512)
train_eval.target_update_period = 100
train_eval.batch_size = 128
# Environment:train steps ratio is 1:1
train_eval.num_iterations = 100000
train_eval.weight_decay = 1e-6
train_eval.width = 1.0
train_eval.policy_save_interval = 10000
train_eval.epsilon_greedy = 0.01
train_eval.eval_interval = 2000
train_eval.eval_episodes = 20
train_eval.sparse_output_layer = False
train_eval.train_mode = 'dense'
mask_updater.update_alg = ''
mask_updater.schedule_alg = ''
log_snr.freq=5000
================================================
FILE: rigl/rl/tfagents/configs/dqn_gym_pruning_config.gin
================================================
include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin'
# Configs to run DQN training for pruning on classic control environments.
train_eval.sparse_output_layer = True
train_eval.train_mode = 'sparse'
# This must be set to 0 when pruning to avoid
# initializing the masks
init_masks.sparsity = 0.0
wrap_all_layers.mode = 'prune'
wrap_all_layers.initial_sparsity = 0.0
wrap_all_layers.final_sparsity = 0.9
wrap_all_layers.mask_init_method = 'erdos_renyi_kernel'
# Environment:train steps ratio is 1:1
# We start pruning after 20% training (20,000) and stop after 75% (75,000)
wrap_all_layers.begin_step = 20000
wrap_all_layers.end_step = 75000
wrap_all_layers.frequency = 1000
log_sparsities.log_images = False
================================================
FILE: rigl/rl/tfagents/configs/dqn_gym_sparse_config.gin
================================================
include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin'
# Configs to run DQN training for static, set, and rigl on classic control
# environments.
train_eval.sparse_output_layer = True
train_eval.train_mode = 'sparse'
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.fixed_sparse_init = True
init_masks.sparsity = 0.9
# For static, set this to ''
# For rigl set this to 'rigl'
# For set set this to 'set'
mask_updater.update_alg = ''
mask_updater.schedule_alg = 'cosine'
mask_updater.update_freq = 1000
mask_updater.init_drop_fraction = 0.5
# Environment:train steps ratio is 1:1, we stop after 75% training = 75,000
mask_updater.last_update_step = 75000
mask_updater.use_stateless = False
wrap_all_layers.mode = 'constant'
log_sparsities.log_images = False
================================================
FILE: rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin
================================================
# Config to run training for dense on mujoco environments.
train_eval.env_name='HalfCheetah-v2'
train_eval.actor_fc_layers = (64, 64)
train_eval.value_fc_layers = (64, 64)
# In order to execute ~1M environment steps, we run 489 iterations
# (`--num_iterations=489`) which results in 1,001,472 environment steps. Each
# iteration results in 320 training steps (or 320 gradient updates, this is
# calulated from environemnt_steps * num_epochs / minibatch_size) and 2,048
# environment steps. Thus 489 *2,048 = 1,001,472 environment steps and
# 489 * 320 = 156,480 training steps.
train_eval.num_iterations = 489
train_eval.weight_decay = 1e-6
train_eval.width = 1.0
train_eval.policy_save_interval = 51000
train_eval.num_epochs = 10
train_eval.eval_interval = 2000
train_eval.eval_episodes = 20
train_eval.sparse_output_layer = False
train_eval.train_mode_actor = 'dense'
train_eval.train_mode_value = 'dense'
mask_updater.update_alg = ''
mask_updater.schedule_alg = ''
log_snr.freq=5000
================================================
FILE: rigl/rl/tfagents/configs/ppo_mujoco_pruning_config.gin
================================================
include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin'
train_eval.sparse_output_layer = True
train_eval.train_mode_actor = 'sparse'
train_eval.train_mode_value = 'sparse'
# This must be set to 0 when pruning to avoid
# initializing the masks
init_masks.sparsity = 0.0
wrap_all_layers.mode = 'prune'
wrap_all_layers.initial_sparsity = 0.0
wrap_all_layers.final_sparsity = 0.9
wrap_all_layers.mask_init_method = 'erdos_renyi_kernel'
# 156,480 steps total
# Start at ~20% = 31,296
# End at ~75% = 117,360
wrap_all_layers.begin_step = 32000
wrap_all_layers.end_step = 120000
wrap_all_layers.frequency = 500
log_sparsities.log_images = False
================================================
FILE: rigl/rl/tfagents/configs/ppo_mujoco_sparse_config.gin
================================================
include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin'
# Config to run PPO training for static, set, and rigl on mujoco environments.
train_eval.sparse_output_layer = True
train_eval.train_mode_actor = 'sparse'
train_eval.train_mode_value = 'sparse'
train_eval.weight_decay = 1e-4
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.fixed_sparse_init = True
init_masks.sparsity = 0.9
# For static, set this to ''
# For rigl set this to 'rigl'
# For set set this to 'set'
mask_updater.update_alg = ''
mask_updater.schedule_alg = 'cosine'
mask_updater.update_freq = 250
mask_updater.init_drop_fraction = 0.3
# 156,480 steps total, end at 75% = 117,360
mask_updater.last_update_step = 120000
mask_updater.use_stateless = False
wrap_all_layers.mode = 'constant'
log_sparsities.log_images = False
================================================
FILE: rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin
================================================
# Config to run SAC training for dense on mujoco environments.
train_eval.env_name = 'Humanoid-v2'
train_eval.initial_collect_steps = 1000
train_eval.num_iterations = 1000000 # 1M
train_eval.width = 1.0
train_eval.weight_decay = 1e-4
================================================
FILE: rigl/rl/tfagents/configs/sac_mujoco_pruning_config.gin
================================================
include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin'
# Configs to run SAC training for pruning on mujoco environments.
train_eval.train_mode_actor = 'sparse'
# Both critics
train_eval.train_mode_value = 'sparse'
train_eval.sparse_output_layer = True
init_masks.fixed_sparse_init = True
# This must be set to 0 when pruning to avoid
# initializing the masks
init_masks.sparsity = 0.0
wrap_all_layers.mode = 'prune'
wrap_all_layers.initial_sparsity = 0.0
wrap_all_layers.final_sparsity = 0.9
wrap_all_layers.mask_init_method = 'erdos_renyi_kernel'
# 1M steps total
# Start at 20%, end at 80%
wrap_all_layers.begin_step = 200000
wrap_all_layers.end_step = 800000
wrap_all_layers.frequency = 1000
log_sparsities.log_images = False
================================================
FILE: rigl/rl/tfagents/configs/sac_mujoco_sparse_config.gin
================================================
include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin'
# Configs to run SAC training for static, set, and rigl on mujoco
# environments.
train_eval.sparse_output_layer = True
train_eval.train_mode_actor = 'sparse'
# Both critics
train_eval.train_mode_value = 'sparse'
train_eval.actor_critic_sparsities_str = ''
train_eval.weight_decay = 1e-6
init_masks.mask_init_method = 'erdos_renyi_kernel'
init_masks.fixed_sparse_init = True
init_masks.sparsity = 0.9
mask_updater.update_alg = ''
mask_updater.schedule_alg = 'cosine'
mask_updater.update_freq = 1000
mask_updater.init_drop_fraction = 0.5
# 1M / train_eval.num_iterations * 0.8
mask_updater.last_update_step = 800000
mask_updater.use_stateless = False
wrap_all_layers.mode = 'constant'
log_sparsities.log_images = False
================================================
FILE: rigl/rl/tfagents/dqn_train_eval.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Sparse training DQN using actor/learner in a gym environment.
"""
import functools
import os
from typing import Tuple
from absl import app
from absl import flags
from absl import logging
import gin
import numpy as np
import reverb
from rigl.rigl_tf2 import mask_updaters
from rigl.rl import sparse_utils
from rigl.rl.tfagents import tf_sparse_utils
import tensorflow.compat.v2 as tf
from tf_agents.agents.dqn import dqn_agent
from tf_agents.environments import suite_atari
from tf_agents.environments import suite_gym
from tf_agents.metrics import py_metrics
from tf_agents.networks import sequential
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_py_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.system import system_multiprocessing as multiprocessing
from tf_agents.train import actor
from tf_agents.train import learner
from tf_agents.train import triggers
from tf_agents.train.utils import train_utils
from tf_agents.utils import common
from tf_agents.utils import eager_utils
FLAGS = flags.FLAGS
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_integer(
'reverb_port', None,
'Port for reverb server, if None, use a randomly chosen unused port.')
flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string(
'gin_bindings', [],
'Gin bindings to override the values set in the config files '
'(e.g. "train_eval.env_name=Acrobot-v1",'
' "init_masks.sparsity=0.9").')
flags.DEFINE_float(
'average_last_fraction', 0.1,
'Tells what fraction latest evaluation scores are averaged. This is used'
' to reduce variance.')
@gin.configurable
class SparseDqnAgent(dqn_agent.DqnAgent):
"""Wrapped DqnAgent that supports sparse training."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_ = sparse_utils.init_masks(self._q_network)
def loss_fn(experience_data, weights_data):
# The following is just to fit to the existing API.
loss_info = self._loss(
experience_data,
td_errors_loss_fn=self._td_errors_loss_fn,
gamma=self._gamma,
reward_scale_factor=self._reward_scale_factor,
weights=weights_data,
training=True)
return loss_info.extra.td_loss
# Create mask updater if doesn't exists
self._mask_updater = mask_updaters.get_mask_updater(
self._q_network, self._optimizer, loss_fn)
def _train(self, experience, weights):
tf.compat.v2.summary.experimental.set_step(self.train_step_counter)
tf_sparse_utils.update_prune_step(self._q_network, self._train_step_counter)
with tf.GradientTape(persistent=True) as tape:
loss_info = self._loss(
experience,
td_errors_loss_fn=self._td_errors_loss_fn,
gamma=self._gamma,
reward_scale_factor=self._reward_scale_factor,
weights=weights,
training=True)
tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan')
variables_to_train = self._q_network.trainable_weights
non_trainable_weights = self._q_network.non_trainable_weights
assert list(variables_to_train), "No variables in the agent's q_network."
grads = tape.gradient(loss_info.loss, variables_to_train)
tf_sparse_utils.log_snr(tape, loss_info.extra.td_loss,
self.train_step_counter, variables_to_train)
# Tuple is used for py3, where zip is a generator producing values once.
grads_and_vars = list(zip(grads, variables_to_train))
def _mask_update_step():
# Second argument is not used.
self._mask_updater.set_validation_data(experience, weights)
self._mask_updater.update(self.train_step_counter)
with tf.name_scope('/'):
tf.summary.scalar(
name='drop_fraction', data=self._mask_updater.last_drop_fraction)
tf_sparse_utils.log_sparsities(self._q_network)
if self._mask_updater is not None:
is_update = self._mask_updater.is_update_iter(self.train_step_counter)
tf.cond(is_update, _mask_update_step, lambda: None)
if self._gradient_clipping is not None:
grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars,
self._gradient_clipping)
if self._summarize_grads_and_vars:
grads_and_vars_with_non_trainable = (
grads_and_vars + [(None, v) for v in non_trainable_weights])
eager_utils.add_variables_summaries(grads_and_vars_with_non_trainable,
self.train_step_counter)
eager_utils.add_gradients_summaries(grads_and_vars,
self.train_step_counter)
self._optimizer.apply_gradients(grads_and_vars)
self.train_step_counter.assign_add(1)
self._update_target()
return loss_info
def _scale_width(num_units, width):
assert width > 0
return int(max(1, num_units * width))
def build_network(
fc_layer_params,
num_actions,
is_sparse,
input_dim,
width = 1.0,
weight_decay = 0.0,
sparse_output_layer = True
):
"""Builds a Sequential model."""
def dense_layer(num_units):
return tf.keras.layers.Dense(
num_units,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0, mode='fan_in', distribution='truncated_normal'),
kernel_regularizer=tf.keras.regularizers.L2(weight_decay),)
# QNetwork consists of a sequence of Dense layers followed by a dense layer
# with `num_actions` units to generate one q_value per available action as
# its output.
all_layers = [
dense_layer(_scale_width(num_units, width=width)
) for num_units in fc_layer_params]
all_layers.append(
tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.03, maxval=0.03),
bias_initializer=tf.keras.initializers.Constant(-0.2)))
if is_sparse:
if sparse_output_layer:
all_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)
else:
all_layers = (tf_sparse_utils.wrap_all_layers(all_layers[:-1], input_dim)
+ all_layers[-1:])
return sequential.Sequential(all_layers)
@gin.configurable
def train_eval(
root_dir,
env_name='CartPole-v0',
# Training params
update_frequency=1,
initial_collect_steps=1000,
num_iterations=100000,
fc_layer_params=(100,),
# Agent params
epsilon_greedy=0.1,
epsilon_decay_period=250000,
batch_size=64,
learning_rate=1e-3,
n_step_update=1,
gamma=0.99,
target_update_tau=1.0,
target_update_period=100,
reward_scale_factor=1.0,
# Replay params
reverb_port=None,
replay_capacity=100000,
# Others
policy_save_interval=1000,
eval_interval=1000,
eval_episodes=10,
weight_decay = 0.0,
width = 1.0,
debug_summaries=False,
sparse_output_layer=True,
train_mode='dense'):
"""Trains and evaluates DQN."""
logging.info('DQN params: Fc layer params: %s', fc_layer_params)
logging.info('DQN params: Train mode: %s', train_mode)
logging.info('DQN params: Target update period: %s', target_update_period)
logging.info('DQN params: Policy save interval: %s', policy_save_interval)
logging.info('DQN params: Eval interval: %s', eval_interval)
logging.info('DQN params: Environment name: %s', env_name)
logging.info('DQN params: Weight decay: %s', weight_decay)
logging.info('DQN params: Width: %s', width)
logging.info('DQN params: Batch size: %s', batch_size)
logging.info('DQN params: Target update period: %s', target_update_period)
logging.info('DQN params: Learning rate: %s', learning_rate)
logging.info('DQN params: Num iterations: %s', num_iterations)
logging.info('DQN params: Sparse output layer: %s', sparse_output_layer)
collect_env = suite_gym.load(env_name)
eval_env = suite_gym.load(env_name)
logging.info('Collect env: %s', collect_env)
logging.info('Eval env: %s', eval_env)
time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec())
action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec())
train_step = train_utils.create_train_step()
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
observation_shape = collect_env.observation_spec().shape
# Build network and get pruning params
is_atari = False
if not is_atari:
q_net = build_network(
fc_layer_params=fc_layer_params,
num_actions=num_actions,
is_sparse=(train_mode == 'sparse'),
# observation_shape is 1-dimensional. We need this so that we can
# calculate the dimensions of the first layer.
input_dim=observation_shape[-1],
width=width,
weight_decay=weight_decay,
sparse_output_layer=sparse_output_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss = common.element_wise_squared_loss
decay_fn = epsilon_greedy
agent = SparseDqnAgent(
time_step_tensor_spec,
action_tensor_spec,
q_network=q_net,
epsilon_greedy=decay_fn,
n_step_update=n_step_update,
target_update_tau=target_update_tau,
target_update_period=target_update_period,
optimizer=optimizer,
td_errors_loss_fn=loss,
gamma=gamma,
reward_scale_factor=reward_scale_factor,
train_step_counter=train_step,
debug_summaries=debug_summaries)
table_name = 'uniform_table'
table = reverb.Table(
table_name,
max_size=replay_capacity,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1))
reverb_server = reverb.Server([table], port=reverb_port)
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
agent.collect_data_spec,
sequence_length=2,
table_name=table_name,
local_server=reverb_server)
rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
reverb_replay.py_client, table_name,
sequence_length=2,
stride_length=1)
dataset = reverb_replay.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=2).prefetch(3)
experience_dataset_fn = lambda: dataset
saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
env_step_metric = py_metrics.EnvironmentSteps()
learning_triggers = [
triggers.PolicySavedModelTrigger(
saved_model_dir,
agent,
train_step,
interval=policy_save_interval,
metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}),
triggers.StepPerSecondLogTrigger(train_step, interval=100),
]
dqn_learner = learner.Learner(
root_dir,
train_step,
agent,
experience_dataset_fn,
triggers=learning_triggers,
run_optimizer_variable_init=False)
# If we haven't trained yet make sure we collect some random samples first to
# fill up the Replay Buffer with some experience.
random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
collect_env.action_spec())
initial_collect_actor = actor.Actor(
collect_env,
random_policy,
train_step,
steps_per_run=initial_collect_steps,
observers=[rb_observer])
logging.info('Doing initial collect.')
initial_collect_actor.run()
tf_collect_policy = agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
use_tf_function=True)
collect_actor = actor.Actor(
collect_env,
collect_policy,
train_step,
steps_per_run=update_frequency,
observers=[rb_observer, env_step_metric],
metrics=actor.collect_metrics(10),
reference_metrics=[env_step_metric],
summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
)
tf_greedy_policy = agent.policy
greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
use_tf_function=True)
eval_actor = actor.Actor(
eval_env,
greedy_policy,
train_step,
episodes_per_run=eval_episodes,
metrics=actor.eval_metrics(eval_episodes),
reference_metrics=[env_step_metric],
summary_dir=os.path.join(root_dir, 'eval'),
)
average_returns = []
if eval_interval:
logging.info('Evaluating.')
eval_actor.run_and_log()
for metric in eval_actor.metrics:
if isinstance(metric, py_metrics.AverageReturnMetric):
average_returns.append(metric._buffer.mean())
logging.info('Training.')
for _ in range(num_iterations):
collect_actor.run()
dqn_learner.run(iterations=1)
if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
logging.info('Evaluating.')
eval_actor.run_and_log()
for metric in eval_actor.metrics:
if isinstance(metric, py_metrics.AverageReturnMetric):
average_returns.append(metric._buffer.mean())
# Log last section of evaluation scores for the final metric.
idx = int(FLAGS.average_last_fraction * len(average_returns))
avg_return = np.mean(average_returns[-idx:])
logging.info('Step %d, Average Return: %f', env_step_metric.result(),
avg_return)
rb_observer.close()
reverb_server.stop()
def main(_):
tf.config.experimental_run_functions_eagerly(False)
logging.set_verbosity(logging.INFO)
tf.enable_v2_behavior()
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)
logging.info('Gin bindings: %s', FLAGS.gin_bindings)
train_eval(
FLAGS.root_dir,
reverb_port=FLAGS.reverb_port)
if __name__ == '__main__':
flags.mark_flag_as_required('root_dir')
multiprocessing.handle_main(functools.partial(app.run, main))
================================================
FILE: rigl/rl/tfagents/ppo_train_eval.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Sparse training PPO using actor/learner in a gym environment.
"""
import collections
import functools
import os
from typing import Optional
from absl import app
from absl import flags
from absl import logging
import gin
import numpy as np
import reverb
from rigl.rigl_tf2 import mask_updaters
from rigl.rl import sparse_utils
from rigl.rl.tfagents import sparse_ppo_actor_network
from rigl.rl.tfagents import sparse_ppo_discrete_actor_network
from rigl.rl.tfagents import sparse_value_network
from rigl.rl.tfagents import tf_sparse_utils
import tensorflow.compat.v2 as tf
from tf_agents.agents import tf_agent
from tf_agents.agents.ppo import ppo_clip_agent
from tf_agents.agents.ppo import ppo_utils
from tf_agents.environments import suite_gym
from tf_agents.environments import suite_mujoco
from tf_agents.metrics import py_metrics
from tf_agents.networks import network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.system import system_multiprocessing as multiprocessing
from tf_agents.train import actor
from tf_agents.train import learner
from tf_agents.train import ppo_learner
from tf_agents.train import triggers
from tf_agents.train.utils import spec_utils
from tf_agents.train.utils import train_utils
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import common
from tf_agents.utils import eager_utils
from tf_agents.utils import nest_utils
from tf_agents.utils import object_identity
FLAGS = flags.FLAGS
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_integer(
'reverb_port', None,
'Port for reverb server, if None, use a randomly chosen unused port.')
flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string(
'gin_bindings', [],
'Gin bindings to override the values set in the config files '
'(e.g. "train_eval.env_name=Acrobot-v1",'
' "init_masks.sparsity=0.9").')
# Env params
flags.DEFINE_bool('is_atari', False, 'Whether the env is an atari game.')
flags.DEFINE_bool('is_mujoco', False, 'Whether the env is a mujoco game.')
flags.DEFINE_bool('is_classic', False,
'Whether the env is a classic control game.')
flags.DEFINE_float(
'average_last_fraction', 0.1,
'Tells what fraction latest evaluation scores are averaged. This is used'
' to reduce variance.')
SparsePPOLossInfo = collections.namedtuple('SparsePPOLossInfo', (
'policy_gradient_loss',
'value_estimation_loss',
'l2_regularization_loss',
'entropy_regularization_loss',
'kl_penalty_loss',
'total_loss_per_sample',
))
def _normalize_advantages(advantages, axes=(0,), variance_epsilon=1e-8):
adv_mean, adv_var = tf.nn.moments(advantages, axes=axes, keepdims=True)
normalized_advantages = tf.nn.batch_normalization(
advantages,
adv_mean,
adv_var,
offset=None,
scale=None,
variance_epsilon=variance_epsilon)
return normalized_advantages
@gin.configurable
class SparsePPOAgent(ppo_clip_agent.PPOClipAgent):
"""Wrapped PPOClipAgent that supports sparse training."""
def __init__(self,
*args,
policy_l2_reg=0.0,
value_function_l2_reg=0.0,
shared_vars_l2_reg=0.0,
**kwargs):
super().__init__(*args,
policy_l2_reg=policy_l2_reg,
value_function_l2_reg=value_function_l2_reg,
shared_vars_l2_reg=shared_vars_l2_reg,
**kwargs)
# Name scoping has been removed here so
# debug_summaries are permenantly disabled. To restore with proper
# scoping.
self._debug_summaries = False
# Pruning layer requires the pruning_step to be >1 during forward pass.
tf_sparse_utils.update_prune_step(
self._actor_net, self.train_step_counter + 1)
tf_sparse_utils.update_prune_step(
self._value_net, self.train_step_counter + 1)
_ = sparse_utils.init_masks(self._actor_net)
_ = sparse_utils.init_masks(self._value_net)
# BEGIN: sparse training create mask updaters
def loss_fn(experience_data, weights_data):
# The following is just to fit to the existing API.
(time_steps, actions, old_act_log_probs, returns, normalized_advantages,
old_action_distribution_parameters, masked_weights,
old_value_predictions) = self._process_experience_weights(
experience_data, weights_data)
loss_info = self.get_loss(
time_steps,
actions,
old_act_log_probs,
returns,
normalized_advantages,
old_action_distribution_parameters,
masked_weights,
self.train_step_counter,
False,
old_value_predictions=old_value_predictions,
training=True)
return loss_info.extra.total_loss_per_sample
self._mask_updater_actor = mask_updaters.get_mask_updater(
self._actor_net, self._optimizer, loss_fn)
self._mask_updater_value = mask_updaters.get_mask_updater(
self._value_net, self._optimizer, loss_fn)
# END: sparse training create mask updaters
logging.info('SparsePPOAgent: policy_l2_reg %.5f.', policy_l2_reg)
logging.info('SparsePPOAgent: value_function_l2_reg %.5f.',
value_function_l2_reg)
logging.info('SparsePPOAgent: shared_vars_l2_reg %.5f.', shared_vars_l2_reg)
def _process_experience_weights(self, experience, weights):
experience = self._as_trajectory(experience)
if self._compute_value_and_advantage_in_train:
processed_experience = self._preprocess(experience)
else:
processed_experience = experience
# Mask trajectories that cannot be used for training.
valid_mask = ppo_utils.make_trajectory_mask(processed_experience)
if weights is None:
masked_weights = valid_mask
else:
masked_weights = weights * valid_mask
# Reconstruct per-timestep policy distribution from stored distribution
# parameters.
old_action_distribution_parameters = processed_experience.policy_info[
'dist_params']
old_actions_distribution = (
ppo_utils.distribution_from_spec(
self._action_distribution_spec,
old_action_distribution_parameters,
legacy_distribution_network=isinstance(
self._actor_net, network.DistributionNetwork)))
# Compute log probability of actions taken during data collection, using the
# collect policy distribution.
old_act_log_probs = common.log_probability(old_actions_distribution,
processed_experience.action,
self._action_spec)
if self._debug_summaries and not tf.config.list_logical_devices('TPU'):
actions_list = tf.nest.flatten(processed_experience.action)
show_action_index = len(actions_list) != 1
for i, single_action in enumerate(actions_list):
action_name = ('actions_{}'.format(i)
if show_action_index else 'actions')
tf.compat.v2.summary.histogram(
name=action_name, data=single_action, step=self.train_step_counter)
time_steps = ts.TimeStep(
step_type=processed_experience.step_type,
reward=processed_experience.reward,
discount=processed_experience.discount,
observation=processed_experience.observation)
actions = processed_experience.action
returns = processed_experience.policy_info['return']
advantages = processed_experience.policy_info['advantage']
normalized_advantages = _normalize_advantages(advantages,
variance_epsilon=1e-8)
if self._debug_summaries and not tf.config.list_logical_devices('TPU'):
tf.compat.v2.summary.histogram(
name='advantages_normalized',
data=normalized_advantages,
step=self.train_step_counter)
old_value_predictions = processed_experience.policy_info['value_prediction']
return (time_steps, actions, old_act_log_probs, returns,
normalized_advantages, old_action_distribution_parameters,
masked_weights, old_value_predictions)
def _train(self, experience, weights):
tf.compat.v2.summary.experimental.set_step(self.train_step_counter)
(time_steps, actions, old_act_log_probs, returns, normalized_advantages,
old_action_distribution_parameters, masked_weights,
old_value_predictions) = self._process_experience_weights(
experience, weights)
if self._compute_value_and_advantage_in_train:
processed_experience = self._preprocess(experience)
else:
processed_experience = experience
batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
# Loss tensors across batches will be aggregated for summaries.
policy_gradient_losses = []
value_estimation_losses = []
l2_regularization_losses = []
entropy_regularization_losses = []
kl_penalty_losses = []
loss_info = None
variables_to_train = list(
object_identity.ObjectIdentitySet(self._actor_net.trainable_weights +
self._value_net.trainable_weights))
# Sort to ensure tensors on different processes end up in same order.
variables_to_train = sorted(variables_to_train, key=lambda x: x.name)
for _ in range(self._num_epochs):
# Name scoping has been removed here so
# debug_summaries are permenantly disabled. To restore with proper
# scoping.
debug_summaries = False
with tf.GradientTape(persistent=True) as tape:
loss_info = self.get_loss(
time_steps,
actions,
old_act_log_probs,
returns,
normalized_advantages,
old_action_distribution_parameters,
masked_weights,
self.train_step_counter,
debug_summaries,
old_value_predictions=old_value_predictions,
training=True)
grads = tape.gradient(loss_info.loss, variables_to_train)
tf_sparse_utils.log_snr(tape, loss_info.extra.total_loss_per_sample,
self.train_step_counter, variables_to_train)
# BEGIN sparse training mask update
# We use the lastest set of gradients to update the masks for sparse
# training. Note, we do this before gradient clipping.
def _mask_update_step(mask_updater, updater_name):
mask_updater.set_validation_data(experience, weights)
mask_updater.update(self.train_step_counter)
with tf.name_scope('Drop_fraction/'):
tf.summary.scalar(
name=f'{updater_name}',
data=mask_updater.last_drop_fraction)
mask_update_step_actor = functools.partial(
_mask_update_step, self._mask_updater_actor, 'actor')
mask_update_step_value = functools.partial(
_mask_update_step, self._mask_updater_value, 'value')
tf_sparse_utils.log_sparsities(self._actor_net, 'actor')
tf_sparse_utils.log_sparsities(self._value_net, 'value')
tf_sparse_utils.log_total_params([self._actor_net, self._value_net])
if self._mask_updater_actor is not None:
is_update_actor = self._mask_updater_actor.is_update_iter(
self.train_step_counter)
tf.cond(is_update_actor, mask_update_step_actor, lambda: None)
if self._mask_updater_value is not None:
is_update_value = self._mask_updater_value.is_update_iter(
self.train_step_counter)
tf.cond(is_update_value, mask_update_step_value, lambda: None)
# END sparse training mask update
if self._gradient_clipping > 0:
grads, _ = tf.clip_by_global_norm(grads, self._gradient_clipping)
# Tuple is used for py3, where zip is a generator producing values once.
grads_and_vars = tuple(zip(grads, variables_to_train))
# If summarize_gradients, create functions for summarizing both
# gradients and variables.
if self._summarize_grads_and_vars and debug_summaries:
eager_utils.add_gradients_summaries(grads_and_vars,
self.train_step_counter)
eager_utils.add_variables_summaries(grads_and_vars,
self.train_step_counter)
self._optimizer.apply_gradients(grads_and_vars)
self.train_step_counter.assign_add(1)
policy_gradient_losses.append(loss_info.extra.policy_gradient_loss)
value_estimation_losses.append(loss_info.extra.value_estimation_loss)
l2_regularization_losses.append(loss_info.extra.l2_regularization_loss)
entropy_regularization_losses.append(
loss_info.extra.entropy_regularization_loss)
kl_penalty_losses.append(loss_info.extra.kl_penalty_loss)
if self._initial_adaptive_kl_beta > 0:
# After update epochs, update adaptive kl beta, then update observation
# normalizer and reward normalizer.
policy_state = self._collect_policy.get_initial_state(batch_size)
# Compute the mean kl from previous action distribution.
kl_divergence = self._kl_divergence(
time_steps, old_action_distribution_parameters,
self._collect_policy.distribution(time_steps, policy_state).action)
self.update_adaptive_kl_beta(kl_divergence)
if self.update_normalizers_in_train:
self.update_observation_normalizer(time_steps.observation)
self.update_reward_normalizer(processed_experience.reward)
loss_info = tf.nest.map_structure(tf.identity, loss_info)
# Make summaries for total loss averaged across all epochs.
# The *_losses lists will have been populated by
# calls to self.get_loss. Assumes all the losses have same length.
with tf.name_scope('Losses/'):
num_epochs = len(policy_gradient_losses)
total_policy_gradient_loss = tf.add_n(policy_gradient_losses) / num_epochs
total_value_estimation_loss = tf.add_n(
value_estimation_losses) / num_epochs
total_l2_regularization_loss = tf.add_n(
l2_regularization_losses) / num_epochs
total_entropy_regularization_loss = tf.add_n(
entropy_regularization_losses) / num_epochs
total_kl_penalty_loss = tf.add_n(kl_penalty_losses) / num_epochs
tf.compat.v2.summary.scalar(
name='policy_gradient_loss',
data=total_policy_gradient_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='value_estimation_loss',
data=total_value_estimation_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='l2_regularization_loss',
data=total_l2_regularization_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='entropy_regularization_loss',
data=total_entropy_regularization_loss,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='kl_penalty_loss',
data=total_kl_penalty_loss,
step=self.train_step_counter)
total_abs_loss = (
tf.abs(total_policy_gradient_loss) +
tf.abs(total_value_estimation_loss) +
tf.abs(total_entropy_regularization_loss) +
tf.abs(total_l2_regularization_loss) + tf.abs(total_kl_penalty_loss))
tf.compat.v2.summary.scalar(
name='total_abs_loss',
data=total_abs_loss,
step=self.train_step_counter)
with tf.name_scope('LearningRate/'):
learning_rate = ppo_utils.get_learning_rate(self._optimizer)
tf.compat.v2.summary.scalar(
name='learning_rate',
data=learning_rate,
step=self.train_step_counter)
if self._summarize_grads_and_vars and not tf.config.list_logical_devices(
'TPU'):
with tf.name_scope('Variables/'):
all_vars = (
self._actor_net.trainable_weights +
self._value_net.trainable_weights)
for var in all_vars:
tf.compat.v2.summary.histogram(
name=var.name.replace(':', '_'),
data=var,
step=self.train_step_counter)
return loss_info
def get_loss(self,
time_steps,
actions,
act_log_probs,
returns,
normalized_advantages,
action_distribution_parameters,
weights,
train_step,
debug_summaries,
old_value_predictions = None,
training = False):
"""Compute the loss and create optimization op for one training epoch.
All tensors should have a single batch dimension.
Args:
time_steps: A minibatch of TimeStep tuples.
actions: A minibatch of actions.
act_log_probs: A minibatch of action probabilities (probability under the
sampling policy).
returns: A minibatch of per-timestep returns.
normalized_advantages: A minibatch of normalized per-timestep advantages.
action_distribution_parameters: Parameters of data-collecting action
distribution. Needed for KL computation.
weights: Optional scalar or element-wise (per-batch-entry) importance
weights. Includes a mask for invalid timesteps.
train_step: A train_step variable to increment for each train step.
Typically the global_step.
debug_summaries: True if debug summaries should be created.
old_value_predictions: (Optional) The saved value predictions, used for
calculating the value estimation loss when value clipping is performed.
training: Whether this loss is being used for training.
Returns:
A tf_agent.LossInfo named tuple with the total_loss and all intermediate
losses in the extra field contained in a PPOLossInfo named tuple.
"""
# Evaluate the current policy on timesteps.
# batch_size from time_steps
batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
policy_state = self._collect_policy.get_initial_state(batch_size)
# We must use _distribution because the distribution API doesn't pass down
# the training= kwarg.
distribution_step = self._collect_policy._distribution(
time_steps,
policy_state,
training=training)
current_policy_distribution = distribution_step.action
# Call all loss functions and add all loss values.
(value_estimation_loss,
value_estimation_loss_per_sample) = self.value_estimation_loss(
time_steps=time_steps,
returns=returns,
old_value_predictions=old_value_predictions,
weights=weights,
debug_summaries=debug_summaries,
training=training)
(policy_gradient_loss,
policy_gradient_loss_per_sample) = self.policy_gradient_loss(
time_steps,
actions,
tf.stop_gradient(act_log_probs),
tf.stop_gradient(normalized_advantages),
current_policy_distribution,
weights,
debug_summaries=debug_summaries)
if (self._policy_l2_reg > 0.0 or self._value_function_l2_reg > 0.0 or
self._shared_vars_l2_reg > 0.0):
l2_regularization_loss = self.l2_regularization_loss(debug_summaries)
else:
l2_regularization_loss = tf.zeros_like(policy_gradient_loss)
l2_regularization_loss_per_sample = tf.repeat(
l2_regularization_loss / tf.cast(batch_size, tf.float32), batch_size)
if self._entropy_regularization > 0.0:
(entropy_regularization_loss, entropy_regularization_loss_per_sample
) = self.entropy_regularization_loss(time_steps,
current_policy_distribution, weights,
debug_summaries)
else:
entropy_regularization_loss = tf.zeros_like(policy_gradient_loss)
entropy_regularization_loss_per_sample = tf.repeat(
tf.constant(0, dtype=tf.float32), batch_size)
if self._initial_adaptive_kl_beta == 0:
kl_penalty_loss = tf.zeros_like(policy_gradient_loss)
else:
kl_penalty_loss = self.kl_penalty_loss(time_steps,
action_distribution_parameters,
current_policy_distribution,
weights, debug_summaries)
kl_penalty_loss_per_sample = tf.repeat(
kl_penalty_loss / tf.cast(batch_size, tf.float32), batch_size)
total_loss = (
policy_gradient_loss + value_estimation_loss + l2_regularization_loss +
entropy_regularization_loss + kl_penalty_loss)
total_loss_per_sample = (
policy_gradient_loss_per_sample + value_estimation_loss_per_sample +
l2_regularization_loss_per_sample +
entropy_regularization_loss_per_sample + kl_penalty_loss_per_sample)
return tf_agent.LossInfo(
total_loss,
SparsePPOLossInfo(
policy_gradient_loss=policy_gradient_loss,
value_estimation_loss=value_estimation_loss,
l2_regularization_loss=l2_regularization_loss,
entropy_regularization_loss=entropy_regularization_loss,
kl_penalty_loss=kl_penalty_loss,
total_loss_per_sample=total_loss_per_sample
))
def value_estimation_loss(self,
time_steps,
returns,
weights,
old_value_predictions = None,
debug_summaries = False,
training = False):
"""Computes the value estimation loss for actor-critic training.
All tensors should have a single batch dimension.
Args:
time_steps: A batch of timesteps.
returns: Per-timestep returns for value function to predict. (Should come
from TD-lambda computation.)
weights: Optional scalar or element-wise (per-batch-entry) importance
weights. Includes a mask for invalid timesteps.
old_value_predictions: (Optional) The saved value predictions from
policy_info, required when self._value_clipping > 0.
debug_summaries: True if debug summaries should be created.
training: Whether this loss is going to be used for training.
Returns:
value_estimation_loss: A scalar value_estimation_loss loss.
Raises:
ValueError: If old_value_predictions was not passed in, but value clipping
was performed.
"""
observation = time_steps.observation
if debug_summaries and not tf.config.list_logical_devices('TPU'):
observation_list = tf.nest.flatten(observation)
show_observation_index = len(observation_list) != 1
for i, single_observation in enumerate(observation_list):
observation_name = ('observations_{}'.format(i)
if show_observation_index else 'observations')
tf.compat.v2.summary.histogram(
name=observation_name,
data=single_observation,
step=self.train_step_counter)
batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
value_state = self._collect_policy.get_initial_value_state(batch_size)
value_preds, _ = self._collect_policy.apply_value_network(
time_steps.observation,
time_steps.step_type,
value_state=value_state,
training=training)
value_estimation_error = tf.math.squared_difference(returns, value_preds)
if self._value_clipping > 0:
if old_value_predictions is None:
raise ValueError(
'old_value_predictions is None but needed for value clipping.')
clipped_value_preds = old_value_predictions + tf.clip_by_value(
value_preds - old_value_predictions, -self._value_clipping,
self._value_clipping)
clipped_value_estimation_error = tf.math.squared_difference(
returns, clipped_value_preds)
value_estimation_error = tf.maximum(value_estimation_error,
clipped_value_estimation_error)
if self._aggregate_losses_across_replicas:
value_estimation_loss = (
common.aggregate_losses(
per_example_loss=value_estimation_error,
sample_weight=weights).total_loss * self._value_pred_loss_coef)
else:
value_estimation_loss = tf.math.reduce_mean(
value_estimation_error * weights) * self._value_pred_loss_coef
value_estimation_loss_per_sample = tf.reduce_mean(value_estimation_error,
axis=0)
if debug_summaries:
tf.compat.v2.summary.scalar(
name='value_pred_avg',
data=tf.reduce_mean(input_tensor=value_preds),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='value_actual_avg',
data=tf.reduce_mean(input_tensor=returns),
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='value_estimation_loss',
data=value_estimation_loss,
step=self.train_step_counter)
if not tf.config.list_logical_devices('TPU'):
tf.compat.v2.summary.histogram(
name='value_preds', data=value_preds, step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='value_estimation_error',
data=value_estimation_error,
step=self.train_step_counter)
if self._check_numerics:
value_estimation_loss = tf.debugging.check_numerics(
value_estimation_loss, 'value_estimation_loss')
value_estimation_loss_per_sample = tf.debugging.check_numerics(
value_estimation_loss_per_sample, 'value_estimation_loss_per_sample')
return value_estimation_loss, value_estimation_loss_per_sample
def policy_gradient_loss(
self,
time_steps,
actions,
sample_action_log_probs,
advantages,
current_policy_distribution,
weights,
debug_summaries = False):
"""Create tensor for policy gradient loss.
All tensors should have a single batch dimension.
Args:
time_steps: TimeSteps with observations for each timestep.
actions: Tensor of actions for timesteps, aligned on index.
sample_action_log_probs: Tensor of sample probability of each action.
advantages: Tensor of advantage estimate for each timestep, aligned on
index. Works better when advantage estimates are normalized.
current_policy_distribution: The policy distribution, evaluated on all
time_steps.
weights: Optional scalar or element-wise (per-batch-entry) importance
weights. Includes a mask for invalid timesteps.
debug_summaries: True if debug summaries should be created.
Returns:
policy_gradient_loss: A tensor that will contain policy gradient loss for
the on-policy experience.
"""
nest_utils.assert_same_structure(time_steps, self.time_step_spec)
action_log_prob = common.log_probability(current_policy_distribution,
actions, self._action_spec)
action_log_prob = tf.cast(action_log_prob, tf.float32)
if self._log_prob_clipping > 0.0:
action_log_prob = tf.clip_by_value(action_log_prob,
-self._log_prob_clipping,
self._log_prob_clipping)
if self._check_numerics:
action_log_prob = tf.debugging.check_numerics(action_log_prob,
'action_log_prob')
# Prepare both clipped and unclipped importance ratios.
importance_ratio = tf.exp(action_log_prob - sample_action_log_probs)
importance_ratio_clipped = tf.clip_by_value(
importance_ratio, 1 - self._importance_ratio_clipping,
1 + self._importance_ratio_clipping)
if self._check_numerics:
importance_ratio = tf.debugging.check_numerics(importance_ratio,
'importance_ratio')
if self._importance_ratio_clipping > 0.0:
importance_ratio_clipped = tf.debugging.check_numerics(
importance_ratio_clipped, 'importance_ratio_clipped')
# Pessimistically choose the minimum objective value for clipped and
# unclipped importance ratios.
per_timestep_objective = importance_ratio * advantages
per_timestep_objective_clipped = importance_ratio_clipped * advantages
per_timestep_objective_min = tf.minimum(per_timestep_objective,
per_timestep_objective_clipped)
if self._importance_ratio_clipping > 0.0:
policy_gradient_loss = -per_timestep_objective_min
else:
policy_gradient_loss = -per_timestep_objective
policy_gradient_loss_per_sample = tf.reduce_mean(policy_gradient_loss,
axis=0)
if self._aggregate_losses_across_replicas:
policy_gradient_loss = common.aggregate_losses(
per_example_loss=policy_gradient_loss,
sample_weight=weights).total_loss
else:
policy_gradient_loss = tf.math.reduce_mean(policy_gradient_loss * weights)
if debug_summaries:
if self._importance_ratio_clipping > 0.0:
clip_fraction = tf.reduce_mean(
input_tensor=tf.cast(
tf.greater(
tf.abs(importance_ratio -
1.0), self._importance_ratio_clipping), tf.float32))
tf.compat.v2.summary.scalar(
name='clip_fraction',
data=clip_fraction,
step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='importance_ratio_mean',
data=tf.reduce_mean(input_tensor=importance_ratio),
step=self.train_step_counter)
entropy = common.entropy(current_policy_distribution, self.action_spec)
tf.compat.v2.summary.scalar(
name='policy_entropy_mean',
data=tf.reduce_mean(input_tensor=entropy),
step=self.train_step_counter)
if not tf.config.list_logical_devices('TPU'):
tf.compat.v2.summary.histogram(
name='action_log_prob',
data=action_log_prob,
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='action_log_prob_sample',
data=sample_action_log_probs,
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='importance_ratio',
data=importance_ratio,
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='importance_ratio_clipped',
data=importance_ratio_clipped,
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='per_timestep_objective',
data=per_timestep_objective,
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='per_timestep_objective_clipped',
data=per_timestep_objective_clipped,
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='per_timestep_objective_min',
data=per_timestep_objective_min,
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='policy_entropy', data=entropy, step=self.train_step_counter)
for i, (single_action, single_distribution) in enumerate(
zip(
tf.nest.flatten(self.action_spec),
tf.nest.flatten(current_policy_distribution))):
# Categorical distribution (used for discrete actions) doesn't have a
# mean.
distribution_index = '_{}'.format(i) if i > 0 else ''
if not tensor_spec.is_discrete(single_action):
tf.compat.v2.summary.histogram(
name='actions_distribution_mean' + distribution_index,
data=single_distribution.mean(),
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='actions_distribution_stddev' + distribution_index,
data=single_distribution.stddev(),
step=self.train_step_counter)
tf.compat.v2.summary.histogram(
name='policy_gradient_loss',
data=policy_gradient_loss,
step=self.train_step_counter)
if self._check_numerics:
policy_gradient_loss = tf.debugging.check_numerics(
policy_gradient_loss, 'policy_gradient_loss')
policy_gradient_loss_per_sample = tf.debugging.check_numerics(
policy_gradient_loss_per_sample, 'policy_gradient_loss_per_sample')
return policy_gradient_loss, policy_gradient_loss_per_sample
def entropy_regularization_loss(
self,
time_steps,
current_policy_distribution,
weights,
debug_summaries = False):
"""Create regularization loss tensor based on agent parameters."""
if self._entropy_regularization > 0:
nest_utils.assert_same_structure(time_steps, self.time_step_spec)
with tf.name_scope('entropy_regularization'):
entropy = tf.cast(
common.entropy(current_policy_distribution, self.action_spec),
tf.float32)
if self._aggregate_losses_across_replicas:
entropy_reg_loss = common.aggregate_losses(
per_example_loss=-entropy,
sample_weight=weights).total_loss * self._entropy_regularization
else:
entropy_reg_loss = (
tf.math.reduce_mean(-entropy * weights) *
self._entropy_regularization)
if self._check_numerics:
entropy_reg_loss = tf.debugging.check_numerics(
entropy_reg_loss, 'entropy_reg_loss')
if debug_summaries and not tf.config.list_logical_devices('TPU'):
tf.compat.v2.summary.histogram(
name='entropy_reg_loss',
data=entropy_reg_loss,
step=self.train_step_counter)
else:
raise ValueError('This is not allowed, this is handled at loss level.')
entropy_reg_loss_per_sample = -entropy
if self._check_numerics:
entropy_reg_loss_per_sample = tf.debugging.check_numerics(
entropy_reg_loss_per_sample, 'entropy_reg_loss_per_sample')
return entropy_reg_loss, entropy_reg_loss_per_sample
class ReverbFixedLengthSequenceObserver(reverb_utils.ReverbAddTrajectoryObserver
):
"""Reverb fixed length sequence observer.
This is a specialized observer similar to ReverbAddTrajectoryObserver but each
sequence contains a fixed number of steps and can span multiple episodes. This
implementation is consistent with (Schulman, 17).
**Note**: Counting of steps in drivers does not include boundary steps. To
guarantee only 1 item is pushed to the replay when collecting n steps with a
`sequence_length` of n make sure to set the `stride_length`.
"""
def __call__(self, trajectory):
"""Writes the trajectory into the underlying replay buffer.
Allows trajectory to be a flattened trajectory. No batch dimension allowed.
Args:
trajectory: The trajectory to be written which could be (possibly nested)
trajectory object or a flattened version of a trajectory. It assumes
there is *no* batch dimension.
"""
self._writer.append(trajectory)
self._cached_steps += 1
self._write_cached_steps()
@gin.configurable
def train_eval(
root_dir,
env_name='HalfCheetah-v2',
# Training params
num_iterations=1600,
actor_fc_layers=(64, 64),
value_fc_layers=(64, 64),
learning_rate=3e-4,
collect_sequence_length=2048,
minibatch_size=64,
num_epochs=10,
# Agent params
importance_ratio_clipping=0.2,
lambda_value=0.95,
discount_factor=0.99,
entropy_regularization=0.,
value_pred_loss_coef=0.5,
use_gae=True,
use_td_lambda_return=True,
gradient_clipping=0.5,
value_clipping=None,
# Replay params
reverb_port=None,
replay_capacity=10000,
# Others
policy_save_interval=5000,
summary_interval=1000,
eval_interval=10000,
eval_episodes=100,
debug_summaries=False,
summarize_grads_and_vars=False,
train_mode_actor='dense',
train_mode_value='dense',
sparse_output_layer=True,
weight_decay=0.0,
width=1.0):
"""Trains and evaluates DQN."""
logging.info('Actor fc layer params: %s', actor_fc_layers)
logging.info('Value fc layer params: %s', value_fc_layers)
logging.info('Policy save interval: %s', policy_save_interval)
logging.info('Eval interval: %s', eval_interval)
logging.info('Environment name: %s', env_name)
logging.info('Learning rate: %s', learning_rate)
logging.info('Num iterations: %s', num_iterations)
logging.info('Sparse output layer: %s', sparse_output_layer)
logging.info('Train mode actor: %s', train_mode_actor)
logging.info('Train mode value: %s', train_mode_value)
logging.info('Width: %s', width)
logging.info('Weight decay: %s', weight_decay)
if FLAGS.is_mujoco:
collect_env = suite_mujoco.load(env_name)
eval_env = suite_mujoco.load(env_name)
logging.info('Loaded Mujoco environment %s', env_name)
elif FLAGS.is_classic:
collect_env = suite_gym.load(env_name)
eval_env = suite_gym.load(env_name)
logging.info('Loaded Classic control environment %s', env_name)
else:
raise ValueError('Environment init for Atari not supported yet.')
num_environments = 1
observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
spec_utils.get_tensor_specs(collect_env))
observation_tensor_spec = tf.TensorSpec(
dtype=tf.float32, shape=observation_tensor_spec.shape)
train_step = train_utils.create_train_step()
if FLAGS.is_classic:
actor_net_constructor = sparse_ppo_discrete_actor_network.PPODiscreteActorNetwork
else:
actor_net_constructor = sparse_ppo_actor_network.PPOActorNetwork
actor_net_builder = actor_net_constructor(
is_sparse=train_mode_actor == 'sparse',
sparse_output_layer=sparse_output_layer,
weight_decay=0,
width=width)
actor_net = actor_net_builder.create_sequential_actor_net(
actor_fc_layers, action_tensor_spec,
input_dim=time_step_tensor_spec.observation.shape[0])
value_net = sparse_value_network.ValueNetwork(
observation_tensor_spec,
fc_layer_params=value_fc_layers,
kernel_initializer=tf.keras.initializers.Orthogonal(),
is_sparse=train_mode_value == 'sparse',
sparse_output_layer=sparse_output_layer,
weight_decay=0,
width=width)
logging.info('Train eval: weight decay %.5f.', weight_decay)
current_iteration = tf.Variable(0, dtype=tf.int64)
def learning_rate_fn():
# Linearly decay the learning rate.
return learning_rate * (1 - current_iteration / num_iterations)
agent = SparsePPOAgent(
time_step_tensor_spec,
action_tensor_spec,
optimizer=tf.compat.v1.train.AdamOptimizer(
learning_rate=learning_rate_fn, epsilon=1e-5),
actor_net=actor_net,
value_net=value_net,
importance_ratio_clipping=importance_ratio_clipping,
lambda_value=lambda_value,
discount_factor=discount_factor,
entropy_regularization=entropy_regularization,
value_pred_loss_coef=value_pred_loss_coef,
policy_l2_reg=weight_decay,
value_function_l2_reg=weight_decay,
shared_vars_l2_reg=weight_decay,
# This is a legacy argument for the number of times we repeat the data
# inside of the train function, incompatible with mini batch learning.
# We set the epoch number from the replay buffer and tf.Data instead.
num_epochs=1,
use_gae=use_gae,
use_td_lambda_return=use_td_lambda_return,
gradient_clipping=gradient_clipping,
value_clipping=value_clipping,
compute_value_and_advantage_in_train=False,
# Skips updating normalizers in the agent, as it's handled in the learner.
update_normalizers_in_train=False,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=train_step)
agent.initialize()
reverb_server = reverb.Server(
[
reverb.Table( # Replay buffer storing experience for training.
name='training_table',
sampler=reverb.selectors.Fifo(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
max_size=replay_capacity,
max_times_sampled=1,
),
reverb.Table( # Replay buffer storing experience for normalization.
name='normalization_table',
sampler=reverb.selectors.Fifo(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
max_size=replay_capacity,
max_times_sampled=1,
)
],
port=reverb_port)
# Create the replay buffer.
reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer(
agent.collect_data_spec,
sequence_length=collect_sequence_length,
table_name='training_table',
server_address='localhost:{}'.format(reverb_server.port),
# The only collected sequence is used to populate the batches.
max_cycle_length=1,
num_workers_per_iterator=1,
max_samples_per_stream=1,
rate_limiter_timeout_ms=1000)
reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer(
agent.collect_data_spec,
sequence_length=collect_sequence_length,
table_name='normalization_table',
server_address='localhost:{}'.format(reverb_server.port),
# The only collected sequence is used to populate the batches.
max_cycle_length=1,
num_workers_per_iterator=1,
max_samples_per_stream=1,
rate_limiter_timeout_ms=1000)
rb_observer = ReverbFixedLengthSequenceObserver(
reverb_replay_train.py_client, ['training_table', 'normalization_table'],
sequence_length=collect_sequence_length,
stride_length=collect_sequence_length)
saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
collect_env_step_metric = py_metrics.EnvironmentSteps()
learning_triggers = [
triggers.PolicySavedModelTrigger(
saved_model_dir,
agent,
train_step,
interval=policy_save_interval,
metadata_metrics={
triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric
}),
triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval),
]
def training_dataset_fn():
return reverb_replay_train.as_dataset(
sample_batch_size=num_environments,
sequence_preprocess_fn=agent.preprocess_sequence)
def normalization_dataset_fn():
return reverb_replay_normalization.as_dataset(
sample_batch_size=num_environments,
sequence_preprocess_fn=agent.preprocess_sequence)
agent_learner = ppo_learner.PPOLearner(
root_dir,
train_step,
agent,
experience_dataset_fn=training_dataset_fn,
normalization_dataset_fn=normalization_dataset_fn,
num_samples=1,
summary_interval=10,
num_epochs=num_epochs,
minibatch_size=minibatch_size,
shuffle_buffer_size=collect_sequence_length,
triggers=learning_triggers)
tf_collect_policy = agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
tf_collect_policy, use_tf_function=True)
collect_actor = actor.Actor(
collect_env,
collect_policy,
train_step,
steps_per_run=collect_sequence_length,
observers=[rb_observer, collect_env_step_metric],
metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric],
reference_metrics=[collect_env_step_metric],
summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
summary_interval=summary_interval)
eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
agent.policy, use_tf_function=True)
average_returns = []
if eval_interval:
logging.info('Intial evaluation.')
eval_actor = actor.Actor(
eval_env,
eval_greedy_policy,
train_step,
metrics=actor.eval_metrics(eval_episodes),
reference_metrics=[collect_env_step_metric],
summary_dir=os.path.join(root_dir, 'eval'),
episodes_per_run=eval_episodes)
eval_actor.run_and_log()
for metric in eval_actor.metrics:
if isinstance(metric, py_metrics.AverageReturnMetric):
average_returns.append(metric._buffer.mean())
logging.info('Training on %s', env_name)
last_eval_step = 0
for i in range(num_iterations):
logging.info('collect_actor.run')
collect_actor.run()
# Reset the reverb observer to make sure the data collected is flushed and
# written to the RB.
# At this point, there a small number of steps left in the cache because the
# actor does not count a boundary step as a step, whereas it still gets
# added to Reverb for training. We throw away those extra steps without
# padding to align with the paper implementation which never collects them
# in the first place.
logging.info('rb_observer.reset')
rb_observer.reset(write_cached_steps=False)
logging.info('reverb_replay_normalization.size: %d',
reverb_replay_normalization.get_table_info().current_size)
logging.info('reverb_replay_train.size: %d',
reverb_replay_train.get_table_info().current_size)
logging.info('agent_learner.run')
agent_learner.run()
logging.info('reverb_replay_train.clear')
reverb_replay_train.clear()
logging.info('reverb_replay_normalization.clear')
reverb_replay_normalization.clear()
current_iteration.assign_add(1)
# Eval only if `eval_interval` has been set. Then, eval if the current train
# step is equal or greater than the `last_eval_step` + `eval_interval` or if
# this is the last iteration. This logic exists because agent_learner.run()
# does not return after every train step.
if (eval_interval and
(agent_learner.train_step_numpy >= eval_interval + last_eval_step
or i == num_iterations - 1)):
logging.info('Evaluating.')
eval_actor.run_and_log()
last_eval_step = agent_learner.train_step_numpy
for metric in eval_actor.metrics:
if isinstance(metric, py_metrics.AverageReturnMetric):
average_returns.append(metric._buffer.mean())
# Log last section of evaluation scores for the final metric.
idx = int(FLAGS.average_last_fraction * len(average_returns))
avg_return = np.mean(average_returns[-idx:])
logging.info('Step %d, Average Return: %f', collect_env_step_metric.result(),
avg_return)
rb_observer.close()
reverb_server.stop()
def main(_):
tf.config.experimental_run_functions_eagerly(False)
logging.set_verbosity(logging.INFO)
tf.enable_v2_behavior()
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)
logging.info('Gin bindings: %s', FLAGS.gin_bindings)
train_eval(
FLAGS.root_dir,
reverb_port=FLAGS.reverb_port)
if __name__ == '__main__':
flags.mark_flag_as_required('root_dir')
multiprocessing.handle_main(functools.partial(app.run, main))
================================================
FILE: rigl/rl/tfagents/sac_train_eval.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Train and Eval SAC.
"""
import functools
import os
from absl import app
from absl import flags
from absl import logging
import gin
import numpy as np
import reverb
from rigl.rigl_tf2 import mask_updaters
from rigl.rl import sparse_utils
from rigl.rl.tfagents import sparse_tanh_normal_projection_network
from rigl.rl.tfagents import tf_sparse_utils
import tensorflow as tf
from tf_agents.agents import tf_agent
from tf_agents.agents.sac import sac_agent
from tf_agents.environments import suite_mujoco
from tf_agents.keras_layers import inner_reshape
from tf_agents.metrics import py_metrics
from tf_agents.networks import nest_map
from tf_agents.networks import sequential
from tf_agents.policies import greedy_policy
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_py_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.train import actor
from tf_agents.train import learner
from tf_agents.train import triggers
from tf_agents.train.utils import spec_utils
from tf_agents.train.utils import strategy_utils
from tf_agents.train.utils import train_utils
from tf_agents.utils import common
from tf_agents.utils import object_identity
FLAGS = flags.FLAGS
flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_integer(
'reverb_port', None,
'Port for reverb server, if None, use a randomly chosen unused port.')
flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')
flags.DEFINE_multi_string('gin_bindings', [], 'Gin binding parameters.')
# Env params
flags.DEFINE_bool('is_atari', False, 'Whether the env is an atari game.')
flags.DEFINE_bool('is_mujoco', False, 'Whether the env is a mujoco game.')
flags.DEFINE_bool('is_classic', False,
'Whether the env is a classic control game.')
flags.DEFINE_float(
'average_last_fraction', 0.1,
'Tells what fraction latest evaluation scores are averaged. This is used'
' to reduce variance.')
dense = functools.partial(
tf.keras.layers.Dense,
activation=tf.keras.activations.relu,
kernel_initializer='glorot_uniform')
def create_fc_layers(layer_units, width=1.0, weight_decay=0):
layers = [
dense(tf_sparse_utils.scale_width(num_units, width=width),
kernel_regularizer=tf.keras.regularizers.L2(weight_decay))
for num_units in layer_units
]
return layers
def create_identity_layer():
return tf.keras.layers.Lambda(lambda x: x)
def create_sequential_critic_network(obs_fc_layer_units,
action_fc_layer_units,
joint_fc_layer_units,
input_dim,
is_sparse = False,
width = 1.0,
weight_decay = 0.0,
sparse_output_layer = True):
"""Create a sequential critic network."""
# Split the inputs into observations and actions.
def split_inputs(inputs):
return {'observation': inputs[0], 'action': inputs[1]}
# Create an observation network layers.
obs_network_layers = (
create_fc_layers(obs_fc_layer_units, width=width,
weight_decay=weight_decay)
if obs_fc_layer_units else None)
# Create an action network layers.
action_network_layers = (
create_fc_layers(action_fc_layer_units, width=width,
weight_decay=weight_decay)
if action_fc_layer_units else None)
# Create a joint network layers.
joint_network_layers = (
create_fc_layers(joint_fc_layer_units, width=width,
weight_decay=weight_decay)
if joint_fc_layer_units else None)
# Final layer.
value_layer = tf.keras.layers.Dense(
1, kernel_initializer='glorot_uniform',
kernel_regularizer=tf.keras.regularizers.L2(weight_decay))
layer_list = [obs_network_layers, action_network_layers,
joint_network_layers]
if is_sparse:
# We need to process all-layers together to distribute sparsities for
# pruning.
all_layers = []
for layers in layer_list:
if layers is not None:
all_layers += layers
if sparse_output_layer:
all_layers.append(value_layer)
new_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)
value_layer = new_layers[-1]
new_layers = new_layers[:-1]
else:
new_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)
# Split back the layers to their own groups
c_index = 0
new_layer_list = []
for layers in layer_list:
if layers is None:
new_layer_list.append(None)
else:
new_layer_list.append(new_layers[c_index:len(layers)])
c_index += len(layers)
layer_list = new_layer_list
# Convert layer_list to sequential or identity lambdas:
module_list = [create_identity_layer() if layers is None else
sequential.Sequential(layers)
for layers in layer_list]
obs_network, action_network, joint_network = module_list
return sequential.Sequential([
tf.keras.layers.Lambda(split_inputs),
nest_map.NestMap({
'observation': obs_network,
'action': action_network
}),
nest_map.NestFlatten(),
tf.keras.layers.Concatenate(),
joint_network,
value_layer,
inner_reshape.InnerReshape(current_shape=[1], new_shape=[])
], name='sequential_critic')
class _TanhNormalProjectionNetworkWrapper(
sparse_tanh_normal_projection_network.SparseTanhNormalProjectionNetwork):
"""Wrapper to pass predefined `outer_rank` to underlying projection net."""
def __init__(self, sample_spec, predefined_outer_rank=1, weight_decay=0.0):
super(_TanhNormalProjectionNetworkWrapper, self).__init__(
sample_spec=sample_spec,
weight_decay=weight_decay)
self.predefined_outer_rank = predefined_outer_rank
def call(self, inputs, network_state=(), **kwargs):
kwargs['outer_rank'] = self.predefined_outer_rank
if 'step_type' in kwargs:
del kwargs['step_type']
return super(_TanhNormalProjectionNetworkWrapper,
self).call(inputs, **kwargs)
def create_sequential_actor_network(actor_fc_layers,
action_tensor_spec,
input_dim,
is_sparse = False,
width = 1.0,
weight_decay = 0.0,
sparse_output_layer = True):
"""Create a sequential actor network."""
def tile_as_nest(non_nested_output):
return tf.nest.map_structure(lambda _: non_nested_output,
action_tensor_spec)
dense_layers = [
dense(tf_sparse_utils.scale_width(num_units, width=width),
kernel_regularizer=tf.keras.regularizers.L2(weight_decay))
for num_units in actor_fc_layers
]
tanh_normal_projection_network_fn = functools.partial(
_TanhNormalProjectionNetworkWrapper,
weight_decay=weight_decay)
last_layer = nest_map.NestMap(
tf.nest.map_structure(tanh_normal_projection_network_fn,
action_tensor_spec))
if is_sparse:
if sparse_output_layer:
dense_layers.append(last_layer.layers[0]._projection_layer)
new_layers = tf_sparse_utils.wrap_all_layers(dense_layers, input_dim)
dense_layers = new_layers[:-1]
last_layer.layers[0]._projection_layer = new_layers[-1]
else:
dense_layers = tf_sparse_utils.wrap_all_layers(dense_layers, input_dim)
return sequential.Sequential(
dense_layers +
[tf.keras.layers.Lambda(tile_as_nest)] + [last_layer])
@gin.configurable
class SparseSacAgent(sac_agent.SacAgent):
"""Wrapped DqnAgent that supports sparse training."""
def __init__(self,
time_step_spec,
action_spec,
*args,
actor_sparsity=None,
critic_sparsity=None,
**kwargs):
super().__init__(time_step_spec,
action_spec,
*args,
**kwargs)
# Pruning layer requires the pruning_step to be >1 during forward pass.
tf_sparse_utils.update_prune_step(
self._critic_network_1, self.train_step_counter + 1)
tf_sparse_utils.update_prune_step(
self._critic_network_2, self.train_step_counter + 1)
tf_sparse_utils.update_prune_step(
self._actor_network, self.train_step_counter + 1)
if critic_sparsity is not None:
_ = sparse_utils.init_masks(self._critic_network_1,
sparsity=critic_sparsity)
_ = sparse_utils.init_masks(self._critic_network_2,
sparsity=critic_sparsity)
else: # Uses init_mask.sparsity value. Either the default or set via gin.
_ = sparse_utils.init_masks(self._critic_network_1)
_ = sparse_utils.init_masks(self._critic_network_2)
if actor_sparsity is not None:
_ = sparse_utils.init_masks(self._actor_network,
sparsity=actor_sparsity)
else:
_ = sparse_utils.init_masks(self._actor_network)
net_observation_spec = time_step_spec.observation
critic_spec = (net_observation_spec, action_spec)
self._target_critic_network_1 = (
common.maybe_copy_target_network_with_checks(
self._critic_network_1,
None,
input_spec=critic_spec,
name='TargetCriticNetwork1'))
self._target_critic_network_1 = (
common.maybe_copy_target_network_with_checks(
self._critic_network_2,
None,
input_spec=critic_spec,
name='TargetCriticNetwork2'))
def critic_loss_fn(experience, weights):
# The following is just to fit to the existing API.
transition = self._as_transition(experience)
time_steps, policy_steps, next_time_steps = transition
actions = policy_steps.action
return self._critic_loss_weight * self.critic_loss(
time_steps,
actions,
next_time_steps,
td_errors_loss_fn=self._td_errors_loss_fn,
gamma=self._gamma,
reward_scale_factor=self._reward_scale_factor,
weights=weights,
training=True)
def actor_loss_fn(experience, weights):
# The following is just to fit to the existing API.
transition = self._as_transition(experience)
time_steps, _, _ = transition
return self._actor_loss_weight*self.actor_loss(
time_steps, weights=weights, training=True)
# Create mask updater if doesn't exists
self._mask_updater_critic_1 = mask_updaters.get_mask_updater(
self._critic_network_1, self._critic_optimizer, critic_loss_fn)
self._mask_updater_critic_2 = mask_updaters.get_mask_updater(
self._critic_network_2, self._critic_optimizer, critic_loss_fn)
self._mask_updater_actor = mask_updaters.get_mask_updater(
self._actor_network, self._actor_optimizer, actor_loss_fn)
def _train(self, experience, weights):
"""Returns a train op to update the agent's networks.
This method trains with the provided batched experience.
Args:
experience: A time-stacked trajectory object.
weights: Optional scalar or elementwise (per-batch-entry) importance
weights.
Returns:
A train_op.
Raises:
ValueError: If optimizers are None and no default value was provided to
the constructor.
"""
tf.summary.experimental.set_step(self.train_step_counter)
transition = self._as_transition(experience)
time_steps, policy_steps, next_time_steps = transition
actions = policy_steps.action
trainable_critic_variables = list(object_identity.ObjectIdentitySet(
self._critic_network_1.trainable_variables +
self._critic_network_2.trainable_variables))
with tf.GradientTape(watch_accessed_variables=False) as tape:
assert trainable_critic_variables, ('No trainable critic variables to '
'optimize.')
tape.watch(trainable_critic_variables)
critic_loss = self._critic_loss_weight*self.critic_loss(
time_steps,
actions,
next_time_steps,
td_errors_loss_fn=self._td_errors_loss_fn,
gamma=self._gamma,
reward_scale_factor=self._reward_scale_factor,
weights=weights,
training=True)
tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')
critic_grads = tape.gradient(critic_loss, trainable_critic_variables)
self._apply_gradients(critic_grads, trainable_critic_variables,
self._critic_optimizer)
trainable_actor_variables = self._actor_network.trainable_variables
with tf.GradientTape(watch_accessed_variables=False) as tape:
assert trainable_actor_variables, ('No trainable actor variables to '
'optimize.')
tape.watch(trainable_actor_variables)
actor_loss = self._actor_loss_weight*self.actor_loss(
time_steps, weights=weights, training=True)
tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')
actor_grads = tape.gradient(actor_loss, trainable_actor_variables)
self._apply_gradients(actor_grads, trainable_actor_variables,
self._actor_optimizer)
# BEGIN sparse training mask update
# We use the lastest set of gradients to update the masks for sparse
# training. Note, we do this before gradient clipping.
# Define helper methods.
def _mask_update_step(mask_updater, updater_name):
mask_updater.set_validation_data(experience, weights)
mask_updater.update(self.train_step_counter)
with tf.name_scope('Drop_fraction/'):
tf.summary.scalar(
name=f'{updater_name}',
data=mask_updater.last_drop_fraction)
mask_update_step_critic_1 = functools.partial(_mask_update_step,
self._mask_updater_critic_1,
'critic_1')
mask_update_step_critic_2 = functools.partial(_mask_update_step,
self._mask_updater_critic_2,
'critic_2')
mask_update_step_actor = functools.partial(_mask_update_step,
self._mask_updater_actor,
'actor')
# Log sparsities every 1000 train steps.
def _log_sparsities():
tf_sparse_utils.log_sparsities(self._critic_network_1, 'critic_1')
tf_sparse_utils.log_sparsities(self._critic_network_2, 'critic_2')
tf_sparse_utils.log_sparsities(self._actor_network, 'actor')
tf_sparse_utils.log_total_params(
[self._critic_network_1,
self._critic_network_2,
self._actor_network])
tf.cond(self.train_step_counter % 1000 == 0, _log_sparsities, lambda: None)
# Update critics
if self._mask_updater_critic_1 is not None:
is_update_critic_1 = self._mask_updater_critic_1.is_update_iter(
self.train_step_counter)
tf.cond(is_update_critic_1, mask_update_step_critic_1, lambda: None)
if self._mask_updater_critic_2 is not None:
is_update_critic_2 = self._mask_updater_critic_2.is_update_iter(
self.train_step_counter)
tf.cond(is_update_critic_2, mask_update_step_critic_2, lambda: None)
# Update actor
if self._mask_updater_actor is not None:
is_update_actor = self._mask_updater_actor.is_update_iter(
self.train_step_counter)
tf.cond(is_update_actor, mask_update_step_actor, lambda: None)
# END sparse training mask update
alpha_variable = [self._log_alpha]
with tf.GradientTape(watch_accessed_variables=False) as tape:
assert alpha_variable, 'No alpha variable to optimize.'
tape.watch(alpha_variable)
alpha_loss = self._alpha_loss_weight * self.alpha_loss(
time_steps, weights=weights, training=True)
tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')
alpha_grads = tape.gradient(alpha_loss, alpha_variable)
self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer)
with tf.name_scope('Losses'):
tf.compat.v2.summary.scalar(
name='critic_loss', data=critic_loss, step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='actor_loss', data=actor_loss, step=self.train_step_counter)
tf.compat.v2.summary.scalar(
name='alpha_loss', data=alpha_loss, step=self.train_step_counter)
self.train_step_counter.assign_add(1)
self._update_target()
total_loss = critic_loss + actor_loss + alpha_loss
extra = sac_agent.SacLossInfo(
critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss)
return tf_agent.LossInfo(loss=total_loss, extra=extra)
@gin.configurable
def train_eval(
root_dir,
strategy,
env_name='HalfCheetah-v2',
# Training params
initial_collect_steps=10000,
num_iterations=1000000,
actor_fc_layers=(256, 256),
critic_obs_fc_layers=None,
critic_action_fc_layers=None,
critic_joint_fc_layers=(256, 256),
# Agent params
batch_size=256,
actor_learning_rate=3e-4,
critic_learning_rate=3e-4,
alpha_learning_rate=3e-4,
gamma=0.99,
target_update_tau=0.005,
target_update_period=1,
reward_scale_factor=0.1,
# Replay params
reverb_port=None,
replay_capacity=1000000,
# Others
policy_save_interval=10000,
replay_buffer_save_interval=100000,
eval_interval=10000,
eval_episodes=30,
debug_summaries=False,
summarize_grads_and_vars=False,
sparse_output_layer = False,
width = 1.0,
train_mode_actor = 'dense',
train_mode_value = 'dense',
weight_decay = 0.0,
actor_critic_sparsities_str = '',
actor_critic_widths_str = ''):
"""Trains and evaluates SAC."""
assert FLAGS.is_mujoco
if actor_critic_widths_str:
actor_critic_widths = [float(s) for s in actor_critic_widths_str.split('_')]
width_actor = actor_critic_widths[0]
width_value = actor_critic_widths[1]
else:
width_actor = width
width_value = width
if actor_critic_sparsities_str:
actor_critic_sparsities = [
float(s) for s in actor_critic_sparsities_str.split('_')
]
else:
# init_mask.sparsity value will be used. Either the default or set via gin.
actor_critic_sparsities = [None, None]
logging.info('Training SAC on: %s', env_name)
logging.info('SAC params: train mode actor: %s', train_mode_actor)
logging.info('SAC params: train mode value: %s', train_mode_value)
logging.info('SAC params: sparse_output_layer: %s', sparse_output_layer)
logging.info('SAC params: width: %s', width)
logging.info('SAC params: actor_critic_widths_str: %s',
actor_critic_widths_str)
logging.info('SAC params: width_actor: %s', width_actor)
logging.info('SAC params: width_value: %s', width_value)
logging.info('SAC params: weight_decay: %s', weight_decay)
logging.info('SAC params: actor_critic_sparsities_str %s type %s',
actor_critic_sparsities_str, type(actor_critic_sparsities_str))
logging.info('SAC params: actor_sparsity: %s', actor_critic_sparsities[0])
logging.info('SAC params: critic_sparsity: %s', actor_critic_sparsities[1])
collect_env = suite_mujoco.load(env_name)
eval_env = suite_mujoco.load(env_name)
_, action_tensor_spec, time_step_tensor_spec = (
spec_utils.get_tensor_specs(collect_env))
actor_net = create_sequential_actor_network(
actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec,
input_dim=time_step_tensor_spec.observation.shape[0],
is_sparse=(train_mode_actor == 'sparse'),
width=width_actor,
weight_decay=weight_decay,
sparse_output_layer=sparse_output_layer)
critic_input_dim = (
action_tensor_spec.shape[0] + time_step_tensor_spec.observation.shape[0])
critic_net = create_sequential_critic_network(
obs_fc_layer_units=critic_obs_fc_layers,
action_fc_layer_units=critic_action_fc_layers,
joint_fc_layer_units=critic_joint_fc_layers,
input_dim=critic_input_dim,
is_sparse=(train_mode_value == 'sparse'),
width=width_value,
weight_decay=weight_decay,
sparse_output_layer=sparse_output_layer)
with strategy.scope():
train_step = train_utils.create_train_step()
agent = SparseSacAgent(
time_step_spec=time_step_tensor_spec,
action_spec=action_tensor_spec,
actor_sparsity=actor_critic_sparsities[0],
critic_sparsity=actor_critic_sparsities[1],
actor_network=actor_net,
critic_network=critic_net,
actor_optimizer=tf.keras.optimizers.Adam(
learning_rate=actor_learning_rate),
critic_optimizer=tf.keras.optimizers.Adam(
learning_rate=critic_learning_rate),
alpha_optimizer=tf.keras.optimizers.Adam(
learning_rate=alpha_learning_rate),
target_update_tau=target_update_tau,
target_update_period=target_update_period,
td_errors_loss_fn=tf.math.squared_difference,
gamma=gamma,
reward_scale_factor=reward_scale_factor,
gradient_clipping=None,
debug_summaries=debug_summaries,
summarize_grads_and_vars=summarize_grads_and_vars,
train_step_counter=train_step)
agent.initialize()
table_name = 'uniform_table'
table = reverb.Table(
table_name,
max_size=replay_capacity,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1))
reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR,
learner.REPLAY_BUFFER_CHECKPOINT_DIR)
reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer(
path=reverb_checkpoint_dir)
reverb_server = reverb.Server([table],
port=reverb_port,
checkpointer=reverb_checkpointer)
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
agent.collect_data_spec,
sequence_length=2,
table_name=table_name,
local_server=reverb_server)
rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
reverb_replay.py_client,
table_name,
sequence_length=2,
stride_length=1)
def experience_dataset_fn():
return reverb_replay.as_dataset(
sample_batch_size=batch_size, num_steps=2).prefetch(50)
saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
env_step_metric = py_metrics.EnvironmentSteps()
learning_triggers = [
triggers.PolicySavedModelTrigger(
saved_model_dir,
agent,
train_step,
interval=policy_save_interval,
metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}),
triggers.ReverbCheckpointTrigger(
train_step,
interval=replay_buffer_save_interval,
reverb_client=reverb_replay.py_client),
triggers.StepPerSecondLogTrigger(train_step, interval=1000),
]
agent_learner = learner.Learner(
root_dir,
train_step,
agent,
experience_dataset_fn,
triggers=learning_triggers,
strategy=strategy)
random_policy = random_py_policy.RandomPyPolicy(
collect_env.time_step_spec(), collect_env.action_spec())
initial_collect_actor = actor.Actor(
collect_env,
random_policy,
train_step,
steps_per_run=initial_collect_steps,
observers=[rb_observer])
logging.info('Doing initial collect.')
initial_collect_actor.run()
tf_collect_policy = agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
tf_collect_policy, use_tf_function=True)
collect_actor = actor.Actor(
collect_env,
collect_policy,
train_step,
steps_per_run=1,
metrics=actor.collect_metrics(10),
summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
observers=[rb_observer, env_step_metric])
tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
tf_greedy_policy, use_tf_function=True)
eval_actor = actor.Actor(
eval_env,
eval_greedy_policy,
train_step,
episodes_per_run=eval_episodes,
metrics=actor.eval_metrics(eval_episodes),
summary_dir=os.path.join(root_dir, 'eval'),
)
average_returns = []
if eval_interval:
logging.info('Evaluating.')
eval_actor.run_and_log()
for metric in eval_actor.metrics:
if isinstance(metric, py_metrics.AverageReturnMetric):
average_returns.append(metric._buffer.mean())
logging.info('Training.')
for _ in range(num_iterations):
collect_actor.run()
agent_learner.run(iterations=1)
if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
logging.info('Evaluating.')
eval_actor.run_and_log()
for metric in eval_actor.metrics:
if isinstance(metric, py_metrics.AverageReturnMetric):
average_returns.append(metric._buffer.mean())
# Log last section of evaluation scores for the final metric.
idx = int(FLAGS.average_last_fraction * len(average_returns))
avg_return = np.mean(average_returns[-idx:])
logging.info('Step %d, Average Return: %f', env_step_metric.result(),
avg_return)
rb_observer.close()
reverb_server.stop()
def main(_):
tf.config.run_functions_eagerly(False)
logging.set_verbosity(logging.INFO)
tf.compat.v1.enable_v2_behavior()
strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)
logging.info('Gin bindings: %s', FLAGS.gin_bindings)
logging.info('# Gin-Config:\n %s', gin.config.operative_config_str())
train_eval(
FLAGS.root_dir,
strategy=strategy,
reverb_port=FLAGS.reverb_port)
if __name__ == '__main__':
flags.mark_flag_as_required('root_dir')
app.run(main)
================================================
FILE: rigl/rl/tfagents/sparse_encoding_network.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras Encoding Network.
Implements a network that will generate the following layers:
[optional]: preprocessing_layers # preprocessing_layers
[optional]: (Add | Concat(axis=-1) | ...) # preprocessing_combiner
[optional]: Conv2D # conv_layer_params
Flatten
[optional]: Dense # fc_layer_params
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import gin
from rigl.rl.tfagents import tf_sparse_utils
from six.moves import zip
import tensorflow as tf
from tf_agents.keras_layers import permanent_variable_rate_dropout
from tf_agents.networks import network
from tf_agents.networks import utils
from tf_agents.utils import nest_utils
CONV_TYPE_2D = '2d'
CONV_TYPE_1D = '1d'
def _copy_layer(layer):
"""Create a copy of a Keras layer with identical parameters.
The new layer will not share weights with the old one.
Args:
layer: An instance of `tf.keras.layers.Layer`.
Returns:
A new keras layer.
Raises:
TypeError: If `layer` is not a keras layer.
ValueError: If `layer` cannot be correctly cloned.
"""
if not isinstance(layer, tf.keras.layers.Layer):
raise TypeError('layer is not a keras layer: %s' % str(layer))
# pylint:disable=unidiomatic-typecheck
if type(layer) == tf.compat.v1.keras.layers.DenseFeatures:
raise ValueError('DenseFeatures V1 is not supported. '
'Use tf.compat.v2.keras.layers.DenseFeatures instead.')
if layer.built:
logging.warning(
'Beware: Copying a layer that has already been built: \'%s\'. '
'This can lead to subtle bugs because the original layer\'s weights '
'will not be used in the copy.', layer.name)
# Get a fresh copy so we don't modify an incoming layer in place. Weights
# will not be shared.
return type(layer).from_config(layer.get_config())
@gin.configurable
class EncodingNetwork(network.Network):
"""Feed Forward network with CNN and FNN layers."""
def __init__(self,
input_tensor_spec,
preprocessing_layers=None,
preprocessing_combiner=None,
conv_layer_params=None,
fc_layer_params=None,
dropout_layer_params=None,
activation_fn=tf.keras.activations.relu,
weight_decay_params=None,
kernel_initializer=None,
batch_squash=True,
dtype=tf.float32,
name='EncodingNetwork',
conv_type=CONV_TYPE_2D,
width=1.0):
"""Creates an instance of `EncodingNetwork`.
Network supports calls with shape outer_rank + input_tensor_spec.shape. Note
outer_rank must be at least 1.
For example an input tensor spec with shape `(2, 3)` will require
inputs with at least a batch size, the input shape is `(?, 2, 3)`.
Input preprocessing is possible via `preprocessing_layers` and
`preprocessing_combiner` Layers. If the `preprocessing_layers` nest is
shallower than `input_tensor_spec`, then the layers will get the subnests.
For example, if:
```python
input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5)
preprocessing_layers = (Layer1(), Layer2())
```
then preprocessing will call:
```python
preprocessed = [preprocessing_layers[0](observations[0]),
preprocessing_layers[1](observations[1])]
```
However if
```python
preprocessing_layers = ([Layer1() for _ in range(2)],
[Layer2() for _ in range(5)])
```
then preprocessing will call:
```python
preprocessed = [
layer(obs) for layer, obs in zip(flatten(preprocessing_layers),
flatten(observations))
]
```
**NOTE** `preprocessing_layers` and `preprocessing_combiner` are not allowed
to have already been built. This ensures calls to `network.copy()` in the
future always have an unbuilt, fresh set of parameters. Furtheremore,
a shallow copy of the layers is always created by the Network, so the
layer objects passed to the network are never modified. For more details
of the semantics of `copy`, see the docstring of
`tf_agents.networks.Network.copy`.
Args:
input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the
input observations.
preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer`
representing preprocessing for the different observations. All of these
layers must not be already built.
preprocessing_combiner: (Optional.) A keras layer that takes a flat list
of tensors and combines them. Good options include
`tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`. This
layer must not be already built.
conv_layer_params: Optional list of convolution layers parameters, where
each item is either a length-three tuple indicating
`(filters, kernel_size, stride)` or a length-four tuple indicating
`(filters, kernel_size, stride, dilation_rate)`.
fc_layer_params: Optional list of fully_connected parameters, where each
item is the number of units in the layer.
dropout_layer_params: Optional list of dropout layer parameters, each item
is the fraction of input units to drop or a dictionary of parameters
according to the keras.Dropout documentation. The additional parameter
`permanent`, if set to True, allows to apply dropout at inference for
approximated Bayesian inference. The dropout layers are interleaved with
the fully connected layers; there is a dropout layer after each fully
connected layer, except if the entry in the list is None. This list must
have the same length of fc_layer_params, or be None.
activation_fn: Activation function, e.g. tf.keras.activations.relu.
weight_decay_params: Optional list of weight decay parameters for the
fully connected layers.
kernel_initializer: Initializer to use for the kernels of the conv and
dense layers. If none is provided a default variance_scaling_initializer
batch_squash: If True the outer_ranks of the observation are squashed into
the batch dimension. This allow encoding networks to be used with
observations with shape [BxTx...].
dtype: The dtype to use by the convolution and fully connected layers.
name: A string representing name of the network.
conv_type: string, '1d' or '2d'. Convolution layers will be 1d or 2D
respectively
width: Scaling factor to apply to the layers.
Raises:
ValueError: If any of `preprocessing_layers` is already built.
ValueError: If `preprocessing_combiner` is already built.
ValueError: If the number of dropout layer parameters does not match the
number of fully connected layer parameters.
ValueError: If conv_layer_params tuples do not have 3 or 4 elements each.
"""
self._width = width
flat_preprocessing_layers = None
if (len(tf.nest.flatten(input_tensor_spec)) > 1 and
preprocessing_combiner is None):
raise ValueError(
'preprocessing_combiner layer is required when more than 1 '
'input_tensor_spec is provided.')
if preprocessing_combiner is not None:
preprocessing_combiner = _copy_layer(preprocessing_combiner)
if not kernel_initializer:
kernel_initializer = tf.compat.v1.variance_scaling_initializer(
scale=2.0, mode='fan_in', distribution='truncated_normal')
layers = []
if conv_layer_params:
if conv_type == '2d':
conv_layer_type = tf.keras.layers.Conv2D
elif conv_type == '1d':
conv_layer_type = tf.keras.layers.Conv1D
else:
raise ValueError('unsupported conv type of %s. Use 1d or 2d' % (
conv_type))
for config in conv_layer_params:
if len(config) == 4:
(filters, kernel_size, strides, dilation_rate) = config
elif len(config) == 3:
(filters, kernel_size, strides) = config
dilation_rate = (1, 1) if conv_type == '2d' else (1,)
else:
raise ValueError(
'only 3 or 4 elements permitted in conv_layer_params tuples')
kernel_regularizer = None
# We use the first weight decay param for all conv layers.
weight_decay = weight_decay_params[0]
if weight_decay is not None:
kernel_regularizer = tf.keras.regularizers.l2(weight_decay)
filters = tf_sparse_utils.scale_width(filters, self._width)
layers.append(
conv_layer_type(
filters=filters,
kernel_size=kernel_size,
strides=strides,
dilation_rate=dilation_rate,
activation=activation_fn,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
dtype=dtype))
layers.append(tf.keras.layers.Flatten())
if fc_layer_params:
if dropout_layer_params is None:
dropout_layer_params = [None] * len(fc_layer_params)
else:
if len(dropout_layer_params) != len(fc_layer_params):
raise ValueError('Dropout and fully connected layer parameter lists'
'have different lengths (%d vs. %d.)' %
(len(dropout_layer_params), len(fc_layer_params)))
if weight_decay_params is None:
weight_decay_params = [None] * len(fc_layer_params)
else:
if len(weight_decay_params) != len(fc_layer_params):
raise ValueError('Weight decay and fully connected layer parameter '
'lists have different lengths (%d vs. %d.)' %
(len(weight_decay_params), len(fc_layer_params)))
for num_units, dropout_params, weight_decay in zip(
fc_layer_params, dropout_layer_params, weight_decay_params):
kernel_regularizer = None
if weight_decay is not None:
kernel_regularizer = tf.keras.regularizers.l2(weight_decay)
layers.append(
tf.keras.layers.Dense(
tf_sparse_utils.scale_width(num_units, self._width),
activation=activation_fn,
kernel_initializer=kernel_initializer,
kernel_regularizer=kernel_regularizer,
dtype=dtype))
if not isinstance(dropout_params, dict):
dropout_params = {'rate': dropout_params} if dropout_params else None
if dropout_params is not None:
layers.append(
permanent_variable_rate_dropout.PermanentVariableRateDropout(
**dropout_params))
super(EncodingNetwork, self).__init__(
input_tensor_spec=input_tensor_spec, state_spec=(), name=name)
# Pull out the nest structure of the preprocessing layers. This avoids
# saving the original kwarg layers as a class attribute which Keras would
# then track.
self._preprocessing_nest = tf.nest.map_structure(lambda l: None,
preprocessing_layers)
self._flat_preprocessing_layers = flat_preprocessing_layers
self._preprocessing_combiner = preprocessing_combiner
self._postprocessing_layers = layers
self._batch_squash = batch_squash
self.built = True # Allow access to self.variables
def call(self, observation, step_type=None, network_state=(), training=False):
del step_type # unused.
if self._batch_squash:
outer_rank = nest_utils.get_outer_rank(
observation, self.input_tensor_spec)
batch_squash = utils.BatchSquash(outer_rank)
observation = tf.nest.map_structure(batch_squash.flatten, observation)
if self._flat_preprocessing_layers is None:
processed = observation
else:
raise ValueError('Flat preprocessing layers should be None.')
states = processed
if self._preprocessing_combiner is not None:
states = self._preprocessing_combiner(states)
for layer in self._postprocessing_layers:
states = layer(states, training=training)
if self._batch_squash:
states = tf.nest.map_structure(batch_squash.unflatten, states)
return states, network_state
================================================
FILE: rigl/rl/tfagents/sparse_ppo_actor_network.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Sequential Actor Network for PPO."""
import sys
import numpy as np
from rigl.rl.tfagents import tf_sparse_utils
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tf_agents.keras_layers import bias_layer
from tf_agents.networks import nest_map
from tf_agents.networks import sequential
def tanh_and_scale_to_spec(inputs, spec):
"""Maps inputs with arbitrary range to range defined by spec using `tanh`."""
means = (spec.maximum + spec.minimum) / 2.0
magnitudes = (spec.maximum - spec.minimum) / 2.0
return means + magnitudes * tf.tanh(inputs)
class PPOActorNetwork():
"""Contains the actor network structure."""
def __init__(self,
seed_stream_class=tfp.util.SeedStream,
is_sparse=False,
sparse_output_layer=False,
weight_decay=0.0,
width=1.0):
self.seed_stream_class = seed_stream_class
self._is_sparse = is_sparse
self._sparse_output_layer = sparse_output_layer
self._weight_decay = weight_decay
self._width = width
def create_sequential_actor_net(self,
fc_layer_units,
action_tensor_spec,
input_dim,
seed=None):
"""Helper method for creating the actor network."""
self._seed_stream = self.seed_stream_class(
seed=seed, salt='tf_agents_sequential_layers')
def _get_seed():
seed = self._seed_stream()
if seed is not None:
seed = seed % sys.maxsize
return seed
def create_dist(loc_and_scale):
loc = loc_and_scale['loc']
loc = tanh_and_scale_to_spec(loc, action_tensor_spec)
scale = loc_and_scale['scale']
scale = tf.math.softplus(scale)
return tfp.distributions.MultivariateNormalDiag(
loc=loc, scale_diag=scale, validate_args=True)
def means_layers():
layer = tf.keras.layers.Dense(
action_tensor_spec.shape.num_elements(),
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=0.1, seed=_get_seed()),
kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay),
name='means_projection_layer')
return layer
def std_layers():
std_bias_initializer_value = np.log(np.exp(0.35) - 1)
return bias_layer.BiasLayer(
bias_initializer=tf.constant_initializer(
value=std_bias_initializer_value))
def no_op_layers():
return tf.keras.layers.Lambda(lambda x: x)
def dense_layer(num_units):
layer = tf.keras.layers.Dense(
tf_sparse_utils.scale_width(num_units, self._width),
activation=tf.nn.tanh,
kernel_initializer=tf.keras.initializers.Orthogonal(seed=_get_seed()),
kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay),
)
return layer
all_layers = [dense_layer(n) for n in fc_layer_units]
all_layers.append(means_layers())
if self._is_sparse:
if self._sparse_output_layer:
all_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)
else:
new_layers = tf_sparse_utils.wrap_all_layers(all_layers[:-1], input_dim)
all_layers = new_layers + all_layers[-1:]
return sequential.Sequential(
all_layers +
[tf.keras.layers.Lambda(
lambda x: {'loc': x, 'scale': tf.zeros_like(x)})] +
[nest_map.NestMap({
'loc': no_op_layers(),
'scale': std_layers(),
})] +
# Create the output distribution from the mean and standard deviation.
[tf.keras.layers.Lambda(create_dist)])
================================================
FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Sparse Discrete Sequential Actor Network for PPO."""
import functools
import sys
import numpy as np
from rigl.rl.tfagents import tf_sparse_utils
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tf_agents.networks import sequential
from tf_agents.specs import distribution_spec
from tf_agents.specs import tensor_spec
def tanh_and_scale_to_spec(inputs, spec):
"""Maps inputs with arbitrary range to range defined by spec using `tanh`."""
mean = (spec.maximum + spec.minimum) / 2.0
magnitude = spec.maximum - spec.minimum
return mean + (magnitude * tf.tanh(inputs)) / 2.0
class PPODiscreteActorNetwork():
"""Contains the actor network structure."""
def __init__(self, seed_stream_class=tfp.util.SeedStream,
is_sparse=False,
sparse_output_layer=False,
weight_decay=0,
width=1.0):
if is_sparse:
raise ValueError('This functionality is not enabled. wrap_all_layers,'
'functionality needs to be implemented')
self.seed_stream_class = seed_stream_class
# Sparse params.
self._is_sparse = is_sparse
self._sparse_output_layer = sparse_output_layer
self._width = width
self._weight_decay = weight_decay
def create_sequential_actor_net(self,
fc_layer_units,
action_tensor_spec,
logits_init_output_factor=0.1,
seed=None):
"""Helper method for creating the actor network."""
self._seed_stream = self.seed_stream_class(
seed=seed, salt='tf_agents_sequential_layers')
# action_tensor_spec is a BoundedArraySpec which is an array with defined
# bounds. Maximum and minimum are arrays with the same shape as the
# main array.
unique_num_actions = np.unique(action_tensor_spec.maximum -
action_tensor_spec.minimum + 1)
if len(unique_num_actions) > 1 or np.any(unique_num_actions <= 0):
raise ValueError('Bounds on discrete actions must be the same for all '
'dimensions and have at least 1 action. Projection '
'Network requires num_actions to be equal across '
'action dimensions. Implement a more general '
'categorical projection if you need more flexibility.')
output_shape = action_tensor_spec.shape.concatenate(
[int(unique_num_actions)])
def _get_seed():
seed = self._seed_stream()
if seed is not None:
seed = seed % sys.maxsize
return seed
def create_dist(logits):
input_param_spec = {
'logits': tensor_spec.TensorSpec(
shape=(1,) + output_shape, dtype=tf.float32)
}
dist_spec = distribution_spec.DistributionSpec(
tfp.distributions.Categorical,
input_param_spec,
sample_spec=action_tensor_spec,
dtype=action_tensor_spec.dtype)
logits = tf.reshape(logits, [-1] + output_shape.as_list())
return dist_spec.build_distribution(logits=logits)
def dense_layer(num_units):
dense = functools.partial(
tf.keras.layers.Dense,
activation=tf.nn.tanh,
kernel_initializer=tf.keras.initializers.Orthogonal(seed=_get_seed()),
kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay))
layer = dense(tf_sparse_utils.scale_width(num_units, self._width))
if self._is_sparse:
return tf_sparse_utils.wrap_layer(layer)
else:
return layer
output_layer = tf.keras.layers.Dense(
output_shape.num_elements(),
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=logits_init_output_factor, seed=_get_seed()),
kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay),
bias_initializer=tf.keras.initializers.Zeros(),
name='logits',
dtype=tf.float32)
if self._is_sparse and self._sparse_output_layer:
output_layer = tf_sparse_utils.wrap_layer(output_layer)
return sequential.Sequential(
[dense_layer(num_units) for num_units in fc_layer_units] +
[output_layer] +
[tf.keras.layers.Lambda(create_dist)])
================================================
FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for sparse_ppo_discrete_actor_network."""
from absl import flags
from absl.testing import parameterized
from rigl.rl.tfagents import sparse_ppo_discrete_actor_network
import tensorflow as tf
from tf_agents.distributions import utils as distribution_utils
from tf_agents.specs import tensor_spec
from tf_agents.utils import test_utils
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
FLAGS = flags.FLAGS
class DeterministicSeedStream(object):
"""A fake seed stream class that always generates a deterministic seed."""
def __init__(self, seed, salt=''):
del salt
self._seed = seed
def __call__(self):
return self._seed
class PpoActorNetworkTest(parameterized.TestCase, test_utils.TestCase):
def setUp(self):
super(PpoActorNetworkTest, self).setUp()
# Run in full eager mode in order to inspect the content of tensors.
tf.config.experimental_run_functions_eagerly(True)
self.observation_tensor_spec = tf.TensorSpec(shape=[3], dtype=tf.float32)
self.action_tensor_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 3)
def tearDown(self):
tf.config.experimental_run_functions_eagerly(False)
super(PpoActorNetworkTest, self).tearDown()
def _init_network(
self, is_sparse=False, sparse_output_layer=False,
width=1.0, weight_decay=0):
actor_net_lib = sparse_ppo_discrete_actor_network.PPODiscreteActorNetwork(
is_sparse=is_sparse, sparse_output_layer=sparse_output_layer,
width=width, weight_decay=weight_decay)
actor_net_lib.seed_stream_class = DeterministicSeedStream
return actor_net_lib.create_sequential_actor_net(
fc_layer_units=(1,), action_tensor_spec=self.action_tensor_spec, seed=1)
def test_no_mismatched_shape(self):
actor_net = self._init_network()
actor_output_spec = actor_net.create_variables(self.observation_tensor_spec)
distribution_utils.assert_specs_are_compatible(
actor_output_spec, self.action_tensor_spec,
'actor_network output spec does not match action spec')
@parameterized.named_parameters(
('dense-output-F', False, False,
(tf.keras.layers.Dense, tf.keras.layers.Dense)),
('dense-output-T', False, True,
(tf.keras.layers.Dense, tf.keras.layers.Dense)),
('sparse-all', True, True,
(pruning_wrapper.PruneLowMagnitude, pruning_wrapper.PruneLowMagnitude)),
('sparse-outp-dense', True, False,
(pruning_wrapper.PruneLowMagnitude, tf.keras.layers.Dense)),
)
def test_is_sparse(self, is_sparse, sparse_output_layer, expected_layers):
expected_units = (1, 4)
actor_net = self._init_network(
is_sparse=is_sparse, sparse_output_layer=sparse_output_layer)
for i, (expected_layer, exp_units) in enumerate(
zip(expected_layers, expected_units)):
layer = actor_net.layers[i]
self.assertIsInstance(layer, expected_layer)
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
self.assertEqual(layer.layer.units, exp_units)
else:
self.assertEqual(layer.units, exp_units)
def test_width_scaling(self):
with self.subTest('dense'):
actor_net = self._init_network(width=2.0)
self.assertEqual(actor_net.layers[0].units, 2)
self.assertEqual(actor_net.layers[1].units, 4)
with self.subTest('sparse'):
actor_net = self._init_network(
is_sparse=True, sparse_output_layer=True, width=2.0)
self.assertEqual(actor_net.layers[0].layer.units, 2)
self.assertEqual(actor_net.layers[1].layer.units, 4)
@parameterized.named_parameters(
('no-wd-d-d', False, False, 0),
('no-wd-s-d', True, False, 0),
('no-wd-s-s', True, True, 0),
('wd-d-d', False, False, 0.1),
('wd-s-d', True, False, 0.1),
('wd-s-s', True, True, 0.1))
def test_weight_decay(self, is_sparse, sparse_output_layer,
expected_weight_decay):
actor_net = self._init_network(is_sparse=is_sparse,
sparse_output_layer=sparse_output_layer,
weight_decay=expected_weight_decay)
for i in range(2):
layer = actor_net.layers[i]
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
l2_weight_decay = layer.layer.kernel_regularizer.get_config()['l2']
else:
l2_weight_decay = layer.kernel_regularizer.get_config()['l2']
self.assertAlmostEqual(l2_weight_decay, expected_weight_decay)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: rigl/rl/tfagents/sparse_tanh_normal_projection_network.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Project inputs to a tanh-squashed MultivariateNormalDiag distribution.
This network reproduces Soft Actor-Critic refererence implementation in:
https://github.com/rail-berkeley/softlearning/
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Callable, Optional, Text
import gin
import tensorflow as tf
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.typing import types
@gin.configurable
class SparseTanhNormalProjectionNetwork(
tanh_normal_projection_network.TanhNormalProjectionNetwork):
"""Generates a tanh-squashed MultivariateNormalDiag distribution.
Note: Due to the nature of the `tanh` function, values near the spec bounds
cannot be returned.
"""
def __init__(self,
sample_spec,
activation_fn = None,
std_transform = tf.exp,
name = 'SparseTanhNormalProjectionNetwork',
weight_decay=0.0):
"""Creates an instance of SparseTanhNormalProjectionNetwork.
Args:
sample_spec: A `tensor_spec.BoundedTensorSpec` detailing the shape and
dtypes of samples pulled from the output distribution.
activation_fn: Activation function to use in dense layer.
std_transform: Transformation function to apply to the stddevs.
name: A string representing name of the network.
weight_decay: Weight decay for L2 regularization.
"""
super(SparseTanhNormalProjectionNetwork, self).__init__(
sample_spec=sample_spec,
activation_fn=activation_fn,
std_transform=std_transform,
name=name)
# We reinitialize the projection layer with L2 regularization and also
# optionally sparsify it.
self._projection_layer = tf.keras.layers.Dense(
sample_spec.shape.num_elements() * 2,
activation=activation_fn,
kernel_regularizer=tf.keras.regularizers.L2(weight_decay),
name='projection_layer')
================================================
FILE: rigl/rl/tfagents/sparse_value_network.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample Keras Value Network.
Implements a network that will generate the following layers:
[optional]: preprocessing_layers # preprocessing_layers
[optional]: (Add | Concat(axis=-1) | ...) # preprocessing_combiner
[optional]: Conv2D # conv_layer_params
Flatten
[optional]: Dense # fc_layer_params
Dense -> 1 # Value output
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gin
from rigl.rl.tfagents import sparse_encoding_network
from rigl.rl.tfagents import tf_sparse_utils
import tensorflow as tf
from tf_agents.networks import network
@gin.configurable
class ValueNetwork(network.Network):
"""Feed Forward value network. Reduces to 1 value output per batch item."""
def __init__(self,
input_tensor_spec,
preprocessing_combiner=None,
conv_layer_params=None,
fc_layer_params=(75, 40),
dropout_layer_params=None,
weight_decay=0.0,
activation_fn=tf.keras.activations.relu,
kernel_initializer=None,
batch_squash=True,
dtype=tf.float32,
name='ValueNetwork',
is_sparse=False,
sparse_output_layer=False,
width=1.0):
"""Creates an instance of `ValueNetwork`.
Network supports calls with shape outer_rank + observation_spec.shape. Note
outer_rank must be at least 1.
Args:
input_tensor_spec: A `tensor_spec.TensorSpec` or a tuple of specs
representing the input observations.
preprocessing_combiner: (Optional.) A keras layer that takes a flat list
of tensors and combines them. Good options include
`tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`.
This layer must not be already built. For more details see
the documentation of `networks.EncodingNetwork`.
conv_layer_params: Optional list of convolution layers parameters, where
each item is a length-three tuple indicating (filters, kernel_size,
stride).
fc_layer_params: Optional list of fully_connected parameters, where each
item is the number of units in the layer.
dropout_layer_params: Optional list of dropout layer parameters, each item
is the fraction of input units to drop or a dictionary of parameters
according to the keras.Dropout documentation. The additional parameter
`permanent`, if set to True, allows to apply dropout at inference for
approximated Bayesian inference. The dropout layers are interleaved with
the fully connected layers; there is a dropout layer after each fully
connected layer, except if the entry in the list is None. This list must
have the same length of fc_layer_params, or be None.
weight_decay: L2 weight decay regularization parameter.
activation_fn: Activation function, e.g. tf.keras.activations.relu,.
kernel_initializer: Initializer to use for the kernels of the conv and
dense layers. If none is provided a default variance_scaling_initializer
batch_squash: If True the outer_ranks of the observation are squashed into
the batch dimension. This allow encoding networks to be used with
observations with shape [BxTx...].
dtype: The dtype to use by the convolution and fully connected layers.
name: A string representing name of the network.
is_sparse: Whether the network is sparse.
sparse_output_layer: Whether the output layer should be sparse. Only
applied when is_sparse=True.
width: Scaling factor to apply to the layers.
Raises:
ValueError: If input_tensor_spec is not an instance of network.InputSpec.
"""
super(ValueNetwork, self).__init__(
input_tensor_spec=input_tensor_spec,
state_spec=(),
name=name)
self._is_sparse = is_sparse
self._sparse_output_layer = sparse_output_layer
self._width = width
if not kernel_initializer:
kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform()
self._encoder = sparse_encoding_network.EncodingNetwork(
input_tensor_spec,
preprocessing_layers=None,
preprocessing_combiner=preprocessing_combiner,
conv_layer_params=conv_layer_params,
fc_layer_params=fc_layer_params,
dropout_layer_params=dropout_layer_params,
activation_fn=activation_fn,
weight_decay_params=[weight_decay] * len(fc_layer_params),
kernel_initializer=kernel_initializer,
batch_squash=batch_squash,
dtype=dtype,
width=self._width)
self._postprocessing_layers = tf.keras.layers.Dense(
1,
activation=None,
kernel_initializer=tf.random_uniform_initializer(
minval=-0.03, maxval=0.03),
kernel_regularizer=tf.keras.regularizers.L2(weight_decay))
if is_sparse:
layers_to_wrap = [l for l in self._encoder._postprocessing_layers
if tf_sparse_utils.is_valid_layer_to_wrap(l)]
input_dim = input_tensor_spec.shape[0]
if sparse_output_layer:
layers_to_wrap.append(self._postprocessing_layers)
wrapped_layers = tf_sparse_utils.wrap_all_layers(
layers_to_wrap, input_dim)
self._postprocessing_layers = wrapped_layers[-1]
wrapped_layers = wrapped_layers[:-1]
else:
wrapped_layers = tf_sparse_utils.wrap_all_layers(
layers_to_wrap, input_dim)
# We need to recreate the original layer list after wrapping the layers.
new_layer_list = []
i = 0
for unwrapped_layer in self._encoder._postprocessing_layers:
if tf_sparse_utils.is_valid_layer_to_wrap(unwrapped_layer):
new_layer_list.append(wrapped_layers[i])
i += 1
else:
new_layer_list.append(unwrapped_layer)
self._encoder._postprocessing_layers = new_layer_list
def call(self, observation, step_type=None, network_state=(), training=False):
state, network_state = self._encoder(
observation, step_type=step_type, network_state=network_state,
training=training)
value = self._postprocessing_layers(state, training=training)
return tf.squeeze(value, -1), network_state
================================================
FILE: rigl/rl/tfagents/tf_sparse_utils.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for sparse tf agents training."""
import re
from absl import logging
import gin
from rigl import sparse_utils as sparse_utils_rigl
from rigl.rl import sparse_utils
import tensorflow.compat.v2 as tf
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper
PRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude
_LAYER_TYPES_TO_WRAP = (tf.keras.layers.Dense, tf.keras.layers.Conv2D,
tf.keras.layers.Conv1D)
def log_total_params(networks):
total_params = 0
for net in networks:
total_net_params, _ = sparse_utils.get_total_params(net)
total_params += total_net_params
with tf.name_scope('Params/'):
tf.compat.v2.summary.scalar('total', total_params)
def scale_width(num_units, width):
assert width > 0
return int(max(1, num_units * width))
@gin.configurable
def wrap_all_layers(layers,
input_dim,
mode='constant',
mask_init_method='erdos_renyi_kernel',
initial_sparsity=0.0,
final_sparsity=0.9,
begin_step=200000,
end_step=600000,
frequency=10000):
"""Wraps a list of dense keras layers to be used by sparse training."""
# We only need to define static masks here, we will update them through
# mask updater later.
new_layers = []
if mode == 'constant':
for layer in layers:
schedule = pruning_schedule.ConstantSparsity(
target_sparsity=0, begin_step=1000000000)
new_layers.append(PRUNING_WRAPPER(layer, pruning_schedule=schedule))
elif mode == 'prune':
logging.info('Pruning schedule: initial sparsity: %f', initial_sparsity)
logging.info('Pruning schedule: mask_init_method: %s', mask_init_method)
logging.info('Pruning schedule: final sparsity: %f', final_sparsity)
logging.info('Pruning schedule: begin step: %f', begin_step)
logging.info('Pruning schedule: end step: %f', end_step)
logging.info('Pruning schedule: frequency: %f', frequency)
# Create dummy masks to get layer-wise sparsities. This is because the
# get_sparsities function expects mask variables to calculate the
# sparsities.
dummy_masks_dict = {}
layer_input_dim = input_dim
for layer in layers:
mask = tf.Variable(tf.ones([layer_input_dim, layer.units]),
trainable=False, name=f'dummymask_{layer.name}')
layer_input_dim = layer.units
dummy_masks_dict[layer.name] = mask
# Get layer-wise sparsities.
extract_name_fn = lambda x: re.findall('(.+):0', x)[0]
reverse_dict = {v.name: k
for k, v in dummy_masks_dict.items()}
sparsity_dict = sparse_utils_rigl.get_sparsities(
list(dummy_masks_dict.values()),
mask_init_method,
final_sparsity,
custom_sparsity_map={},
extract_name_fn=extract_name_fn)
# This dict will have {layer_name: layer_sparsity}
renamed_sparsity_dict = {reverse_dict[k]: float(v)
for k, v in sparsity_dict.items()}
# Wrap layers with possibly non-uniform pruning schedule.
for layer in layers:
sparsity = renamed_sparsity_dict[layer.name]
logging.info('Layer: %s, sparsity: %f', layer.name, sparsity)
schedule = pruning_schedule.PolynomialDecay(
initial_sparsity=initial_sparsity,
final_sparsity=sparsity,
begin_step=begin_step,
end_step=end_step,
frequency=frequency)
new_layers.append(PRUNING_WRAPPER(layer, pruning_schedule=schedule))
return new_layers
@gin.configurable
def wrap_layer(layer,
mode='constant',
initial_sparsity=0.0,
final_sparsity=0.9,
begin_step=200000,
end_step=600000,
frequency=10000):
"""Wraps a keras layer to be used by sparse training."""
# We only need to define static masks here, we will update them through
# mask updater later.
if mode == 'constant':
schedule = pruning_schedule.ConstantSparsity(
target_sparsity=0, begin_step=1000000000)
elif mode == 'prune':
logging.info('Pruning schedule: initial sparsity: %f', initial_sparsity)
logging.info('Pruning schedule: final sparsity: %f', final_sparsity)
logging.info('Pruning schedule: begin step: %f', begin_step)
logging.info('Pruning schedule: end step: %f', end_step)
logging.info('Pruning schedule: frequency: %f', frequency)
schedule = pruning_schedule.PolynomialDecay(
initial_sparsity=initial_sparsity,
final_sparsity=final_sparsity,
begin_step=begin_step,
end_step=end_step,
frequency=frequency)
return PRUNING_WRAPPER(layer, pruning_schedule=schedule)
def is_valid_layer_to_wrap(layer):
for layer_type in _LAYER_TYPES_TO_WRAP:
if isinstance(layer, layer_type):
return True
return False
@gin.configurable
def log_sparsities(model, model_name='q_net', log_images=False):
"""Logs relevant sparsity stats to tensorboard."""
for layer in sparse_utils.get_all_pruning_layers(model):
for _, mask, threshold in layer.pruning_vars:
if log_images:
reshaped_mask = tf.expand_dims(tf.expand_dims(mask, 0), -1)
with tf.name_scope('Masks/'):
tf.compat.v2.summary.image(f'{model_name}/{mask.name}', reshaped_mask)
with tf.name_scope('Sparsity/'):
sparsity = 1 - tf.reduce_mean(mask)
tf.compat.v2.summary.scalar(f'{model_name}/{mask.name}', sparsity)
with tf.name_scope('Threshold/'):
tf.compat.v2.summary.scalar(f'{model_name}/{threshold.name}', threshold)
total_params, nparam_dict = sparse_utils.get_total_params(model)
with tf.name_scope('Params/'):
tf.compat.v2.summary.scalar(f'{model_name}/total', total_params)
for k, val in nparam_dict.items():
tf.compat.v2.summary.scalar(f'{model_name}/' + k, val)
def update_prune_step(model, step):
for layer in sparse_utils.get_all_pruning_layers(model):
# Assign iteration count to the layer pruning_step.
layer.pruning_step.assign(step)
def flatten_list_of_vars(var_list):
flat_vars = [tf.reshape(v, [-1]) for v in var_list]
return tf.concat(flat_vars, axis=-1)
@gin.configurable
def log_snr(tape, loss, step, variables_to_train, freq=1000):
"""Given a gradient tape and loss, it logs signal-to-noise ratio."""
def true_fn():
grads_per_sample = tape.jacobian(loss, variables_to_train)
list_of_snrs = []
for grad in grads_per_sample:
if grad is not None:
if isinstance(grad, tf.IndexedSlices):
grad_values = grad.values
else:
grad_values = grad
grad_mean = tf.math.reduce_mean(grad_values, axis=0)
grad_std = tf.math.reduce_std(grad_values, axis=0)
list_of_snrs.append(tf.abs(grad_mean / (grad_std + 1e-10)))
snr_mean = tf.reduce_mean(flatten_list_of_vars(list_of_snrs))
snr_std = tf.math.reduce_std((flatten_list_of_vars(list_of_snrs)))
with tf.name_scope('SNR/'):
tf.compat.v2.summary.scalar(name='mean', data=snr_mean, step=step)
tf.compat.v2.summary.scalar(name='std', data=snr_std, step=step)
tf.cond(step % freq == 0, true_fn, lambda: None)
================================================
FILE: rigl/rl/train.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""The entry point for training a sparse DQN agent."""
import os
from absl import app
from absl import flags
import gin
from rigl.rl import run_experiment
import tensorflow as tf
flags.DEFINE_string('base_dir', None,
'Base directory to host all required sub-directories.')
flags.DEFINE_multi_string(
'gin_files', [], 'List of paths to gin configuration files.')
flags.DEFINE_multi_string(
'gin_bindings', [],
'Gin bindings to override the values set in the config files '
'(e.g. "DQNAgent.epsilon_train=0.1",'
' "create_atari_environment.game_name="Pong"").')
FLAGS = flags.FLAGS
def create_sparsetrain_runner(base_dir):
assert base_dir is not None
return run_experiment.SparseTrainRunner(base_dir)
def main(unused_argv):
gin.parse_config_files_and_bindings(FLAGS.gin_files, FLAGS.gin_bindings)
runner = create_sparsetrain_runner(FLAGS.base_dir)
runner.run_experiment()
logconfigfile_path = os.path.join(FLAGS.base_dir, 'operative_config.gin')
with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
if __name__ == '__main__':
flags.mark_flag_as_required('base_dir')
app.run(main)
================================================
FILE: rigl/sparse_optimizers.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements some common and new sparse training algorithms."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import numpy as np
from rigl import sparse_optimizers_base as sparse_opt_base
from rigl import sparse_utils
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.training import moving_averages
from tensorflow.python.training import optimizer as tf_optimizer
from tensorflow.python.training import training_util
class PruningGetterTf1Mixin:
"""Tf1 model_pruning library specific variable retrieval."""
def get_weights(self):
return pruning.get_weights()
def get_masks(self):
return pruning.get_masks()
def get_masked_weights(self):
return pruning.get_masked_weights()
class SparseSETOptimizer(PruningGetterTf1Mixin,
sparse_opt_base.SparseSETOptimizerBase):
pass
class SparseRigLOptimizer(PruningGetterTf1Mixin,
sparse_opt_base.SparseRigLOptimizerBase):
pass
class SparseStaticOptimizer(SparseSETOptimizer):
"""Sparse optimizer that re-initializes weak connections during training.
Attributes:
optimizer: tf.train.Optimizer
begin_step: int, first iteration where masks are updated.
end_step: int, iteration after which no mask is updated.
frequency: int, of mask update operations.
drop_fraction: float, of connections to drop during each update.
drop_fraction_anneal: str or None, if supplied used to anneal the drop
fraction.
use_locking: bool, passed to the super.
grow_init: str, name of the method used to initialize new connections.
momentum: float, for the exponentialy moving average.
name: bool, passed to the super.
"""
def __init__(self,
optimizer,
begin_step,
end_step,
frequency,
drop_fraction=0.1,
drop_fraction_anneal='constant',
use_locking=False,
grow_init='zeros',
name='SparseStaticOptimizer',
stateless_seed_offset=0):
super(SparseStaticOptimizer, self).__init__(
optimizer,
begin_step,
end_step,
frequency,
drop_fraction=drop_fraction,
drop_fraction_anneal=drop_fraction_anneal,
grow_init=grow_init,
use_locking=use_locking,
name=name,
stateless_seed_offset=stateless_seed_offset)
def generic_mask_update(self, mask, weights, noise_std=1e-5):
"""True branch of the condition, updates the mask."""
# Ensure that the weights are masked.
masked_weights = mask * weights
score_drop = math_ops.abs(masked_weights)
# Add noise for slight bit of randomness.
score_drop += self._random_normal(
score_drop.shape,
stddev=noise_std,
dtype=score_drop.dtype,
seed=hash(weights.name + 'drop'))
# Revive n_prune many connections using momentum.
score_grow = mask
return self._get_update_op(
score_drop, score_grow, mask, weights, reinit_when_same=True)
class SparseMomentumOptimizer(SparseSETOptimizer):
"""Sparse optimizer that grows connections with the expected gradients.
A simplified implementation of Momentum based sparse optimizer. No
redistribution of sparsity.
Original implementation:
https://github.com/TimDettmers/sparse_learning/blob/master/mnist_cifar/main.py
Attributes:
optimizer: tf.train.Optimizer
begin_step: int, first iteration where masks are updated.
end_step: int, iteration after which no mask is updated.
frequency: int, of mask update operations.
drop_fraction: float, of connections to drop during each update.
drop_fraction_anneal: str or None, if supplied used to anneal the drop
fraction.
use_locking: bool, passed to the super.
grow_init: str, name of the method used to initialize new connections.
momentum: float, for the exponentialy moving average.
use_tpu: bool, if true the masked_gradients are aggregated.
name: bool, passed to the super.
"""
def __init__(self,
optimizer,
begin_step,
end_step,
frequency,
drop_fraction=0.1,
drop_fraction_anneal='constant',
use_locking=False,
grow_init='zeros',
momentum=0.9,
use_tpu=False,
name='SparseMomentumOptimizer',
stateless_seed_offset=0):
super(SparseMomentumOptimizer, self).__init__(
optimizer,
begin_step,
end_step,
frequency,
drop_fraction=drop_fraction,
drop_fraction_anneal=drop_fraction_anneal,
grow_init=grow_init,
use_locking=use_locking,
name='SparseMomentumOptimizer',
stateless_seed_offset=stateless_seed_offset)
self._ema_grads = moving_averages.ExponentialMovingAverage(decay=momentum)
self._use_tpu = use_tpu
def set_masked_grads(self, grads, weights):
if self._use_tpu:
grads = [tpu_ops.cross_replica_sum(g) for g in grads]
self._masked_grads = grads
# Using names since better to hash.
self._weight2masked_grads = {w.name: m for w, m in zip(weights, grads)}
def compute_gradients(self, loss, **kwargs):
"""Wraps the compute gradient of passed optimizer."""
grads_and_vars = self._optimizer.compute_gradients(loss, **kwargs)
# Need to update the EMA of the masked_weights. This is a bit hacky and
# might not work as expected if the gradients are not applied after every
# calculation. However, it should be fine if only .minimize() call is used.
masked_grads_vars = self._optimizer.compute_gradients(
loss, var_list=self.get_masked_weights())
masked_grads = [g for g, _ in masked_grads_vars]
self.set_masked_grads(masked_grads, self.get_weights())
return grads_and_vars
def _before_apply_gradients(self, grads_and_vars):
"""Updates momentum before updating the weights with gradient."""
return self._ema_grads.apply(self._masked_grads)
def generic_mask_update(self, mask, weights, noise_std=1e-5):
"""True branch of the condition, updates the mask."""
# Ensure that the weights are masked.
casted_mask = math_ops.cast(mask, dtypes.float32)
masked_weights = casted_mask * weights
score_drop = math_ops.abs(masked_weights)
# Add noise for slight bit of randomness.
score_drop += self._random_normal(
score_drop.shape,
stddev=noise_std,
dtype=score_drop.dtype,
seed=hash(weights.name + 'drop'))
# Revive n_prune many connections using momentum.
masked_grad = self._weight2masked_grads[weights.name]
score_grow = math_ops.abs(self._ema_grads.average(masked_grad))
return self._get_update_op(score_drop, score_grow, mask, weights)
class SparseSnipOptimizer(tf_optimizer.Optimizer):
"""Implementation of dynamic sparsity optimizers.
Implementation of Snip
https://arxiv.org/abs/1810.02340
Attributes:
optimizer: tf.train.Optimizer
default_sparsity: float, between 0 and 1.
mask_init_method: str, used to determine mask initializations.
custom_sparsity_map: dict, key/value pairs where the mask
correspond whose name is '{key}/mask:0' is set to the corresponding
sparsity value.
use_locking: bool, passed to the super.
use_tpu: bool, if true the masked_gradients are aggregated.
name: bool, passed to the super.
"""
def __init__(self,
optimizer,
default_sparsity,
mask_init_method,
custom_sparsity_map=None,
use_locking=False,
use_tpu=False,
name='SparseSnipOptimizer'):
super(SparseSnipOptimizer, self).__init__(use_locking, name)
if not custom_sparsity_map:
custom_sparsity_map = {}
self._optimizer = optimizer
self._use_tpu = use_tpu
self._default_sparsity = default_sparsity
self._mask_init_method = mask_init_method
self._custom_sparsity_map = custom_sparsity_map
self.is_snipped = variable_scope.get_variable(
'is_snipped', initializer=lambda: False, trainable=False)
def compute_gradients(self, loss, **kwargs):
"""Wraps the compute gradient of passed optimizer."""
return self._optimizer.compute_gradients(loss, **kwargs)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Wraps the original apply_gradient of the optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the variables
have been updated.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
"""
def apply_gradient_op():
return self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
maybe_reduce = lambda x: x
if self._use_tpu:
maybe_reduce = tpu_ops.cross_replica_sum
grads_and_vars_dict = {
re.findall('(.+)/weights:0', var.name)[0]: (maybe_reduce(grad), var)
for grad, var in grads_and_vars
if var.name.endswith('weights:0')
}
def snip_fn(mask, sparsity, dtype):
"""Creates a random sparse mask with deterministic sparsity.
Args:
mask: tf.Tensor, used to obtain correct corresponding gradient.
sparsity: float, between 0 and 1.
dtype: tf.dtype, type of the return value.
Returns:
tf.Tensor
"""
del dtype
var_name = sparse_utils.mask_extract_name_fn(mask.name)
g, v = grads_and_vars_dict[var_name]
score_drop = math_ops.abs(g * v)
n_total = np.prod(score_drop.shape.as_list())
n_prune = sparse_utils.get_n_zeros(n_total, sparsity)
n_keep = n_total - n_prune
# Sort the entire array since the k needs to be constant for TPU.
_, sorted_indices = nn_ops.top_k(
array_ops.reshape(score_drop, [-1]), k=n_total)
sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)
# We will have zeros after having `n_keep` many ones.
new_values = array_ops.where(
math_ops.range(n_total) < n_keep,
array_ops.ones_like(sorted_indices, dtype=mask.dtype),
array_ops.zeros_like(sorted_indices, dtype=mask.dtype))
new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values,
new_values.shape)
return array_ops.reshape(new_mask, mask.shape)
def snip_op():
all_masks = pruning.get_masks()
assigner = sparse_utils.get_mask_init_fn(
all_masks,
self._mask_init_method,
self._default_sparsity,
self._custom_sparsity_map,
mask_fn=snip_fn)
with ops.control_dependencies([assigner]):
assign_op = state_ops.assign(
self.is_snipped, True, name='assign_true_after_snipped')
return assign_op
maybe_snip_op = control_flow_ops.cond(
math_ops.logical_and(
math_ops.equal(global_step, 0),
math_ops.logical_not(self.is_snipped)), snip_op, apply_gradient_op)
return maybe_snip_op
class SparseDNWOptimizer(tf_optimizer.Optimizer):
"""Implementation of DNW optimizer.
Implementation of DNW.
See https://arxiv.org/pdf/1906.00586.pdf
This optimizer ensures the mask is updated at every iteration, according to
the current set of weights. It uses dense gradient to update weights.
Attributes:
optimizer: tf.train.Optimizer
default_sparsity: float, between 0 and 1.
mask_init_method: str, used to determine mask initializations.
custom_sparsity_map: dict, key/value pairs where the mask
correspond whose name is '{key}/mask:0' is set to the corresponding
sparsity value.
use_tpu: bool, if true the masked_gradients are aggregated.
use_locking: bool, passed to the super.
name: bool, passed to the super.
"""
def __init__(self,
optimizer,
default_sparsity,
mask_init_method,
custom_sparsity_map=None,
use_tpu=False,
use_locking=False,
name='SparseDNWOptimizer'):
super(SparseDNWOptimizer, self).__init__(use_locking, name)
self._optimizer = optimizer
self._use_tpu = use_tpu
self._default_sparsity = default_sparsity
self._mask_init_method = mask_init_method
self._custom_sparsity_map = custom_sparsity_map
def compute_gradients(self, loss, var_list=None, **kwargs):
"""Wraps the compute gradient of passed optimizer."""
# Replace masked variables with masked_weights so that the gradient is dense
# and not masked
if var_list is None:
var_list = (
variables.trainable_variables() +
ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
var_list = self.replace_with_masked_weights(var_list)
grads_and_vars = self._optimizer.compute_gradients(
loss, var_list=var_list, **kwargs)
return self.replace_masked_weights(grads_and_vars)
def replace_with_masked_weights(self, var_list):
"""Replaces masked variables with masked weights."""
weight2masked_weights = {
w.name: mw
for w, mw in zip(self.get_weights(), self.get_masked_weights())
}
updated_var_list = [weight2masked_weights.get(w.name, w) for w in var_list]
return updated_var_list
def replace_masked_weights(self, grads_and_vars):
"""Replaces masked weight tensords with weight variables."""
masked_weights2weight = {
mw.name: w
for w, mw in zip(self.get_weights(), self.get_masked_weights())
}
updated_grads_and_vars = [
(g, masked_weights2weight.get(w.name, w)) for g, w in grads_and_vars
]
return updated_grads_and_vars
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Wraps the original apply_gradient of the optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the variables
have been updated.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
"""
optimizer_update = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
vars_dict = {
re.findall('(.+)/weights:0', var.name)[0]: var
for var in self.get_weights()
}
def dnw_fn(mask, sparsity, dtype):
"""Creates a mask with smallest magnitudes with deterministic sparsity.
Args:
mask: tf.Tensor, used to obtain correct corresponding gradient.
sparsity: float, between 0 and 1.
dtype: tf.dtype, type of the return value.
Returns:
tf.Tensor
"""
del dtype
var_name = sparse_utils.mask_extract_name_fn(mask.name)
v = vars_dict[var_name]
score_drop = math_ops.abs(v)
n_total = np.prod(score_drop.shape.as_list())
n_prune = sparse_utils.get_n_zeros(n_total, sparsity)
n_keep = n_total - n_prune
# Sort the entire array since the k needs to be constant for TPU.
_, sorted_indices = nn_ops.top_k(
array_ops.reshape(score_drop, [-1]), k=n_total)
sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)
# We will have zeros after having `n_keep` many ones.
new_values = array_ops.where(
math_ops.range(n_total) < n_keep,
array_ops.ones_like(sorted_indices, dtype=mask.dtype),
array_ops.zeros_like(sorted_indices, dtype=mask.dtype))
new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values,
new_values.shape)
return array_ops.reshape(new_mask, mask.shape)
with ops.control_dependencies([optimizer_update]):
all_masks = self.get_masks()
mask_update_op = sparse_utils.get_mask_init_fn(
all_masks,
self._mask_init_method,
self._default_sparsity,
self._custom_sparsity_map,
mask_fn=dnw_fn)
return mask_update_op
def get_weights(self):
return pruning.get_weights()
def get_masks(self):
return pruning.get_masks()
def get_masked_weights(self):
return pruning.get_masked_weights()
================================================
FILE: rigl/sparse_optimizers_base.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements some common and new sparse training algorithms."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.training import learning_rate_decay
from tensorflow.python.training import optimizer as tf_optimizer
from tensorflow.python.training import training_util
def extract_number(token):
"""Strips the number from the end of the token if it exists.
Args:
token: str, s or s_d where d is a number: a float or int. `foo_.5`,
`foo_foo.5`, `foo_0.5`, `foo_4` are all valid strings.
Returns:
float, d if exists otherwise 1.
"""
regexp = re.compile(r'.*_(\d*\.?\d*)$')
if regexp.search(token):
return float(regexp.search(token).group(1))
else:
return 1.
class SparseSETOptimizerBase(tf_optimizer.Optimizer):
"""Implementation of dynamic sparsity optimizers.
Implementation of SET.
See https://www.nature.com/articles/s41467-018-04316-3
This optimizer wraps a regular optimizer and performs updates on the masks
according to schedule given.
Attributes:
optimizer: tf.train.Optimizer
begin_step: int, first iteration where masks are updated.
end_step: int, iteration after which no mask is updated.
frequency: int, of mask update operations.
drop_fraction: float, of connections to drop during each update.
drop_fraction_anneal: str or None, if supplied used to anneal the drop
fraction.
use_locking: bool, passed to the super.
grow_init: str, name of the method used to initialize new connections.
name: bool, passed to the super.
use_stateless: bool, if True stateless operations are used. This is
important for multi-worker jobs not to diverge.
stateless_seed_offset: int, added to the seed of stateless operations. Use
this to create randomness without divergence across workers.
"""
def __init__(self,
optimizer,
begin_step,
end_step,
frequency,
drop_fraction=0.1,
drop_fraction_anneal='constant',
use_locking=False,
grow_init='zeros',
name='SparseSETOptimizer',
use_stateless=True,
stateless_seed_offset=0):
super(SparseSETOptimizerBase, self).__init__(use_locking, name)
self._optimizer = optimizer
self._grow_init = grow_init
self._drop_fraction_anneal = drop_fraction_anneal
self._drop_fraction_initial_value = ops.convert_to_tensor(
float(drop_fraction),
name='%s_drop_fraction' % self._drop_fraction_anneal)
self._begin_step = ops.convert_to_tensor(begin_step, name='begin_step')
self._end_step = ops.convert_to_tensor(end_step, name='end_step')
self._frequency = ops.convert_to_tensor(frequency, name='frequency')
self._frequency_val = frequency
self._use_stateless = use_stateless
self._stateless_seed_offset = stateless_seed_offset
def compute_gradients(self, loss, **kwargs):
"""Wraps the compute gradient of passed optimizer."""
result = self._optimizer.compute_gradients(loss, **kwargs)
return result
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Wraps the original apply_gradient of the optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the variables
have been updated.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
"""
pre_op = self._before_apply_gradients(grads_and_vars)
with ops.control_dependencies([pre_op]):
optimizer_update = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
# We get the default one after calling the super.apply_gradient(), since
# we want to preserve original behavior of the optimizer: don't increment
# anything if no global_step is passed. But we need the global step for
# the mask_update.
global_step = (
global_step if global_step is not None else
training_util.get_or_create_global_step())
self._global_step = global_step
with ops.control_dependencies([optimizer_update]):
return self.cond_mask_update_op(global_step, control_flow_ops.no_op)
def _before_apply_gradients(self, grads_and_vars):
"""Called before applying gradients."""
return control_flow_ops.no_op('before_apply_grad')
def cond_mask_update_op(self, global_step, false_branch):
"""Creates the conditional mask update operation.
All masks are updated when it is an update iteration
(checked by self.is_mask_update_iter()).
Arguments:
global_step: tf.Variable, current training iteration.
false_branch: function, called when it is not a mask update iteration.
Returns:
conditional update operation
"""
# Initializing to -freq so that last_update_step+freq=0. This enables early
# mask_updates.
last_update_step = variable_scope.get_variable(
'last_mask_update_step', [],
initializer=init_ops.constant_initializer(
-self._frequency_val, dtype=global_step.dtype),
trainable=False,
dtype=global_step.dtype)
def mask_update_op():
update_ops = []
for mask, weights in zip(self.get_masks(), self.get_weights()):
update_ops.append(self.generic_mask_update(mask, weights))
with ops.control_dependencies(update_ops):
assign_op = state_ops.assign(
last_update_step, global_step, name='last_mask_update_step_assign')
with ops.control_dependencies([assign_op]):
return control_flow_ops.no_op('mask_update')
maybe_update = control_flow_ops.cond(
self.is_mask_update_iter(global_step, last_update_step), mask_update_op,
false_branch)
return maybe_update
def get_weights(self):
raise NotImplementedError
def get_masks(self):
raise NotImplementedError
def get_masked_weights(self):
raise NotImplementedError
def is_mask_update_iter(self, global_step, last_update_step):
"""Function for checking if the current step is a mask update step.
It also creates the drop_fraction op and assigns it to the self object.
Args:
global_step: tf.Variable(int), current training step.
last_update_step: tf.Variable(int), holding the last iteration the mask is
updated. Used to determine whether current iteration is a mask update
step.
Returns:
bool, whether the current iteration is a mask_update step.
"""
gs_dtype = global_step.dtype
self._begin_step = math_ops.cast(self._begin_step, gs_dtype)
self._end_step = math_ops.cast(self._end_step, gs_dtype)
self._frequency = math_ops.cast(self._frequency, gs_dtype)
is_step_within_update_range = math_ops.logical_and(
math_ops.greater_equal(global_step, self._begin_step),
math_ops.logical_or(
math_ops.less_equal(global_step, self._end_step),
# If _end_step is negative, we never stop updating the mask.
# In other words we update the mask with given frequency until the
# training ends.
math_ops.less(self._end_step, 0)))
is_update_step = math_ops.less_equal(
math_ops.add(last_update_step, self._frequency), global_step)
is_mask_update_iter_op = math_ops.logical_and(is_step_within_update_range,
is_update_step)
self.drop_fraction = self.get_drop_fraction(global_step,
is_mask_update_iter_op)
return is_mask_update_iter_op
def get_drop_fraction(self, global_step, is_mask_update_iter_op):
"""Returns a constant or annealing drop_fraction op."""
if self._drop_fraction_anneal == 'constant':
drop_frac = self._drop_fraction_initial_value
elif self._drop_fraction_anneal == 'cosine':
decay_steps = self._end_step - self._begin_step
drop_frac = learning_rate_decay.cosine_decay(
self._drop_fraction_initial_value,
global_step,
decay_steps,
name='cosine_drop_fraction')
elif self._drop_fraction_anneal.startswith('exponential'):
exponent = extract_number(self._drop_fraction_anneal)
div_dtype = self._drop_fraction_initial_value.dtype
power = math_ops.divide(
math_ops.cast(global_step - self._begin_step, div_dtype),
math_ops.cast(self._end_step - self._begin_step, div_dtype),
)
drop_frac = math_ops.multiply(
self._drop_fraction_initial_value,
math_ops.pow(1 - power, exponent),
name='%s_drop_fraction' % self._drop_fraction_anneal)
else:
raise ValueError('drop_fraction_anneal: %s is not valid' %
self._drop_fraction_anneal)
return array_ops.where(is_mask_update_iter_op, drop_frac,
array_ops.zeros_like(drop_frac))
def generic_mask_update(self, mask, weights, noise_std=1e-5):
"""True branch of the condition, updates the mask."""
# Ensure that the weights are masked.
masked_weights = mask * weights
score_drop = math_ops.abs(masked_weights)
# Add noise for slight bit of randomness.
score_drop += self._random_normal(
score_drop.shape,
stddev=noise_std,
dtype=score_drop.dtype,
seed=(hash(weights.name + 'drop')))
# Randomly revive n_prune many connections from non-existing connections.
score_grow = self._random_uniform(
weights.shape, seed=hash(weights.name + 'grow'))
return self._get_update_op(score_drop, score_grow, mask, weights)
def _get_update_op(self,
score_drop,
score_grow,
mask,
weights,
reinit_when_same=False):
"""Prunes+grows connections, all tensors same shape."""
old_dtype = mask.dtype
mask_casted = math_ops.cast(mask, dtypes.float32)
n_total = array_ops.size(score_drop)
n_ones = math_ops.cast(math_ops.reduce_sum(mask_casted), dtype=dtypes.int32)
n_prune = math_ops.cast(
math_ops.cast(n_ones, dtype=dtypes.float32) * self.drop_fraction,
dtypes.int32)
n_keep = n_ones - n_prune
# Sort the entire array since the k needs to be constant for TPU.
_, sorted_indices = nn_ops.top_k(
array_ops.reshape(score_drop, [-1]), k=n_total)
sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)
# We will have zeros after having `n_keep` many ones.
new_values = array_ops.where(
math_ops.range(n_total) < n_keep,
array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype),
array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype))
mask1 = array_ops.scatter_nd(sorted_indices_ex, new_values,
new_values.shape)
# Flatten the scores
score_grow = array_ops.reshape(score_grow, [-1])
# Set scores of the enabled connections(ones) to min(s) - 1, so that they
# have the lowest scores.
score_grow_lifted = array_ops.where(
math_ops.equal(mask1, 1),
array_ops.ones_like(mask1) * (math_ops.reduce_min(score_grow) - 1),
score_grow)
_, sorted_indices = nn_ops.top_k(score_grow_lifted, k=n_total)
sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)
new_values = array_ops.where(
math_ops.range(n_total) < n_prune,
array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype),
array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype))
mask2 = array_ops.scatter_nd(sorted_indices_ex, new_values,
new_values.shape)
# Ensure masks are disjoint.
assert_op = control_flow_ops.Assert(
math_ops.equal(math_ops.reduce_sum(mask1 * mask2), 0.), [mask1, mask2])
with ops.control_dependencies([assert_op]):
# Let's set the weights of the growed connections.
mask2_reshaped = array_ops.reshape(mask2, mask.shape)
# Set the values of the new connections.
grow_tensor = self.get_grow_tensor(weights, self._grow_init)
if reinit_when_same:
# If dropped and grown, we re-initialize.
new_connections = math_ops.equal(mask2_reshaped, 1)
else:
new_connections = math_ops.logical_and(
math_ops.equal(mask2_reshaped, 1), math_ops.equal(mask_casted, 0))
new_weights = array_ops.where(new_connections, grow_tensor, weights)
weights_update = state_ops.assign(weights, new_weights)
# Ensure there is no momentum value for new connections
reset_op = self.reset_momentum(weights, new_connections)
with ops.control_dependencies([weights_update, reset_op]):
mask_combined = array_ops.reshape(mask1 + mask2, mask.shape)
mask_combined = math_ops.cast(mask_combined, dtype=old_dtype)
new_mask = state_ops.assign(mask, mask_combined)
return new_mask
def reset_momentum(self, weights, new_connections):
reset_ops = []
for s_name in self._optimizer.get_slot_names():
# Momentum variable for example, we reset the aggregated values to zero.
optim_var = self._optimizer.get_slot(weights, s_name)
new_values = array_ops.where(new_connections,
array_ops.zeros_like(optim_var), optim_var)
reset_ops.append(state_ops.assign(optim_var, new_values))
return control_flow_ops.group(reset_ops)
def get_grow_tensor(self, weights, method):
"""Different ways to initialize new connections.
Args:
weights: tf.Tensor or Variable.
method: str, available options: 'zeros', 'random_normal', 'random_uniform'
and 'initial_value'
Returns:
tf.Tensor same shape and type as weights.
Raises:
ValueError, when the method is not valid.
"""
if not isinstance(method, six.string_types):
raise ValueError('Grow-Init: %s is not a string' % method)
if method == 'zeros':
grow_tensor = array_ops.zeros_like(weights, dtype=weights.dtype)
elif method.startswith('initial_dist'):
original_shape = weights.initial_value.shape
divisor = extract_number(method)
grow_tensor = array_ops.reshape(
random_ops.random_shuffle(
array_ops.reshape(weights.initial_value, [-1])),
original_shape) / divisor
elif method.startswith('random_normal'):
stddev = math_ops.reduce_std(weights)
divisor = extract_number(method)
grow_tensor = self._random_normal(
weights.shape,
stddev=stddev,
dtype=weights.dtype,
seed=hash(weights.name + 'grow_init_n')) / divisor
elif method.startswith('random_uniform'):
mean = math_ops.reduce_mean(math_ops.abs(weights))
divisor = extract_number(method)
grow_tensor = self._random_uniform(
weights.shape,
minval=-mean,
maxval=mean,
dtype=weights.dtype,
seed=hash(weights.name + 'grow_init_u')) / divisor
else:
raise ValueError('Grow-Init: %s is not a valid option.' % method)
return grow_tensor
def _random_uniform(self, *args, **kwargs):
if self._use_stateless:
c_seed = self._stateless_seed_offset + kwargs['seed']
kwargs['seed'] = math_ops.cast(
array_ops.stack([c_seed, self._global_step]), dtypes.int32)
return stateless_random_ops.stateless_random_uniform(*args, **kwargs)
else:
return random_ops.random_uniform(*args, **kwargs)
def _random_normal(self, *args, **kwargs):
if self._use_stateless:
c_seed = self._stateless_seed_offset + kwargs['seed']
kwargs['seed'] = math_ops.cast(
array_ops.stack([c_seed, self._global_step]), dtypes.int32)
return stateless_random_ops.stateless_random_normal(*args, **kwargs)
else:
return random_ops.random_normal(*args, **kwargs)
class SparseRigLOptimizerBase(SparseSETOptimizerBase):
"""Sparse optimizer that grows connections with the pre-removal gradients.
Attributes:
optimizer: tf.train.Optimizer
begin_step: int, first iteration where masks are updated.
end_step: int, iteration after which no mask is updated.
frequency: int, of mask update operations.
drop_fraction: float, of connections to drop during each update.
drop_fraction_anneal: str or None, if supplied used to anneal the drop
fraction.
use_locking: bool, passed to the super.
grow_init: str, name of the method used to initialize new connections.
init_avg_scale: float, used to scale the gradient when initializing the,
momentum values of new connections. We hope this will improve training,
compare to starting from 0 for the new connections. Set this to something
between 0 and 1 / (1 - momentum). This is because in the current
implementation of MomentumOptimizer, aggregated values converge to 1 / (1
- momentum) with constant gradients.
use_tpu: bool, if true the masked_gradients are aggregated.
name: bool, passed to the super.
"""
def __init__(self,
optimizer,
begin_step,
end_step,
frequency,
drop_fraction=0.1,
drop_fraction_anneal='constant',
use_locking=False,
grow_init='zeros',
initial_acc_scale=0.,
use_tpu=False,
name='SparseRigLOptimizer',
stateless_seed_offset=0):
super(SparseRigLOptimizerBase, self).__init__(
optimizer,
begin_step,
end_step,
frequency,
drop_fraction=drop_fraction,
drop_fraction_anneal=drop_fraction_anneal,
grow_init=grow_init,
use_locking=use_locking,
name='SparseRigLOptimizer',
stateless_seed_offset=stateless_seed_offset)
self._initial_acc_scale = initial_acc_scale
self._use_tpu = use_tpu
def set_masked_grads(self, grads, weights):
if self._use_tpu:
grads = [tpu_ops.cross_replica_sum(g) for g in grads]
self._masked_grads = grads
# Using names since better to hash.
self._weight2masked_grads = {w.name: m for w, m in zip(weights, grads)}
def compute_gradients(self, loss, **kwargs):
"""Wraps the compute gradient of passed optimizer."""
grads_and_vars = self._optimizer.compute_gradients(loss, **kwargs)
masked_grads_vars = self._optimizer.compute_gradients(
loss, var_list=self.get_masked_weights())
masked_grads = [g for g, _ in masked_grads_vars]
self.set_masked_grads(masked_grads, self.get_weights())
return grads_and_vars
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Wraps the original apply_gradient of the optimizer.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the variables
have been updated.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
"""
pre_op = self._before_apply_gradients(grads_and_vars)
with ops.control_dependencies([pre_op]):
# Call this to create slots.
_ = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
def apply_gradient_op():
optimizer_update = self._optimizer.apply_gradients(
grads_and_vars, global_step=global_step, name=name)
return optimizer_update
# We get the default one after calling the super.apply_gradient(), since
# we want to preserve original behavior of the optimizer: don't increment
# anything if no global_step is passed. But we need the global step for
# the mask_update.
global_step = (
global_step if global_step is not None else
training_util.get_or_create_global_step())
self._global_step = global_step
return self.cond_mask_update_op(global_step, apply_gradient_op)
def generic_mask_update(self, mask, weights, noise_std=1e-5):
"""True branch of the condition, updates the mask."""
# Ensure that the weights are masked.
casted_mask = math_ops.cast(mask, dtype=dtypes.float32)
masked_weights = casted_mask * weights
score_drop = math_ops.abs(masked_weights)
# Add noise for slight bit of randomness.
score_drop += self._random_normal(
score_drop.shape,
stddev=noise_std,
dtype=score_drop.dtype,
seed=hash(weights.name + 'drop'))
# Revive n_prune many connections using gradient.
score_grow = math_ops.abs(self._weight2masked_grads[weights.name])
with ops.control_dependencies([score_grow]):
return self._get_update_op(score_drop, score_grow, mask, weights)
def get_grow_tensor(self, weights, method):
"""Returns initialization for grown weights."""
if method.startswith('grad_scale'):
masked_grad = self._weight2masked_grads[weights.name]
divisor = extract_number(method)
grow_tensor = masked_grad / divisor
elif method.startswith('grad_sign'):
masked_grad_sign = math_ops.sign(self._weight2masked_grads[weights.name])
divisor = extract_number(method)
grow_tensor = masked_grad_sign / divisor
else:
grow_tensor = super(SparseRigLOptimizerBase,
self).get_grow_tensor(weights, method)
return grow_tensor
def reset_momentum(self, weights, new_connections):
reset_ops = []
for s_name in self._optimizer.get_slot_names():
# Momentum variable for example, we reset the aggregated values to zero.
optim_var = self._optimizer.get_slot(weights, s_name)
accum_grad = (
self._weight2masked_grads[weights.name] * self._initial_acc_scale)
new_values = array_ops.where(new_connections, accum_grad, optim_var)
reset_ops.append(state_ops.assign(optim_var, new_values))
return control_flow_ops.group(reset_ops)
================================================
FILE: rigl/sparse_optimizers_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the sparse_optimizers file."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl import flags
from absl.testing import parameterized
import numpy as np
from rigl import sparse_optimizers
from rigl import sparse_utils
import tensorflow.compat.v1 as tf # tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
FLAGS = flags.FLAGS
class SparseSETOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
freq_iter=2):
"""Setups a trivial training procedure for sparse training."""
tf.reset_default_graph()
optim = tf.train.GradientDescentOptimizer(0.1)
sparse_optim = sparse_optimizers.SparseSETOptimizer(
optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
x = tf.random.uniform((1, n_inp))
y = layers.masked_fully_connected(x, n_out, activation_fn=None)
global_step = tf.train.get_or_create_global_step()
weight = pruning.get_weights()[0]
# There is one masked layer to be trained.
mask = pruning.get_masks()[0]
# Around half of the values of the mask is set to zero with `mask_update`.
mask_update = tf.assign(
mask,
tf.constant(
np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]),
dtype=tf.float32))
loss = tf.reduce_mean(y)
global_step = tf.train.get_or_create_global_step()
train_op = sparse_optim.minimize(loss, global_step)
# Init
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run([mask_update])
return sess, train_op, mask, weight, global_step
@parameterized.parameters((15, 25, 0.5), (15, 25, 0.2), (3, 5, 0.2))
def testMaskNonUpdateIterations(self, n_inp, n_out, drop_frac):
"""Training a layer for 5 iterations and see whether mask is kept intact.
The mask should be updated only in iterations 1 and 3 (since start_iter=1,
end_iter=4, freq_iter=2).
Args:
n_inp: int, number of input channels.
n_out: int, number of output channels
drop_frac: float, passed to the sparse optimizer.
"""
sess, train_op, mask, _, _ = self._setup_graph(
n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2)
expected_updates = [1, 3]
# Running 5 times to make sure the mask is not updated after end_iter.
for i in range(1, 6):
c_mask, = sess.run([mask])
sess.run([train_op])
c_mask2, = sess.run([mask])
if i not in expected_updates:
self.assertAllEqual(c_mask, c_mask2)
@parameterized.parameters((15, 25, 0.5), (15, 25, 0.7), (30, 10, 0.9))
def testUpdateIterations(self, n_inp, n_out, drop_frac):
"""Checking whether the mask is updated during correct iterations.
The mask should be updated only in iterations 1 and 3 (since start_iter=1,
end_iter=4, freq_iter=2). Number of 1's in the mask should be equal.
Args:
n_inp: int, number of input channels.
n_out: int, number of output channels
drop_frac: float, passed to the sparse optimizer.
"""
sess, train_op, mask, _, _ = self._setup_graph(
n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2)
expected_updates = [1, 3]
# Running 4 times since last update is at 3.
for i in range(1, 5):
c_mask, = sess.run([mask])
sess.run([train_op])
c_mask2, = sess.run([mask])
if i in expected_updates:
# Number of ones (connections) should be same.
self.assertEqual(c_mask.sum(), c_mask2.sum())
# Assert there is some change in the mask.
self.assertNotAllClose(c_mask, c_mask2)
@parameterized.parameters((3, 7, 2), (1, 5, 3), (0, 4, 1))
def testNoDrop(self, start_iter, end_iter, freq_iter):
"""Checks when the drop fraction is 0, no update is made.
The mask should be updated only in iterations 1 and 3 (since start_iter=1,
end_iter=4, freq_iter=2). Number of 1's in the mask should be equal.
Args:
start_iter: int, start iteration for sparse training.
end_iter: int, final iteration for sparse training.
freq_iter: int, mask update frequency.
"""
# Setting drop_fraction to 0; so there is nothing dropped, nothing changed.
sess, train_op, mask, _, _ = self._setup_graph(
3, 5, 0, start_iter=start_iter, end_iter=end_iter, freq_iter=freq_iter)
for _ in range(end_iter+2):
c_mask, = sess.run([mask])
sess.run([train_op])
c_mask2, = sess.run([mask])
self.assertAllEqual(c_mask, c_mask2)
def testNewConnectionZeroInit(self):
"""Checks whether the new connections are initialized correctly to zeros.
"""
end_iter = 4
sess, train_op, mask, weight, _ = self._setup_graph(
n_inp=3, n_out=5, drop_frac=0.5, start_iter=0, end_iter=end_iter,
freq_iter=1)
# Let's iterate until the mask updates are done.
for _ in range(end_iter + 1):
mask_tensor, = sess.run([mask])
sess.run([train_op])
new_mask_tensor, new_weight_tensor = sess.run([mask, weight])
# Let's sum the values of the new connections
new_weights = new_weight_tensor[np.logical_and(mask_tensor == 0,
new_mask_tensor == 1)]
self.assertTrue(np.all(new_weights == 0))
@parameterized.parameters(itertools.product(
((3, 7, 2), (5, 3), (1,)), ('zeros', 'random_normal', 'random_uniform')))
def testShapeOfGetGrowTensor(self, shape, init_type):
"""Checks whether the new tensor is created with correct shape."""
optim = tf.train.GradientDescentOptimizer(0.1)
sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1,
use_stateless=False)
weights = tf.random_uniform(shape)
grow_tensor = sparse_optim.get_grow_tensor(weights, init_type)
self.assertAllEqual(weights.shape, grow_tensor.shape)
@parameterized.parameters(itertools.product(
(tf.float32, tf.float64),
('zeros', 'random_normal', 'random_uniform')))
def testDtypeOfGetGrowTensor(self, dtype, init_type):
"""Checks whether the new tensor is created with correct data type."""
optim = tf.train.GradientDescentOptimizer(0.1)
sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1,
use_stateless=False)
weights = tf.random_uniform((3, 4), dtype=dtype, maxval=5)
grow_tensor = sparse_optim.get_grow_tensor(weights, init_type)
self.assertEqual(grow_tensor.dtype, weights.dtype)
@parameterized.parameters('ones', 'zero', None, 0)
def testValueErrorOfGetGrowTensor(self, method):
"""Checks whether the new tensor is created with correct shape and type."""
optim = tf.train.GradientDescentOptimizer(0.1)
sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1,
use_stateless=False)
weights = tf.random_uniform((3, 4))
with self.assertRaises(ValueError):
sparse_optim.get_grow_tensor(weights, method)
class SparseStaticOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
freq_iter=2):
"""Setups a trivial training procedure for sparse training."""
tf.reset_default_graph()
optim = tf.train.GradientDescentOptimizer(0.1)
sparse_optim = sparse_optimizers.SparseStaticOptimizer(
optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
x = tf.random.uniform((1, n_inp))
y = layers.masked_fully_connected(x, n_out, activation_fn=None)
global_step = tf.train.get_or_create_global_step()
weight = pruning.get_weights()[0]
# There is one masked layer to be trained.
mask = pruning.get_masks()[0]
# Around half of the values of the mask is set to zero with `mask_update`.
mask_update = tf.assign(
mask,
tf.constant(
np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]),
dtype=tf.float32))
loss = tf.reduce_mean(y)
global_step = tf.train.get_or_create_global_step()
train_op = sparse_optim.minimize(loss, global_step)
# Init
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run([mask_update])
return sess, train_op, mask, weight, global_step
@parameterized.parameters((15, 25, 0.5), (15, 25, 0.2), (3, 5, 0.2))
def testMaskStatic(self, n_inp, n_out, drop_frac):
"""Training a layer for 5 iterations and see whether mask is kept intact.
The mask should be updated only in iterations 1 and 3 (since start_iter=1,
end_iter=4, freq_iter=2).
Args:
n_inp: int, number of input channels.
n_out: int, number of output channels
drop_frac: float, passed to the sparse optimizer.
"""
sess, train_op, mask, _, _ = self._setup_graph(
n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2)
# Running 5 times to make sure the mask is not updated after end_iter.
for _ in range(5):
c_mask, = sess.run([mask])
sess.run([train_op])
c_mask2, = sess.run([mask])
self.assertAllEqual(c_mask, c_mask2)
class SparseMomentumOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
freq_iter=2, momentum=0.5):
"""Setups a trivial training procedure for sparse training."""
tf.reset_default_graph()
optim = tf.train.GradientDescentOptimizer(0.1)
sparse_optim = sparse_optimizers.SparseMomentumOptimizer(
optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac,
momentum=momentum)
x = tf.ones((1, n_inp))
y = layers.masked_fully_connected(x, n_out, activation_fn=None)
# Multiplying the output with range of constants to have constant but
# different gradients at the masked weights.
y = y * tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape)
loss = tf.reduce_sum(y)
global_step = tf.train.get_or_create_global_step()
train_op = sparse_optim.minimize(loss, global_step)
weight = pruning.get_weights()[0]
masked_grad = sparse_optim._weight2masked_grads[weight.name]
masked_grad_ema = sparse_optim._ema_grads.average(masked_grad)
# Init
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
return sess, train_op, masked_grad_ema
@parameterized.parameters((3, 4, 0.5), (5, 2, 0.), (2, 5, 1.))
def testMomentumUpdate(self, n_inp, n_out, momentum):
"""Checking whether momentum applied correctly."""
sess, train_op, masked_grad_ema = self._setup_graph(
n_inp, n_out, 0.5, start_iter=1, end_iter=4, freq_iter=2,
momentum=momentum)
# Running 6 times to make sure the momeuntum is always updated.
current_momentum = np.zeros((n_inp, n_out))
for _ in range(6):
ema_masked_grad, = sess.run([masked_grad_ema])
self.assertAllEqual(ema_masked_grad, current_momentum)
sess.run([train_op])
# This is since we multiply the output values with range(n_out)
# Note the broadcast from n_out vector to (n_inp, n_out) matrix.
current_momentum = (current_momentum * momentum +
(1 - momentum) * np.arange(n_out))
ema_masked_grad, = sess.run([masked_grad_ema])
self.assertAllEqual(ema_masked_grad, current_momentum)
class SparseRigLOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
freq_iter=2):
"""Setups a trivial training procedure for sparse training."""
tf.reset_default_graph()
optim = tf.train.GradientDescentOptimizer(1e-3)
global_step = tf.train.get_or_create_global_step()
sparse_optim = sparse_optimizers.SparseRigLOptimizer(
optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
x = tf.ones((1, n_inp))
y = layers.masked_fully_connected(x, n_out, activation_fn=None)
# Multiplying the output with range of constants to have constant but
# different gradients at the masked weights. We also multiply the loss with
# global_step to increase the gradient linearly with time.
scale_vector = (
tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) *
tf.cast(global_step, dtype=y.dtype))
y = y * scale_vector
loss = tf.reduce_sum(y)
global_step = tf.train.get_or_create_global_step()
train_op = sparse_optim.minimize(loss, global_step)
weight = pruning.get_weights()[0]
expected_gradient = tf.broadcast_to(scale_vector, weight.shape)
masked_grad = sparse_optim._weight2masked_grads[weight.name]
# Init
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
return sess, train_op, masked_grad, expected_gradient
@parameterized.parameters((3, 4), (5, 2), (2, 5))
def testMaskedGradientCalculation(self, n_inp, n_out):
"""Checking whether masked_grad is calculated after apply_gradients."""
# No drop since we don't want to change the mask but check whether the grad
# is calculated after the gradient step.
sess, train_op, masked_grad, expected_gradient = self._setup_graph(
n_inp, n_out, 0., start_iter=0, end_iter=3, freq_iter=1)
# Since we only update the mask every 2 iterations, we will iterate 6 times.
for i in range(6):
is_mask_update = i % 2 == 0
if is_mask_update:
expected_gradient_tensor, = sess.run([expected_gradient])
_, masked_grad_tensor = sess.run([train_op, masked_grad])
self.assertAllEqual(masked_grad_tensor,
expected_gradient_tensor)
else:
sess.run([train_op])
@parameterized.parameters(
(3, 7, 2, [1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1]),
(1, 5, 3, [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1]),
(0, 4, 1, [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]))
def testApplyGradients(self, start_iter, end_iter, freq_iter, is_incremented):
"""Checking apply_gradient is called in non mask update iterations."""
sess, train_op, _, _ = self._setup_graph(
3, 5, .5, start_iter=start_iter, end_iter=end_iter, freq_iter=freq_iter)
global_step = tf.train.get_or_create_global_step()
# Since we only update the mask every 2 iterations, we will iterate 6 times.
for one_if_incremented in is_incremented:
before, = sess.run([global_step])
sess.run([train_op])
after, = sess.run([global_step])
if one_if_incremented == 1:
self.assertEqual(before + 1, after)
else:
# Mask update step.
self.assertEqual(before, after)
class SparseSnipOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _setup_graph(self, default_sparsity, mask_init_method,
custom_sparsity_map, n_inp=3, n_out=5):
"""Setups a trivial training procedure for sparse training."""
tf.reset_default_graph()
optim = tf.train.GradientDescentOptimizer(1e-3)
sparse_optim = sparse_optimizers.SparseSnipOptimizer(
optim, default_sparsity, mask_init_method,
custom_sparsity_map=custom_sparsity_map)
inp_values = np.arange(1, n_inp+1)
scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5
# The gradient is the outer product of input and the output gradients.
# Since the loss is sample sum the output gradient is equal to the scale
# vector.
expected_grads = np.outer(inp_values, scale_vector_values)
x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp))
y = layers.masked_fully_connected(x, n_out, activation_fn=None)
scale_vector = tf.constant(scale_vector_values, dtype=tf.float32)
y = y * scale_vector
loss = tf.reduce_sum(y)
global_step = tf.train.get_or_create_global_step()
train_op = sparse_optim.minimize(loss, global_step)
# Init
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
mask = pruning.get_masks()[0]
weights = pruning.get_weights()[0]
return sess, train_op, expected_grads, sparse_optim, mask, weights
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testSnipSparsity(self, n_inp, n_out, default_sparsity):
"""Checking whether masked_grad is calculated after apply_gradients."""
# No drop since we don't want to change the mask but check whether the grad
# is calculated after the gradient step.
sess, train_op, _, _, mask, _ = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
_ = sess.run([train_op])
snipped_mask, = sess.run([mask])
n_ones = np.sum(snipped_mask)
n_zeros = snipped_mask.size - n_ones
n_zeros_expected = sparse_utils.get_n_zeros(snipped_mask.size,
default_sparsity)
self.assertEqual(n_zeros, n_zeros_expected)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testGradientUsed(self, n_inp, n_out, default_sparsity):
"""Checking whether masked_grad is calculated after apply_gradients."""
# No drop since we don't want to change the mask but check whether the grad
# is calculated after the gradient step.
sess, train_op, expected_grads, _, mask, weights = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
# Calculate sensitivity scores.
weights, = sess.run([weights])
expected_scores = np.abs(expected_grads*weights)
_ = sess.run([train_op])
snipped_mask, = sess.run([mask])
kept_connection_scores = expected_scores[snipped_mask == 1]
min_score_kept = np.min(kept_connection_scores)
snipped_connection_scores = expected_scores[snipped_mask == 0]
max_score_snipped = np.max(snipped_connection_scores)
self.assertLessEqual(max_score_snipped, min_score_kept)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testInitialMaskIsDense(self, n_inp, n_out, default_sparsity):
"""Checking whether masked_grad is calculated after apply_gradients."""
# No drop since we don't want to change the mask but check whether the grad
# is calculated after the gradient step.
sess, _, _, _, mask, _ = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
mask_start, = sess.run([mask])
self.assertEqual(np.sum(mask_start), mask_start.size)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testAfterSnipTraining(self, n_inp, n_out, default_sparsity):
"""Checking whether masked_grad is calculated after apply_gradients."""
# No drop since we don't want to change the mask but check whether the grad
# is calculated after the gradient step.
sess, train_op, _, sparse_optim, mask, _ = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
global_step = tf.train.get_or_create_global_step()
is_snip_iter = sess.run([train_op])
self.assertTrue(is_snip_iter)
# On other iterations mask should stay same. Let's do 3 more iterations.
for i in range(3):
mask_before, c_iter = sess.run([mask, global_step])
self.assertEqual(i, c_iter)
is_snip_iter, is_snipped = sess.run([train_op, sparse_optim.is_snipped])
self.assertTrue(is_snipped)
self.assertFalse(is_snip_iter)
mask_after, = sess.run([mask])
self.assertAllEqual(mask_after, mask_before)
class SparseDNWOptimizerTest(tf.test.TestCase, parameterized.TestCase):
def _setup_graph(self,
default_sparsity,
mask_init_method,
custom_sparsity_map,
n_inp=3,
n_out=5):
"""Setups a trivial training procedure for sparse training."""
tf.reset_default_graph()
optim = tf.train.GradientDescentOptimizer(1e-3)
sparse_optim = sparse_optimizers.SparseDNWOptimizer(
optim,
default_sparsity,
mask_init_method,
custom_sparsity_map=custom_sparsity_map)
inp_values = np.arange(1, n_inp + 1)
scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5
# The gradient is the outer product of input and the output gradients.
# Since the loss is sample sum the output gradient is equal to the scale
# vector.
expected_grads = np.outer(inp_values, scale_vector_values)
x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp))
y = layers.masked_fully_connected(x, n_out, activation_fn=None)
scale_vector = tf.constant(scale_vector_values, dtype=tf.float32)
y = y * scale_vector
loss = tf.reduce_sum(y)
global_step = tf.train.get_or_create_global_step()
grads_and_vars = sparse_optim.compute_gradients(loss)
train_op = sparse_optim.apply_gradients(
grads_and_vars, global_step=global_step)
# Init
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
mask = pruning.get_masks()[0]
weights = pruning.get_weights()[0]
return (sess, train_op, (expected_grads, grads_and_vars), mask, weights)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testDNWSparsity(self, n_inp, n_out, default_sparsity):
"""Checking whether masked_grad is calculated after apply_gradients."""
# No drop since we don't want to change the mask but check whether the grad
# is calculated after the gradient step.
sess, train_op, _, mask, _ = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
_ = sess.run([train_op])
dnw_mask, = sess.run([mask])
n_ones = np.sum(dnw_mask)
n_zeros = dnw_mask.size - n_ones
n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size, default_sparsity)
self.assertEqual(n_zeros, n_zeros_expected)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testWeightsUsed(self, n_inp, n_out, default_sparsity):
"""Checking whether masked_grad is calculated after apply_gradients."""
# No drop since we don't want to change the mask but check whether the grad
# is calculated after the gradient step.
sess, train_op, _, mask, weights = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
# Calculate sensitivity scores.
weights, = sess.run([weights])
expected_scores = np.abs(weights)
_ = sess.run([train_op])
dnw_mask, = sess.run([mask])
kept_connection_scores = expected_scores[dnw_mask == 1]
min_score_kept = np.min(kept_connection_scores)
dnw_mask_connection_scores = expected_scores[dnw_mask == 0]
max_score_removed = np.max(dnw_mask_connection_scores)
self.assertLessEqual(max_score_removed, min_score_kept)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testGradientIsDense(self, n_inp, n_out, default_sparsity):
"""Checking whether calculated gradients are dense."""
sess, _, grad_info, _, _ = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
expected_grad, grads_and_vars = grad_info
grad, = sess.run([grads_and_vars[0][0]])
self.assertAllClose(expected_grad, grad)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testDNWUpdates(self, n_inp, n_out, default_sparsity):
"""Checking whether mask is updated correctly."""
sess, train_op, _, mask, weights = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
# On all iterations mask should have least magnitude connections.
for _ in range(5):
sess.run([train_op])
mask_after, weights_after = sess.run([mask, weights])
kept_connection_magnitudes = np.abs(weights_after[mask_after == 1])
min_score_kept = np.min(kept_connection_magnitudes)
removed_connection_magnitudes = np.abs(weights_after[mask_after == 0])
max_score_removed = np.max(removed_connection_magnitudes)
self.assertLessEqual(max_score_removed, min_score_kept)
@parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))
def testSparsityAfterDNWUpdates(self, n_inp, n_out, default_sparsity):
"""Checking whether mask is updated correctly."""
sess, train_op, _, mask, _ = self._setup_graph(
default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)
# On all iterations mask should have least magnitude connections.
for _ in range(5):
sess.run([train_op])
dnw_mask, = sess.run([mask])
n_ones = np.sum(dnw_mask)
n_zeros = dnw_mask.size - n_ones
n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size,
default_sparsity)
self.assertEqual(n_zeros, n_zeros_expected)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: rigl/sparse_utils.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module has helper functions for the interpolation experiments."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import numpy as np
from rigl import str_sparsities
import tensorflow.compat.v1 as tf
from google_research.micronet_challenge import counting
DEFAULT_ERK_SCALE = 1.0
def mask_extract_name_fn(mask_name):
return re.findall('(.+)/mask:0', mask_name)[0]
def get_n_zeros(size, sparsity):
return int(np.floor(sparsity * size))
def calculate_sparsity(masks):
dense_params = tf.constant(0.)
sparse_params = tf.constant(0.)
for mask in masks:
dense_params += tf.cast(tf.size(mask), dtype=dense_params.dtype)
sparse_params += tf.cast(tf.reduce_sum(mask), dtype=sparse_params.dtype)
return 1. - sparse_params / dense_params
def get_mask_random_numpy(mask_shape, sparsity, random_state=None):
"""Creates a random sparse mask with deterministic sparsity.
Args:
mask_shape: list, used to obtain shape of the random mask.
sparsity: float, between 0 and 1.
random_state: np.random.RandomState, if given the shuffle call is made using
the RandomState
Returns:
numpy.ndarray
"""
flat_ones = np.ones(mask_shape).flatten()
n_zeros = get_n_zeros(flat_ones.size, sparsity)
flat_ones[:n_zeros] = 0
if random_state:
random_state.shuffle(flat_ones)
else:
np.random.shuffle(flat_ones)
new_mask = flat_ones.reshape(mask_shape)
return new_mask
def get_mask_random(mask, sparsity, dtype, random_state=None):
"""Creates a random sparse mask with deterministic sparsity.
Args:
mask: tf.Tensor, used to obtain shape of the random mask.
sparsity: float, between 0 and 1.
dtype: tf.dtype, type of the return value.
random_state: np.random.RandomState, if given the shuffle call is made using
the RandomState
Returns:
tf.Tensor
"""
new_mask_numpy = get_mask_random_numpy(
mask.shape.as_list(), sparsity, random_state=random_state)
new_mask = tf.constant(new_mask_numpy, dtype=dtype)
return new_mask
def get_sparsities_erdos_renyi(all_masks,
default_sparsity,
custom_sparsity_map,
include_kernel,
extract_name_fn=mask_extract_name_fn,
erk_power_scale=DEFAULT_ERK_SCALE):
"""Given the method, returns the sparsity of individual layers as a dict.
It ensures that the non-custom layers have a total parameter count as the one
with uniform sparsities. In other words for the layers which are not in the
custom_sparsity_map the following equation should be satisfied.
# eps * (p_1 * N_1 + p_2 * N_2) = (1 - default_sparsity) * (N_1 + N_2)
Args:
all_masks: list, of all mask Variables.
default_sparsity: float, between 0 and 1.
custom_sparsity_map: dict, key/value pairs where the mask
correspond whose name is '{key}/mask:0' is set to the corresponding
sparsity value.
include_kernel: bool, if True kernel dimension are included in the scaling.
extract_name_fn: function, extracts the variable name.
erk_power_scale: float, if given used to take power of the ratio. Use
scale<1 to make the erdos_renyi softer.
Returns:
sparsities, dict of where keys() are equal to all_masks and individiual
masks are mapped to the their sparsities.
"""
# We have to enforce custom sparsities and then find the correct scaling
# factor.
is_eps_valid = False
# # The following loop will terminate worst case when all masks are in the
# custom_sparsity_map. This should probably never happen though, since once
# we have a single variable or more with the same constant, we have a valid
# epsilon. Note that for each iteration we add at least one variable to the
# custom_sparsity_map and therefore this while loop should terminate.
dense_layers = set()
while not is_eps_valid:
# We will start with all layers and try to find right epsilon. However if
# any probablity exceeds 1, we will make that layer dense and repeat the
# process (finding epsilon) with the non-dense layers.
# We want the total number of connections to be the same. Let say we have
# for layers with N_1, ..., N_4 parameters each. Let say after some
# iterations probability of some dense layers (3, 4) exceeded 1 and
# therefore we added them to the dense_layers set. Those layers will not
# scale with erdos_renyi, however we need to count them so that target
# paratemeter count is achieved. See below.
# eps * (p_1 * N_1 + p_2 * N_2) + (N_3 + N_4) =
# (1 - default_sparsity) * (N_1 + N_2 + N_3 + N_4)
# eps * (p_1 * N_1 + p_2 * N_2) =
# (1 - default_sparsity) * (N_1 + N_2) - default_sparsity * (N_3 + N_4)
# eps = rhs / (\sum_i p_i * N_i) = rhs / divisor.
divisor = 0
rhs = 0
raw_probabilities = {}
for mask in all_masks:
var_name = extract_name_fn(mask.name)
shape_list = mask.shape.as_list()
n_param = np.prod(shape_list)
n_zeros = get_n_zeros(n_param, default_sparsity)
if var_name in dense_layers:
# See `- default_sparsity * (N_3 + N_4)` part of the equation above.
rhs -= n_zeros
elif var_name in custom_sparsity_map:
# We ignore custom_sparsities in erdos-renyi calculations.
pass
else:
# Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the
# equation above.
n_ones = n_param - n_zeros
rhs += n_ones
# Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out).
if include_kernel:
raw_probabilities[mask.name] = (np.sum(shape_list) /
np.prod(shape_list))**erk_power_scale
else:
n_in, n_out = shape_list[-2:]
raw_probabilities[mask.name] = (n_in + n_out) / (n_in * n_out)
# Note that raw_probabilities[mask] * n_param gives the individual
# elements of the divisor.
divisor += raw_probabilities[mask.name] * n_param
# By multipliying individual probabilites with epsilon, we should get the
# number of parameters per layer correctly.
eps = rhs / divisor
# If eps * raw_probabilities[mask.name] > 1. We set the sparsities of that
# mask to 0., so they become part of dense_layers sets.
max_prob = np.max(list(raw_probabilities.values()))
max_prob_one = max_prob * eps
if max_prob_one > 1:
is_eps_valid = False
for mask_name, mask_raw_prob in raw_probabilities.items():
if mask_raw_prob == max_prob:
var_name = extract_name_fn(mask_name)
tf.logging.info('Sparsity of var: %s had to be set to 0.', var_name)
dense_layers.add(var_name)
else:
is_eps_valid = True
sparsities = {}
# With the valid epsilon, we can set sparsities of the remaning layers.
for mask in all_masks:
var_name = extract_name_fn(mask.name)
shape_list = mask.shape.as_list()
n_param = np.prod(shape_list)
if var_name in custom_sparsity_map:
sparsities[mask.name] = custom_sparsity_map[var_name]
tf.logging.info('layer: %s has custom sparsity: %f', var_name,
sparsities[mask.name])
elif var_name in dense_layers:
sparsities[mask.name] = 0.
else:
probability_one = eps * raw_probabilities[mask.name]
sparsities[mask.name] = 1. - probability_one
tf.logging.info('layer: %s, shape: %s, sparsity: %f', var_name, mask.shape,
sparsities[mask.name])
return sparsities
def get_sparsities_uniform(all_masks,
default_sparsity,
custom_sparsity_map,
extract_name_fn=mask_extract_name_fn):
"""Given the method, returns the sparsity of individual layers as a dict.
Args:
all_masks: list, of all mask Variables.
default_sparsity: float, between 0 and 1.
custom_sparsity_map: dict, key/value pairs where the mask
correspond whose name is '{key}/mask:0' is set to the corresponding
sparsity value.
extract_name_fn: function, extracts the variable name.
Returns:
sparsities, dict of where keys() are equal to all_masks and individiual
masks are mapped to the their sparsities.
"""
sparsities = {}
for mask in all_masks:
var_name = extract_name_fn(mask.name)
if var_name in custom_sparsity_map:
sparsities[mask.name] = custom_sparsity_map[var_name]
else:
sparsities[mask.name] = default_sparsity
return sparsities
def get_sparsities_str(all_masks, default_sparsity):
"""Given the method, returns the sparsity of individual layers as a dict.
Args:
all_masks: list, of all mask Variables.
default_sparsity: float, between 0 and 1.
Returns:
sparsities, dict of where keys() are equal to all_masks and individiual
masks are mapped to the their sparsities.
"""
str_sparsities_parsed = str_sparsities.read_all()
if default_sparsity in str_sparsities_parsed:
sprsts = str_sparsities_parsed[default_sparsity]
sparsities = {mask.name: sprsts[mask.name] for mask in all_masks}
else:
raise ValueError('sparsity: %f is not defined' % default_sparsity)
return sparsities
def get_sparsities(all_masks,
method,
default_sparsity,
custom_sparsity_map,
extract_name_fn=mask_extract_name_fn,
erk_power_scale=DEFAULT_ERK_SCALE):
"""Given the method, returns the sparsity of individual layers as a dict.
Args:
all_masks: list, of all mask Variables.
method: str, 'random' or 'erdos_renyi'.
default_sparsity: float, between 0 and 1.
custom_sparsity_map: dict, key/value pairs where the mask
correspond whose name is '{key}/mask:0' is set to the corresponding
sparsity value.
extract_name_fn: function, extracts the variable name.
erk_power_scale: float, passed to the erdos_renyi function.
Returns:
sparsities, dict of where keys() are equal to all_masks and individiual
masks are mapped to the their sparsities.
Raises:
ValueError: when a key from custom_sparsity not found in all_masks.
ValueError: when an invalid initialization option is given.
"""
# (1) Ensure all keys are valid and processed.
keys_found = set()
for mask in all_masks:
var_name = extract_name_fn(mask.name)
if var_name in custom_sparsity_map:
keys_found.add(var_name)
keys_given = set(custom_sparsity_map.keys())
if keys_found != keys_given:
diff = keys_given - keys_found
raise ValueError('No masks are found for the following names: %s' %
str(diff))
if method in ('erdos_renyi', 'erdos_renyi_kernel'):
include_kernel = method == 'erdos_renyi_kernel'
sparsities = get_sparsities_erdos_renyi(
all_masks,
default_sparsity,
custom_sparsity_map,
include_kernel=include_kernel,
extract_name_fn=extract_name_fn,
erk_power_scale=erk_power_scale)
elif method == 'random':
sparsities = get_sparsities_uniform(
all_masks,
default_sparsity,
custom_sparsity_map,
extract_name_fn=extract_name_fn)
elif method == 'str':
sparsities = get_sparsities_str(all_masks, default_sparsity)
else:
raise ValueError('Method: %s is not valid mask initialization method' %
method)
return sparsities
def get_mask_init_fn(all_masks,
method,
default_sparsity,
custom_sparsity_map,
mask_fn=get_mask_random,
erk_power_scale=DEFAULT_ERK_SCALE,
extract_name_fn=mask_extract_name_fn):
"""Returns a function for initializing masks randomly.
Args:
all_masks: list, of all masks to be updated.
method: str, method to initialize the masks, passed to the
sparse_utils.get_mask() function.
default_sparsity: float, if 0 mask left intact, if greater than one, a
fraction of the ones in each mask is flipped to 0.
custom_sparsity_map: dict, sparsity of individual variables can be
overridden here. Key should point to the correct variable name, and value
should be in [0, 1].
mask_fn: function, to initialize masks with given sparsity.
erk_power_scale: float, passed to get_sparsities.
extract_name_fn: function, used to grab names from the variable.
Returns:
A callable to run after an init op. See `init_fn` of
`tf.train.Scaffold`. Returns None if no `preinitialize_checkpoint` field
is set in `RunnerSpec`.
Raise:
ValueError: when there is no mask corresponding to a key in the
custom_sparsity_map.
"""
sparsities = get_sparsities(
all_masks,
method,
default_sparsity,
custom_sparsity_map,
erk_power_scale=erk_power_scale,
extract_name_fn=extract_name_fn)
tf.logging.info('Per layer sparsities are like the following: %s',
str(sparsities))
assign_ops = []
for mask in all_masks:
new_mask = mask_fn(mask, sparsities[mask.name], mask.dtype)
assign_op = tf.assign(mask, new_mask)
assign_ops.append(assign_op)
return tf.group(assign_ops)
## Calculating flops and parameters using a list of Keras layers.
def _get_kernel(layer):
"""Given the Keras layer returns the weights."""
if isinstance(layer, tf.keras.layers.DepthwiseConv2D):
return layer.depthwise_kernel
else:
return layer.kernel
def get_stats(masked_layers,
default_sparsity=0.8,
method='erdos_renyi',
custom_sparsities=None,
is_debug=False,
width=1.,
first_layer_name='conv1',
last_layer_name='conv_preds',
param_size=32,
erk_power_scale=DEFAULT_ERK_SCALE):
"""Given the Keras layer returns the size and FLOPS of the model.
Args:
masked_layers: list, of tf.keras.Layer.
default_sparsity: float, if 0 mask left intact, if greater than one, a
fraction of the ones in each mask is flipped to 0.
method: str, passed to the `.get_sparsities()` functions.
custom_sparsities: dictor None, sparsity of individual variables can be
overridden here. Key should point to the correct variable name, and value
should be in [0, 1].
is_debug: bool, if True prints individual stats for given layers.
width: float, multiplier for the individual layer widths.
first_layer_name: str, to scale the width correctly.
last_layer_name: str, to scale the width correctly.
param_size: int, number of bits to represent a single parameter.
erk_power_scale: float, passed to the get_sparsities function.
Returns:
total_flops, sum of multiply and add operations.
total_param_bits, total bits to represent the model during the inference.
real_sparsity, calculated independently omitting bias parameters.
"""
if custom_sparsities is None:
custom_sparsities = {}
sparsities = get_sparsities([_get_kernel(l) for l in masked_layers],
method,
default_sparsity,
custom_sparsities,
lambda a: a,
erk_power_scale=erk_power_scale)
total_flops = 0
total_param_bits = 0
total_params = 0.
n_zeros = 0.
for layer in masked_layers:
kernel = _get_kernel(layer)
k_shape = kernel.shape.as_list()
d_in, d_out = 2, 3
# If fully connected change indices.
if len(k_shape) == 2:
d_in, d_out = 0, 1
# and k_shape[d_in] != 1 since depthwise
if not kernel.name.startswith(first_layer_name) and k_shape[d_in] != 1:
k_shape[d_in] = int(k_shape[d_in] * width)
if not kernel.name.startswith(last_layer_name) and k_shape[d_out] != 1:
k_shape[d_out] = int(k_shape[d_out] * width)
if is_debug:
print(kernel.name, layer.input_shape, k_shape, sparsities[kernel.name])
if isinstance(layer, tf.keras.layers.Conv2D):
layer_op = counting.Conv2D(layer.input_shape[1], k_shape, layer.strides,
'same', True, 'relu')
elif isinstance(layer, tf.keras.layers.DepthwiseConv2D):
layer_op = counting.DepthWiseConv2D(layer.input_shape[1], k_shape,
layer.strides, 'same', True, 'relu')
elif isinstance(layer, tf.keras.layers.Dense):
layer_op = counting.FullyConnected(k_shape, True, 'relu')
else:
raise ValueError('Should not happen.')
param_count, n_mults, n_adds = counting.count_ops(layer_op,
sparsities[kernel.name],
param_size)
total_param_bits += param_count
total_flops += n_mults + n_adds
n_param = np.prod(k_shape)
total_params += n_param
n_zeros += int(n_param * sparsities[kernel.name])
return total_flops, total_param_bits, n_zeros / total_params
================================================
FILE: rigl/sparse_utils_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the data_helper input pipeline and the training process.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from rigl import sparse_utils
import tensorflow.compat.v1 as tf
class GetMaskRandomTest(tf.test.TestCase, parameterized.TestCase):
def _setup_session(self):
"""Resets the graph and returns a fresh session."""
tf.reset_default_graph()
sess = tf.Session()
return sess
@parameterized.parameters(((30, 40), 0.5), ((1, 2, 1, 4), 0.8), ((3,), 0.1))
def testMaskConnectionDeterminism(self, shape, sparsity):
sess = self._setup_session()
mask = tf.ones(shape)
mask1 = sparse_utils.get_mask_random(mask, sparsity, tf.int32)
mask2 = sparse_utils.get_mask_random(mask, sparsity, tf.int32)
mask1_array, = sess.run([mask1])
mask2_array, = sess.run([mask2])
self.assertEqual(np.sum(mask1_array), np.sum(mask2_array))
@parameterized.parameters(((30, 4), 0.5, 60), ((1, 2, 1, 4), 0.8, 2),
((30,), 0.1, 27))
def testMaskFraction(self, shape, sparsity, expected_ones):
sess = self._setup_session()
mask = tf.ones(shape)
mask1 = sparse_utils.get_mask_random(mask, sparsity, tf.int32)
mask1_array, = sess.run([mask1])
self.assertEqual(np.sum(mask1_array), expected_ones)
@parameterized.parameters(tf.int32, tf.float32, tf.int64, tf.float64)
def testMaskDtype(self, dtype):
_ = self._setup_session()
mask = tf.ones((3, 2))
mask1 = sparse_utils.get_mask_random(mask, 0.5, dtype)
self.assertEqual(mask1.dtype, dtype)
class GetSparsitiesTest(tf.test.TestCase, parameterized.TestCase):
def _setup_session(self):
"""Resets the graph and returns a fresh session."""
tf.reset_default_graph()
sess = tf.Session()
return sess
@parameterized.parameters(0., 0.4, 0.9)
def testSparsityDictRandom(self, default_sparsity):
_ = self._setup_session()
all_masks = [tf.get_variable(shape=(2, 3), name='var1/mask'),
tf.get_variable(shape=(2, 3), name='var2/mask'),
tf.get_variable(shape=(1, 1, 3), name='var3/mask')]
custom_sparsity = {'var1': 0.8}
sparsities = sparse_utils.get_sparsities(
all_masks, 'random', default_sparsity, custom_sparsity)
self.assertEqual(sparsities[all_masks[0].name], 0.8)
self.assertEqual(sparsities[all_masks[1].name], default_sparsity)
self.assertEqual(sparsities[all_masks[2].name], default_sparsity)
@parameterized.parameters(0.1, 0.4, 0.9)
def testSparsityDictErdosRenyiCustom(self, default_sparsity):
_ = self._setup_session()
all_masks = [tf.get_variable(shape=(2, 4), name='var1/mask'),
tf.get_variable(shape=(2, 3), name='var2/mask'),
tf.get_variable(shape=(1, 1, 3), name='var3/mask')]
custom_sparsity = {'var3': 0.8}
sparsities = sparse_utils.get_sparsities(
all_masks, 'erdos_renyi', default_sparsity, custom_sparsity)
self.assertEqual(sparsities[all_masks[2].name], 0.8)
@parameterized.parameters(0.1, 0.4, 0.9)
def testSparsityDictErdosRenyiError(self, default_sparsity):
_ = self._setup_session()
all_masks = [tf.get_variable(shape=(2, 4), name='var1/mask'),
tf.get_variable(shape=(2, 3), name='var2/mask'),
tf.get_variable(shape=(1, 1, 3), name='var3/mask')]
custom_sparsity = {'var3': 0.8}
sparsities = sparse_utils.get_sparsities(
all_masks, 'erdos_renyi', default_sparsity, custom_sparsity)
self.assertEqual(sparsities[all_masks[2].name], 0.8)
@parameterized.parameters(((2, 3), (2, 3), 0.5),
((1, 1, 2, 3), (1, 1, 2, 3), 0.3),
((8, 6), (4, 3), 0.7),
((80, 4), (20, 20), 0.8),
((2, 6), (2, 3), 0.8))
def testSparsityDictErdosRenyiSparsitiesScale(
self, shape1, shape2, default_sparsity):
_ = self._setup_session()
all_masks = [tf.get_variable(shape=shape1, name='var1/mask'),
tf.get_variable(shape=shape2, name='var2/mask')]
custom_sparsity = {}
sparsities = sparse_utils.get_sparsities(
all_masks, 'erdos_renyi', default_sparsity, custom_sparsity)
sparsity1 = sparsities[all_masks[0].name]
size1 = np.prod(shape1)
sparsity2 = sparsities[all_masks[1].name]
size2 = np.prod(shape2)
# Ensure that total number of connections are similar.
expected_zeros_uniform = (
sparse_utils.get_n_zeros(size1, default_sparsity) +
sparse_utils.get_n_zeros(size2, default_sparsity))
# Ensure that total number of connections are similar.
expected_zeros_current = (
sparse_utils.get_n_zeros(size1, sparsity1) +
sparse_utils.get_n_zeros(size2, sparsity2))
# Due to rounding we can have some difference. This is expected but should
# be less than number of rounding operations we make.
diff = abs(expected_zeros_uniform - expected_zeros_current)
tolerance = 2
self.assertLessEqual(diff, tolerance)
# Ensure that ErdosRenyi proportions are preserved.
factor1 = (shape1[-1] + shape1[-2]) / float(shape1[-1] * shape1[-2])
factor2 = (shape2[-1] + shape2[-2]) / float(shape2[-1] * shape2[-2])
self.assertAlmostEqual((1 - sparsity1) / factor1,
(1 - sparsity2) / factor2)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: rigl/str_sparsities.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reads ResNet-50 sparsity distributions found by STR.
[STR]: https://arxiv.org/abs/2002.03231
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
REPORTED_SPARSITIES = """
Overall - Overall 25502912 4089284608 79.55 81.27 87.70 90.23 90.55 94.80 95.03 95.15 96.11 96.53 97.78 98.05 98.22 98.79 98.98 99.10
Layer 1 - conv1 9408 118013952 51.46 51.40 63.02 59.80 59.83 64.87 67.36 66.96 72.11 69.46 73.29 73.47 72.05 75.12 76.12 77.75
Layer 2 - layer1.0.conv1 4096 12845056 69.36 73.24 87.57 83.28 85.18 89.60 91.41 91.11 92.38 91.75 94.46 94.51 94.60 95.95 96.53 96.51
Layer 3 - layer1.0.conv2 36864 115605504 77.85 76.26 90.87 89.48 87.31 94.79 94.27 95.04 95.69 96.07 97.36 97.77 98.35 98.51 98.59 98.84
Layer 4 - layer1.0.conv3 16384 51380224 74.81 74.65 86.52 85.80 85.25 91.85 92.78 93.67 94.13 94.69 96.61 97.03 97.37 98.04 98.21 98.47
Layer 5 - layer1.0.downsample.0 16384 51380224 70.95 72.96 83.53 83.34 82.56 89.13 90.62 90.17 91.83 92.69 95.48 94.89 95.68 96.98 97.56 97.72
Layer 6 - layer1.1.conv1 16384 51380224 80.27 79.58 89.82 89.89 88.51 94.56 96.64 95.78 95.81 96.81 98.79 98.90 98.98 99.13 99.62 99.47
Layer 7 - layer1.1.conv2 36864 115605504 81.36 80.95 91.75 90.60 89.61 94.70 95.78 96.18 96.42 97.26 98.65 99.07 99.40 99.11 99.31 99.56
Layer 8 - layer1.1.conv3 16384 51380224 84.45 80.11 91.22 91.70 90.21 95.17 97.05 95.81 96.34 97.23 98.68 98.76 98.90 99.16 99.57 99.46
Layer 9 - layer1.2.conv1 16384 51380224 78.23 79.79 90.12 88.07 89.36 94.62 95.94 94.74 96.23 96.75 97.96 98.41 98.72 99.38 99.35 99.46
Layer 10 - layer1.2.conv2 36864 115605504 76.01 81.53 91.06 87.03 88.27 93.90 95.63 94.26 96.24 96.11 97.54 98.27 98.44 99.32 99.19 99.39
Layer 11 - layer1.2.conv3 16384 51380224 84.47 83.28 94.95 90.99 92.64 95.76 96.95 96.01 96.87 97.31 98.38 98.60 98.72 99.38 99.27 99.51
Layer 12 - layer2.0.conv1 32768 102760448 73.74 73.96 86.78 85.95 85.90 92.32 94.79 93.86 94.62 95.64 97.19 98.22 98.52 98.48 98.84 98.92
Layer 13 - layer2.0.conv2 147456 115605504 82.56 85.70 91.31 93.91 94.03 97.54 97.43 97.65 98.38 98.62 99.24 99.23 99.40 99.61 99.67 99.63
Layer 14 - layer2.0.conv3 65536 51380224 84.70 83.55 93.04 93.13 92.13 96.61 97.37 97.21 97.59 98.14 98.80 98.95 99.18 99.29 99.47 99.43
Layer 15 - layer2.0.downsample.0 131072 102760448 85.10 87.66 92.78 94.96 95.13 98.07 97.97 98.15 98.70 98.88 99.37 99.35 99.40 99.69 99.68 99.71
Layer 16 - layer2.1.conv1 65536 51380224 85.42 85.79 94.04 95.31 94.94 97.92 98.53 98.21 98.84 99.06 99.46 99.53 99.72 99.78 99.81 99.80
Layer 17 - layer2.1.conv2 147456 115605504 76.95 82.75 87.63 91.50 91.76 95.59 97.22 96.07 97.32 97.80 98.24 98.24 98.60 99.24 99.66 99.33
Layer 18 - layer2.1.conv3 65536 51380224 84.76 84.71 93.10 93.66 93.23 97.00 98.18 97.35 98.06 98.41 98.96 99.21 99.32 99.55 99.58 99.59
Layer 19 - layer2.2.conv1 65536 51380224 84.30 85.34 92.70 94.61 94.76 97.72 97.91 98.21 98.54 98.98 99.24 99.35 99.50 99.62 99.63 99.77
Layer 20 - layer2.2.conv2 147456 115605504 84.28 85.43 92.99 94.86 94.90 97.52 97.21 98.11 98.19 99.04 99.28 99.37 99.46 99.63 99.59 99.72
Layer 21 - layer2.2.conv3 65536 51380224 82.19 84.21 91.12 93.38 93.53 96.89 97.14 97.59 97.77 98.66 98.96 99.15 99.25 99.49 99.51 99.57
Layer 22 - layer2.3.conv1 65536 51380224 83.37 84.41 90.46 93.26 93.50 96.71 97.89 96.99 98.14 98.36 99.10 99.23 99.33 99.53 99.75 99.60
Layer 23 - layer2.3.conv2 147456 115605504 82.83 84.03 91.44 93.21 93.25 96.83 98.02 96.96 98.45 98.30 98.97 99.06 99.26 99.31 99.81 99.68
Layer 24 - layer2.3.conv3 65536 51380224 82.93 85.65 91.02 94.14 93.56 97.20 97.97 97.04 98.16 98.36 98.88 98.97 99.20 99.32 99.67 99.62
Layer 25 - layer3.0.conv1 131072 102760448 76.63 77.98 85.99 88.85 88.60 94.26 95.07 94.97 96.21 96.59 97.75 98.04 98.30 98.72 99.11 99.06
Layer 26 - layer3.0.conv2 589824 115605504 87.35 88.68 94.39 96.14 96.19 98.51 98.77 98.72 99.11 99.23 99.53 99.59 99.64 99.73 99.80 99.81
Layer 27 - layer3.0.conv3 262144 51380224 81.22 83.22 90.58 93.19 93.05 96.82 97.38 97.32 97.98 98.28 98.88 99.03 99.16 99.39 99.55 99.53
Layer 28 - layer3.0.downsample.0 524288 102760448 89.75 90.99 96.05 97.20 97.16 98.96 99.21 99.20 99.50 99.58 99.78 99.82 99.86 99.91 99.94 99.93
Layer 29 - layer3.1.conv1 262144 51380224 85.88 87.35 93.43 95.36 96.12 98.64 98.77 98.87 99.22 99.33 99.64 99.67 99.72 99.82 99.88 99.84
Layer 30 - layer3.1.conv2 589824 115605504 85.06 86.24 92.74 95.06 95.30 98.09 98.28 98.36 98.75 99.08 99.46 99.48 99.54 99.69 99.76 99.76
Layer 31 - layer3.1.conv3 262144 51380224 84.34 86.79 92.15 94.84 94.90 97.75 98.15 98.11 98.56 98.94 99.30 99.36 99.45 99.65 99.79 99.70
Layer 32 - layer3.2.conv1 262144 51380224 87.51 89.15 94.15 96.77 96.46 98.81 98.83 98.96 99.19 99.44 99.67 99.71 99.74 99.82 99.85 99.89
Layer 33 - layer3.2.conv2 589824 115605504 87.15 88.67 94.09 95.59 96.14 98.86 98.69 98.91 99.21 99.20 99.64 99.72 99.76 99.85 99.84 99.90
Layer 34 - layer3.2.conv3 262144 51380224 84.86 86.90 92.40 94.99 94.99 98.19 98.19 98.42 98.76 98.97 99.42 99.56 99.62 99.76 99.75 99.88
Layer 35 - layer3.3.conv1 262144 51380224 86.62 89.46 94.06 96.08 95.88 98.70 98.71 98.77 99.01 99.27 99.58 99.66 99.69 99.83 99.87 99.87
Layer 36 - layer3.3.conv2 589824 115605504 86.52 87.97 93.56 96.10 96.11 98.70 98.82 98.89 99.19 99.31 99.68 99.73 99.77 99.88 99.87 99.93
Layer 37 - layer3.3.conv3 262144 51380224 84.19 86.81 92.32 94.94 94.91 98.20 98.37 98.43 98.82 99.00 99.51 99.57 99.64 99.81 99.81 99.87
Layer 38 - layer3.4.conv1 262144 51380224 85.85 88.40 93.55 95.49 95.86 98.35 98.44 98.55 98.79 98.96 99.54 99.59 99.60 99.82 99.86 99.87
Layer 39 - layer3.4.conv2 589824 115605504 85.96 87.38 93.27 95.66 95.63 98.41 98.58 98.56 99.19 99.26 99.64 99.69 99.67 99.87 99.90 99.92
Layer 40 - layer3.4.conv3 262144 51380224 83.45 85.76 91.75 94.49 94.35 97.67 98.09 97.99 98.65 98.94 99.49 99.52 99.48 99.77 99.86 99.85
Layer 41 - layer3.5.conv1 262144 51380224 83.33 85.77 91.79 95.09 94.24 97.46 97.89 97.92 98.71 98.90 99.35 99.52 99.58 99.76 99.79 99.83
Layer 42 - layer3.5.conv2 589824 115605504 84.98 86.67 92.48 94.92 95.13 97.88 98.14 98.32 98.91 99.00 99.44 99.58 99.69 99.80 99.83 99.87
Layer 43 - layer3.5.conv3 262144 51380224 79.78 82.23 89.39 93.14 92.76 96.59 97.04 97.30 98.10 98.41 99.03 99.25 99.44 99.61 99.71 99.75
Layer 44 - layer4.0.conv1 524288 102760448 77.83 79.61 87.11 90.32 90.64 95.39 95.84 95.92 97.17 97.35 98.36 98.60 98.83 99.20 99.37 99.42
Layer 45 - layer4.0.conv2 2359296 115605504 86.18 88.00 93.53 95.66 95.78 98.31 98.47 98.55 99.08 99.16 99.54 99.63 99.69 99.81 99.85 99.86
Layer 46 - layer4.0.conv3 1048576 51380224 78.43 80.48 87.85 91.14 91.27 96.00 96.40 96.47 97.53 97.92 98.81 99.00 99.15 99.45 99.57 99.61
Layer 47 - layer4.0.downsample.0 2097152 102760448 88.49 89.98 95.03 96.79 96.90 98.91 99.06 99.11 99.45 99.51 99.77 99.82 99.85 99.92 99.94 99.94
Layer 48 - layer4.1.conv1 1048576 51380224 82.07 84.02 90.34 93.69 93.72 97.15 97.56 97.76 98.45 98.75 99.27 99.36 99.54 99.67 99.76 99.80
Layer 49 - layer4.1.conv2 2359296 115605504 83.42 85.23 91.16 93.98 93.93 97.26 97.58 97.71 98.36 98.67 99.25 99.34 99.50 99.68 99.76 99.80
Layer 50 - layer4.1.conv3 1048576 51380224 78.08 79.96 86.66 90.48 90.22 95.22 95.76 95.89 96.88 97.65 98.70 98.85 99.13 99.45 99.58 99.66
Layer 51 - layer4.2.conv1 1048576 51380224 76.34 77.93 84.98 87.57 88.47 93.90 93.87 94.16 95.55 95.91 97.66 97.97 98.15 98.88 99.08 99.22
Layer 52 - layer4.2.conv2 2359296 115605504 73.57 74.97 82.32 84.37 86.01 91.92 91.66 92.22 94.02 94.16 96.65 97.13 97.29 98.44 98.74 99.00
Layer 53 - layer4.2.conv3 1048576 51380224 68.78 70.38 78.11 80.29 81.73 89.64 89.43 89.65 91.40 92.65 96.02 96.72 96.93 98.47 98.83 99.15
Layer 54 - fc 2048000 2048000 50.65 52.46 60.48 64.50 65.12 75.20 75.73 75.80 78.57 80.69 85.96 87.26 88.03 91.11 92.15 92.87"""
def _name_map_str(k):
"""Maps the naming of the layers."""
if k == 'conv1':
new_key = 'initial_conv'
elif k == 'fc':
new_key = 'final_dense'
else:
if 'downsample' in k:
group_id = re.search(r'layer(\d)\.0\.downsample\.0', k).group(1)
new_key = 'bottleneck_projection_block_group_projection_block_group%s' % group_id
else:
res = re.search(r'layer(\d)\.(\d)\.conv(\d)', k)
group_id, block_id, layer_id = (int(res.group(1)), int(res.group(2)),
int(res.group(3)))
if block_id == 0:
new_key = 'bottleneck_%d_block_group_projection_block_group%d' % (
layer_id, group_id)
else:
new_key = 'bottleneck_%d_block_group%d_%d_1' % (layer_id, group_id,
block_id)
return 'resnet_model/%s/mask:0' % new_key
def read_all():
"""Reads and returns sparsity distributions."""
str_sparsities_parsed = collections.defaultdict(dict)
for l in REPORTED_SPARSITIES.strip().split('\n'):
l = l.split('-')[1].strip().split(' ')
if l[0] == 'Overall':
overall_sparsities = list(map(float, l[3:]))
else:
for i, ls in enumerate(l[3:]):
# Sparsities are between 0 and 1, so devide by 100.
s = overall_sparsities[i] / 100
new_key = _name_map_str(l[0])
# Accuracies are between 0 and 1, so devide by 100.
str_sparsities_parsed[s][new_key] = float(ls) / 100.
return str_sparsities_parsed
================================================
FILE: run.sh
================================================
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
set -x
virtualenv -p python3 env
source env/bin/activate
pip install -r rigl/requirements.txt
python -m rigl.sparse_optimizers_test
python -m rigl.sparse_utils_test