Showing preview only (951K chars total). Download the full file or copy to clipboard to get everything.
Repository: google-research/rigl
Branch: master
Commit: d39fc7d46505
Files: 141
Total size: 902.2 KB
Directory structure:
gitextract_rhnta46k/
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── rigl/
│ ├── __init__.py
│ ├── cifar_resnet/
│ │ ├── data_helper.py
│ │ ├── data_helper_test.py
│ │ ├── resnet_model.py
│ │ └── resnet_train_eval.py
│ ├── experimental/
│ │ └── jax/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── analysis/
│ │ │ └── plot_summary_json.ipynb
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── cifar10.py
│ │ │ ├── cifar10_test.py
│ │ │ ├── dataset_base.py
│ │ │ ├── dataset_base_test.py
│ │ │ ├── dataset_factory.py
│ │ │ ├── dataset_factory_test.py
│ │ │ ├── mnist.py
│ │ │ └── mnist_test.py
│ │ ├── fixed_param.py
│ │ ├── fixed_param_test.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── cifar10_cnn.py
│ │ │ ├── cifar10_cnn_test.py
│ │ │ ├── mnist_cnn.py
│ │ │ ├── mnist_cnn_test.py
│ │ │ ├── mnist_fc.py
│ │ │ ├── mnist_fc_test.py
│ │ │ ├── model_factory.py
│ │ │ └── model_factory_test.py
│ │ ├── prune.py
│ │ ├── prune_test.py
│ │ ├── pruning/
│ │ │ ├── __init__.py
│ │ │ ├── init.py
│ │ │ ├── init_test.py
│ │ │ ├── mask_factory.py
│ │ │ ├── mask_factory_test.py
│ │ │ ├── masked.py
│ │ │ ├── masked_test.py
│ │ │ ├── pruning.py
│ │ │ ├── pruning_test.py
│ │ │ ├── symmetry.py
│ │ │ └── symmetry_test.py
│ │ ├── random_mask.py
│ │ ├── random_mask_test.py
│ │ ├── requirements.txt
│ │ ├── run.sh
│ │ ├── shuffled_mask.py
│ │ ├── shuffled_mask_test.py
│ │ ├── train.py
│ │ ├── train_test.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── training.py
│ │ │ └── training_test.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── imagenet_resnet/
│ │ ├── colabs/
│ │ │ ├── MobileNet_Counting.ipynb
│ │ │ └── Resnet_50_Param_Flops_Counting.ipynb
│ │ ├── imagenet_train_eval.py
│ │ ├── mobilenetv1_model.py
│ │ ├── mobilenetv2_model.py
│ │ ├── pruning_layers.py
│ │ ├── resnet_model.py
│ │ ├── train_test.py
│ │ ├── utils.py
│ │ └── vgg.py
│ ├── mnist/
│ │ ├── mnist_train_eval.py
│ │ └── visualize_mask_records.py
│ ├── requirements.txt
│ ├── rigl_tf2/
│ │ ├── README.md
│ │ ├── colabs/
│ │ │ └── MnistProp.ipynb
│ │ ├── configs/
│ │ │ ├── dense.gin
│ │ │ ├── grasp.gin
│ │ │ ├── hessian.gin
│ │ │ ├── interpolate.gin
│ │ │ ├── lottery.gin
│ │ │ ├── prune.gin
│ │ │ ├── rigl.gin
│ │ │ ├── scratch.gin
│ │ │ ├── set.gin
│ │ │ ├── small_dense.gin
│ │ │ └── snip.gin
│ │ ├── init_utils.py
│ │ ├── interpolate.py
│ │ ├── mask_updaters.py
│ │ ├── metainit.py
│ │ ├── mlp_configs/
│ │ │ ├── dense.gin
│ │ │ ├── lottery.gin
│ │ │ ├── prune.gin
│ │ │ ├── rigl.gin
│ │ │ ├── scratch.gin
│ │ │ ├── set.gin
│ │ │ └── small_dense.gin
│ │ ├── networks.py
│ │ ├── train.py
│ │ └── utils.py
│ ├── rl/
│ │ ├── README.md
│ │ ├── dqn_agents.py
│ │ ├── requirements.txt
│ │ ├── run.sh
│ │ ├── run_experiment.py
│ │ ├── sparse_utils.py
│ │ ├── sparsetrain_configs/
│ │ │ ├── dqn_atari_dense.gin
│ │ │ ├── dqn_atari_dense_impala_net.gin
│ │ │ ├── dqn_atari_prune.gin
│ │ │ ├── dqn_atari_prune_impala_net.gin
│ │ │ ├── dqn_atari_rigl.gin
│ │ │ ├── dqn_atari_rigl_impala_net.gin
│ │ │ ├── dqn_atari_set.gin
│ │ │ ├── dqn_atari_set_impala_net.gin
│ │ │ ├── dqn_atari_static.gin
│ │ │ └── dqn_atari_static_impala_net.gin
│ │ ├── tfagents/
│ │ │ ├── configs/
│ │ │ │ ├── dqn_gym_dense_config.gin
│ │ │ │ ├── dqn_gym_pruning_config.gin
│ │ │ │ ├── dqn_gym_sparse_config.gin
│ │ │ │ ├── ppo_mujoco_dense_config.gin
│ │ │ │ ├── ppo_mujoco_pruning_config.gin
│ │ │ │ ├── ppo_mujoco_sparse_config.gin
│ │ │ │ ├── sac_mujoco_dense_config.gin
│ │ │ │ ├── sac_mujoco_pruning_config.gin
│ │ │ │ └── sac_mujoco_sparse_config.gin
│ │ │ ├── dqn_train_eval.py
│ │ │ ├── ppo_train_eval.py
│ │ │ ├── sac_train_eval.py
│ │ │ ├── sparse_encoding_network.py
│ │ │ ├── sparse_ppo_actor_network.py
│ │ │ ├── sparse_ppo_discrete_actor_network.py
│ │ │ ├── sparse_ppo_discrete_actor_network_test.py
│ │ │ ├── sparse_tanh_normal_projection_network.py
│ │ │ ├── sparse_value_network.py
│ │ │ └── tf_sparse_utils.py
│ │ └── train.py
│ ├── sparse_optimizers.py
│ ├── sparse_optimizers_base.py
│ ├── sparse_optimizers_test.py
│ ├── sparse_utils.py
│ ├── sparse_utils_test.py
│ └── str_sparsities.py
└── run.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute
We'd love to accept your patches and contributions to this project.
- If you want to contribute to the library please check `Issues` tab and feel
free to take on any problem/issue you find interesting.
- If your `issue` is not reported yet, please create a new one. It is
important to discuss the problem/request before implementing the solution.
- Reach us at rigl.authors@gmail.com any time!
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Rigging the Lottery: Making All Tickets Winners
<img src="https://github.com/google-research/rigl/blob/master/imgs/flops8.jpg" alt="80% Sparse Resnet-50" width="45%" align="middle">
**Paper**: [https://arxiv.org/abs/1911.11134](https://arxiv.org/abs/1911.11134)
**15min Presentation** [[pml4dc](https://pml4dc.github.io/iclr2020/program/pml4dc_7.html)] [[icml](https://icml.cc/virtual/2020/paper/5808)]
**ML Reproducibility Challenge 2020** [report](https://openreview.net/forum?id=riCIeP6LzEE)
## Colabs for Calculating FLOPs of Sparse Models
[MobileNet-v1](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb)
[ResNet-50](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb)
## Best Sparse Models
Parameters are float, so each parameter is represented with 4 bytes. Uniform
sparsity distribution keeps first layer dense therefore have slightly larger size
and parameters. ERK applies to all layers except for 99% sparse model, in which
we set the first layer to be dense, since otherwise we observe much worse
performance.
### Extended Training Results
Performance of RigL increases significantly with extended training iterations.
In this section we extend the training of sparse models by 5x. Note that sparse
models require much less FLOPs per training iteration and therefore most of the
extended trainings cost less FLOPs than baseline dense training.
Observing improving performance we wanted to understand where the performance of sparse networks saturates. Longest training we ran had 100x training length of the original
100 epoch ImageNet training. This training costs 5.8x of the original dense training FLOPS and the resulting 99% sparse Resnet-50 achieves an impressive 68.15% test accuracy (vs 5x training accuracy of 61.86%).
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| - (DENSE) | 0 | 3.2e18 | 8.2e9 | 102.122 | 76.8 | - |
| ERK | 0.8 | 2.09x | 0.42x | 23.683 | 77.17 | [link](https://storage.googleapis.com/gresearch/rigl/s80erk5x.tar.gz) |
| Uniform | 0.8 | 1.14x | 0.23x | 23.685 | 76.71 | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform5x.tar.gz) |
| ERK | 0.9 | 1.23x | 0.24x | 13.499 | 76.42 | [link](https://storage.googleapis.com/gresearch/rigl/s90erk5x.tar.gz) |
| Uniform | 0.9 | 0.66x | 0.13x | 13.532 | 75.73 | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform5x.tar.gz) |
| ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.63 | [link](https://storage.googleapis.com/gresearch/rigl/s95erk5x.tar.gz) |
| Uniform | 0.95 | 0.42x | 0.08x | 8.433 | 73.22 | [link](https://storage.googleapis.com/gresearch/rigl/s95uniform5x.tar.gz) |
| ERK | 0.965 | 0.45x | 0.09x | 6.904 | 72.77 | [link](https://storage.googleapis.com/gresearch/rigl/s965erk5x.tar.gz) |
| Uniform | 0.965 | 0.34x | 0.07x | 6.904 | 71.31 | [link](https://storage.googleapis.com/gresearch/rigl/s965uniform5x.tar.gz) |
| ERK | 0.99 | 0.29x | 0.05x | 4.354 | 61.86 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk5x.tar.gz) |
| ERK | 0.99 | 0.58x | 0.05x | 4.354 | 63.89 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk10x.tar.gz) |
| ERK | 0.99 | 2.32x | 0.05x | 4.354 | 66.94 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk40x.tar.gz) |
| ERK | **0.99** | 5.8x | 0.05x | 4.354 | **68.15** | [link](https://storage.googleapis.com/gresearch/rigl/s99erk100x.tar.gz) |
We also ran extended training runs with MobileNet-v1. Again training 100x more,
we were not able saturate the performance. Training longer consistently achieved
better results.
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| - (DENSE) | 0 | 4.5e17 | 1.14e9 | 16.864 | 72.1 | - |
| ERK | 0.89 | 1.39x | 0.21x | 2.392 | 69.31 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk10x.tar.gz) |
| ERK | 0.89 | 2.79x | 0.21x | 2.392 | 70.63 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk50x.tar.gz) |
| Uniform | 0.89 | 1.25x | 0.09x | 2.392 | 69.28 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform10x.tar.gz) |
| Uniform | 0.89 | 6.25x | 0.09x | 2.392 | 70.25 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform50x.tar.gz) |
| Uniform | 0.89 | 12.5x | 0.09x | 2.392 | 70.59 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform100x.tar.gz) |
### 1x Training Results
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.12 | [link](https://storage.googleapis.com/gresearch/rigl/s80erk1x.tar.gz) |
| Uniform | 0.8 | 0.23x | 0.23x | 23.685 | 74.60 | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform1x.tar.gz) |
| ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.07 | [link](https://storage.googleapis.com/gresearch/rigl/s90erk1x.tar.gz) |
| Uniform | 0.9 | 0.13x | 0.13x | 13.532 | 72.02 | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform1x.tar.gz) |
### Results w/o label smoothing
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt |
|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|
| ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.02 | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_1x.tar.gz) |
| ERK | 0.8 | 2.09x | 0.42x | 23.683 | 76.17 | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_5x.tar.gz) |
| ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.4 | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_1x.tar.gz) |
| ERK | 0.9 | 1.23x | 0.24x | 13.499 | 75.9 | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_5x.tar.gz) |
| ERK | 0.95 | 0.13x | 0.12x | 8.399 | 70.39 | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_1x.tar.gz) |
| ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.36 | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_5x.tar.gz) |
### Evaluating checkpoints
Download the checkpoints and run the evaluation on ERK checkpoints with the
following:
```python
python imagenet_train_eval.py --mode=eval_once --output_dir=path/to/ckpt/folder \
--eval_once_ckpt_prefix=model.ckpt-3200000 --use_folder_stub=False \
--training_method=rigl --mask_init_method=erdos_renyi_kernel \
--first_layer_sparsity=-1
```
When running checkpoints with uniform sparsity distribution use `--mask_init_method=random` and `--first_layer_sparsity=0`. Set
`--model_architecture=mobilenet_v1` when evaluating mobilenet checkpoints.
## Sparse Training Algorithms
In this repository we implement following dynamic sparsity strategies:
1. [SET](https://www.nature.com/articles/s41467-018-04316-3): Implements Sparse
Evalutionary Training (SET) which corresponds to replacing low magnitude
connections randomly with new ones.
2. [SNFS](https://arxiv.org/abs/1907.04840): Implements momentum based training
*without* sparsity re-distribution:
3. [RigL](https://arxiv.org/abs/1911.11134): Our method, RigL, removes a
fraction of connections based on weight magnitudes and activates new ones
using instantaneous gradient information.
And the following one-shot pruning algorithm:
1. [SNIP](https://arxiv.org/abs/1810.02340): Single-shot Network Pruning based
on connection sensitivity prunes the least salient connections before training.
We have code for following settings:
- [Imagenet2012](https://github.com/google-research/rigl/tree/master/rigl/imagenet_resnet):
TPU compatible code with Resnet-50 and MobileNet-v1/v2.
- [CIFAR-10](https://github.com/google-research/rigl/tree/master/rigl/cifar_resnet)
with WideResNets.
- [MNIST](https://github.com/google-research/rigl/tree/master/rigl/mnist) with
2 layer fully connected network.
## Setup
First clone this repo.
```bash
git clone https://github.com/google-research/rigl.git
cd rigl
```
We use [Neurips 2019 MicroNet Challenge](https://micronet-challenge.github.io/)
code for counting operations and size of our networks. Let's clone the
google_research repo and add current folder to the python path.
```bash
git clone https://github.com/google-research/google-research.git
mv google-research/ google_research/
export PYTHONPATH=$PYTHONPATH:$PWD
```
Now we can run some tests. Following script creates a virtual environment and
installs the necessary libraries. Finally, it runs few tests.
```bash
bash run.sh
```
We need to activate the virtual environment before running an experiment. With
that, we are ready to run some trivial MNIST experiments.
```bash
source env/bin/activate
python rigl/mnist/mnist_train_eval.py
```
You can load and verify the performance of the Resnet-50 checkpoints
like following.
```bash
python rigl/imagenet_resnet/imagenet_train_eval.py --mode=eval_once --training_method=baseline --eval_batch_size=100 --output_dir=/path/to/folder --eval_once_ckpt_prefix=s80_model.ckpt-1280000 --use_folder_stub=False
```
We use the [Official TPU Code](https://github.com/tensorflow/tpu/tree/master/models/official/resnet)
for loading ImageNet data. First clone the
tensorflow/tpu repo and then add models/ folder to the python path.
```bash
git clone https://github.com/tensorflow/tpu.git
export PYTHONPATH=$PYTHONPATH:$PWD/tpu/models/
```
## Other Implementations
- [Graphcore-TF-MNIST](https://github.com/graphcore/examples/tree/master/applications/tensorflow/dynamic_sparsity/mnist_rigl): with sparse matrix ops!
- [Pytorch implementation](https://github.com/McCrearyD/rigl-torch) by Dyllan McCreary.
- [Micrograd-Pure Python](https://evcu.github.io/ml/sparse-micrograd/): This is
a toy example with pure python sparse implementation. Caution, very slow but fun.
## Citation
```
@incollection{rigl,
author = {Evci, Utku and Gale, Trevor and Menick, Jacob and Castro, Pablo Samuel and Elsen, Erich},
booktitle = {Proceedings of Machine Learning and Systems 2020},
pages = {471--481},
title = {Rigging the Lottery: Making All Tickets Winners},
year = {2020}
}
```
## Disclaimer
This is not an official Google product.
================================================
FILE: rigl/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This repo involves the code for training sparse neural networks."""
name = 'rigl'
================================================
FILE: rigl/cifar_resnet/data_helper.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for CIFAR10 data input pipeline.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
IMG_SIZE = 32
def pad_input(x, crop_dim=4):
"""Concatenates sides of image with pixels cropped from the border of image.
Args:
x: Input image float32 tensor.
crop_dim: Number of pixels to crop from the edge of the image.
Cropped pixels are then concatenated to the original image.
Returns:
x: input image float32 tensor. Transformed by padding edges with cropped
pixels.
"""
x = tf.concat(
[x[:crop_dim, :, :][::-1], x, x[-crop_dim:, :, :][::-1]], axis=0)
x = tf.concat(
[x[:, :crop_dim, :][:, ::-1], x, x[:, -crop_dim:, :][:, ::-1]], axis=1)
return x
def preprocess_train(x, width, height):
"""Pre-processing applied to training data set.
Args:
x: Input image float32 tensor.
width: int specifying intended width in pixels of image after preprocessing.
height: int specifying intended height in pixels of image after
preprocessing.
Returns:
x: transformed input with random crops, flips and reflection.
"""
x = pad_input(x, crop_dim=4)
x = tf.random_crop(x, [width, height, 3])
x = tf.image.random_flip_left_right(x)
return x
def input_fn(params):
"""Provides batches of CIFAR data.
Args:
params: A dictionary with a set of arguments, namely:
* batch_size (int32), specifies data points in a batch
* data_split (string), designates train or eval
* data_dictionary (string), specifies directory location of input dataset
Returns:
images: A float32`Tensor` of size [batch_size, 32, 32, 3].
labels: A int32`Tensor` of size [batch_size, num_classes].
"""
def parse_serialized_example(record):
"""Parses a CIFAR10 example."""
image = record['image']
label = tf.cast(record['label'], tf.int32)
image = tf.cast(image, tf.float32)
image = tf.image.per_image_standardization(image)
if data_split == 'train':
image = preprocess_train(image, IMG_SIZE, IMG_SIZE)
return image, label
data_split = params['data_split']
batch_size = params['batch_size']
if data_split == 'eval':
data_split = 'test'
dataset = tfds.load('cifar10:3.*.*', split=data_split)
# we only repeat an example and shuffle inputs during training
if data_split == 'train':
dataset = dataset.repeat().shuffle(buffer_size=50000)
# deserialize record into tensors and apply pre-processing.
dataset = dataset.map(parse_serialized_example).prefetch(batch_size)
# at test time, for the final batch we drop remaining examples so that no
# example is seen twice.
dataset = dataset.batch(batch_size)
images_batch, labels_batch = tf.data.make_one_shot_iterator(
dataset).get_next()
return (tf.reshape(images_batch, [batch_size, IMG_SIZE, IMG_SIZE, 3]),
tf.reshape(labels_batch, [batch_size]))
================================================
FILE: rigl/cifar_resnet/data_helper_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Tests for the data_helper input pipeline and the training process.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from absl import logging
import absl.testing.parameterized as parameterized
from rigl.cifar_resnet import resnet_train_eval
from rigl.cifar_resnet.data_helper import input_fn
import tensorflow.compat.v1 as tf
from tensorflow.contrib.model_pruning.python import pruning
FLAGS = flags.FLAGS
BATCH_SIZE = 1
NUM_IMAGES = 1
JITTER_MULTIPLIER = 2
class DataHelperTest(tf.test.TestCase, parameterized.TestCase):
def get_next(self):
data_directory = FLAGS.data_directory
# we pass the updated eval and train string to the params dictionary.
params = {
'mode': 'test',
'data_split': 'eval',
'batch_size': BATCH_SIZE,
'data_directory': data_directory
}
test_inputs, test_labels = input_fn(params)
return test_inputs, test_labels
def testInputPipeline(self):
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():
test_inputs, test_labels = self.get_next()
with self.test_session() as sess:
test_images_out, test_labels_out = sess.run([test_inputs, test_labels])
self.assertAllEqual(test_images_out.shape, [BATCH_SIZE, 32, 32, 3])
self.assertAllEqual(test_labels_out.shape, [BATCH_SIZE])
@parameterized.parameters(
{
'training_method': 'baseline',
},
{
'training_method': 'threshold',
},
{
'training_method': 'rigl',
},
)
def testTrainingStep(self, training_method):
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():
images, labels = self.get_next()
global_step, _, _, logits = resnet_train_eval.build_model(
mode='train',
images=images,
labels=labels,
training_method=training_method,
num_classes=FLAGS.num_classes,
depth=FLAGS.resnet_depth,
width=FLAGS.resnet_width)
tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)
learning_rate = 0.1
opt = tf.train.MomentumOptimizer(
learning_rate, momentum=FLAGS.momentum, use_nesterov=True)
if training_method in ['threshold']:
# Create a pruning object using the pruning hyperparameters
pruning_obj = pruning.Pruning()
logging.info('starting mask update op')
mask_update_op = pruning_obj.conditional_mask_update_op()
# Create the training op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = opt.minimize(total_loss, global_step)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
# test that we can train successfully for 1 step
sess.run(init_op)
for _ in range(1):
sess.run(train_op)
if training_method in ['threshold']:
sess.run(mask_update_op)
if __name__ == '__main__':
tf.test.main()
================================================
FILE: rigl/cifar_resnet/resnet_model.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Model implementation of wide resnet model.
Implements masking layer if pruning method is selected.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from rigl.imagenet_resnet.pruning_layers import sparse_conv2d
from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected
import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
_BN_EPS = 1e-5
_BN_MOMENTUM = 0.9
class WideResNetModel(object):
"""Implements WideResNet model."""
def __init__(self,
is_training,
regularizer=None,
data_format='channels_last',
pruning_method='baseline',
droprate=0.3,
prune_first_layer=True,
prune_last_layer=True):
"""WideResnet as described in https://arxiv.org/pdf/1605.07146.pdf.
Args:
is_training: Boolean, True during model training,
false for evaluation/inference.
regularizer: A regularization function (mapping variables to
regularization losses), or None.
data_format: A string that indicates whether the channels are the second
or last index in the matrix. 'channels_first' or 'channels_last'.
pruning_method: str, 'threshold' or 'baseline'.
droprate: float, dropout rate to apply activations.
prune_first_layer: bool, if True first layer is pruned.
prune_last_layer: bool, if True last layer is pruned.
"""
self._training = is_training
self._regularizer = regularizer
self._data_format = data_format
self._pruning_method = pruning_method
self._droprate = droprate
self._prune_first_layer = prune_first_layer
self._prune_last_layer = prune_last_layer
if data_format == 'channels_last':
self._channel_axis = -1
elif data_format == 'channels_first':
self._channel_axis = 1
def build(self, inputs, depth, width, num_classes, name=None):
"""Model architecture to train the model.
The configuration of the resnet blocks requires that depth should be
6n+4 where n is the number of resnet blocks desired.
Args:
inputs: A 4D float tensor containing the model inputs.
depth: Number of convolutional layers in the network.
width: Size of the convolutional filters in the residual blocks.
num_classes: Positive integer number of possible classes.
name: Optional string, the name of the resulting op in the TF graph.
Returns:
A 2D float logits tensor of shape (batch_size, num_classes).
Raises:
ValueError: if depth is not the minimum amount required to build the
model.
"""
if (depth - 4) % 6 != 0:
raise ValueError('Depth of ResNet specified not sufficient.')
resnet_blocks = (depth - 4) // 6
with tf.variable_scope(name, 'resnet_model'):
first_layer_technique = self._pruning_method
if not self._prune_first_layer:
first_layer_technique = 'baseline'
net = self._conv(
inputs,
'conv_1',
output_size=16,
sparsity_technique=first_layer_technique)
net = self._residual_block(
net, 'conv_2', 16 * width, subsample=False, blocks=resnet_blocks)
net = self._residual_block(
net, 'conv_3', 32 * width, subsample=True, blocks=resnet_blocks)
net = self._residual_block(
net, 'conv_4', 64 * width, subsample=True, blocks=resnet_blocks)
# Put the final BN, relu before the max pooling.
with tf.name_scope('Pooling'):
net = self._batch_norm(net)
net = tf.nn.relu(net)
net = tf.layers.average_pooling2d(
net, pool_size=8, strides=1, data_format=self._data_format)
net = contrib_layers.flatten(net)
last_layer_technique = self._pruning_method
if not self._prune_last_layer:
last_layer_technique = 'baseline'
net = self._dense(
net, num_classes, 'logits', sparsity_technique=last_layer_technique)
return net
def _batch_norm(self, net, name=None):
"""Adds batchnorm to the model.
Input gradients cannot be computed with fused batch norm; causes recursive
loop of tf.gradient call. If regularizer is specified, fused batchnorm must
be set to False (default setting).
Args:
net: Pre-batch norm tensor activations.
name: Specified name for batch normalization layer.
Returns:
batch norm layer: Activations from the batch normalization layer.
"""
return tf.layers.batch_normalization(
inputs=net,
fused=False,
training=self._training,
axis=self._channel_axis,
momentum=_BN_MOMENTUM,
epsilon=_BN_EPS,
name=name)
def _dense(self, net, num_units, name=None, sparsity_technique='baseline'):
return sparse_fully_connected(
x=net,
units=num_units,
sparsity_technique=sparsity_technique,
kernel_regularizer=self._regularizer,
name=name)
def _conv(self,
net,
name,
output_size,
strides=(1, 1),
padding='SAME',
sparsity_technique='baseline'):
"""returns conv layer."""
return sparse_conv2d(
x=net,
units=output_size,
activation=None,
kernel_size=[3, 3],
use_bias=False,
kernel_initializer=None,
kernel_regularizer=self._regularizer,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique=sparsity_technique,
normalizer_fn=None,
strides=strides,
padding=padding,
data_format=self._data_format,
name=name)
def _residual_block(self, net, name, output_size, subsample, blocks):
"""Adds a residual block to the model."""
with tf.name_scope(name):
for n in range(blocks):
with tf.name_scope('res_%d' % n):
# when subsample is true + first block a larger stride is used.
if subsample and n == 0:
strides = [2, 2]
else:
strides = [1, 1]
# Create the skip connection
skip = net
end_point = 'skip_%s' % name
net = self._batch_norm(net)
net = tf.nn.relu(net)
if net.get_shape()[3].value != output_size:
skip = sparse_conv2d(
x=net,
units=output_size,
activation=None,
kernel_size=[1, 1],
use_bias=False,
kernel_initializer=None,
kernel_regularizer=self._regularizer,
bias_initializer=None,
biases_regularizer=None,
sparsity_technique=self._pruning_method,
normalizer_fn=None,
strides=strides,
padding='VALID',
data_format=self._data_format,
name=end_point)
# Create residual
net = self._conv(
net,
'%s_%d_1' % (name, n),
output_size,
strides,
sparsity_technique=self._pruning_method)
net = self._batch_norm(net)
net = tf.nn.relu(net)
net = tf.keras.layers.Dropout(self._droprate)(net, self._training)
net = self._conv(
net,
'%s_%d_2' % (name, n),
output_size,
sparsity_technique=self._pruning_method)
# Combine the residual and the skip connection
net += skip
return net
================================================
FILE: rigl/cifar_resnet/resnet_train_eval.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""This script trains a ResNet model that implements various pruning methods.
Implement pruning method during training:
Specify the pruning method to use using FLAGS.training_method
- To train a model with no pruning, specify FLAGS.training_method='baseline'
Specify desired end sparsity using FLAGS.end_sparsity
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
from rigl import sparse_optimizers
from rigl import sparse_utils
from rigl.cifar_resnet.data_helper import input_fn
from rigl.cifar_resnet.resnet_model import WideResNetModel
from rigl.imagenet_resnet import utils
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import training as contrib_training
from tensorflow.contrib.model_pruning.python import pruning
flags.DEFINE_string('master', 'local',
'BNS name of the TensorFlow runtime to use.')
flags.DEFINE_integer('ps_task', 0,
'Task id of the replica running the training.')
flags.DEFINE_integer('keep_checkpoint_max', 5,
'Number of checkpoints to save, set 0 for all.')
flags.DEFINE_string('pruning_hparams', '',
'Comma separated list of pruning-related hyperparameters')
flags.DEFINE_string('train_dir', '/tmp/cifar10/',
'Directory where to write event logs and checkpoint.')
flags.DEFINE_string(
'load_mask_dir', '',
'Directory of a trained model from which to load only the mask')
flags.DEFINE_string(
'initial_value_checkpoint', '',
'Directory of a model from which to load only the parameters')
flags.DEFINE_integer(
'seed', default=0, help=('Sets the random seed.'))
flags.DEFINE_float('momentum', 0.9, 'The momentum value.')
# 250 Epochs
flags.DEFINE_integer('max_steps', 97656, 'Number of steps to run.')
flags.DEFINE_float('l2', 5e-4, 'Scale factor for L2 weight decay.')
flags.DEFINE_integer('resnet_depth', 16, 'Number of core convolutional layers'
'in the network.')
flags.DEFINE_integer('resnet_width', 4, 'Width of the residual blocks.')
flags.DEFINE_string(
'data_directory', '', 'data directory where cifar10 records are stored')
flags.DEFINE_integer('num_classes', 10, 'Number of classes.')
flags.DEFINE_integer('dataset_size', 50000, 'Size of training dataset.')
flags.DEFINE_integer('batch_size', 128, 'Batch size.')
flags.DEFINE_integer('checkpoint_steps', 5000, 'Specifies step interval for'
'saving model checkpoints.')
flags.DEFINE_integer(
'summaries_steps', 300, 'Specifies interval in steps for'
'saving model summaries.')
flags.DEFINE_bool('per_class_metrics', True, 'Whether to add per-class'
'performance summaries.')
flags.DEFINE_enum('mode', 'train', ('train_and_eval', 'train', 'eval'),
'String that specifies either inference or training')
# pruning flags
flags.DEFINE_integer('sparsity_begin_step', 20000, 'Step to begin pruning at.')
flags.DEFINE_integer('sparsity_end_step', 75000, 'Step to end pruning at.')
flags.DEFINE_integer('pruning_frequency', 1000,
'Step interval between pruning steps.')
flags.DEFINE_float('end_sparsity', 0.9,
'Target sparsity desired by end of training.')
flags.DEFINE_enum(
'training_method', 'baseline',
('scratch', 'set', 'baseline', 'momentum', 'rigl', 'static', 'snip',
'prune'),
'Method used for training sparse network. `scratch` means initial mask is '
'kept during training. `set` is for sparse evalutionary training and '
'`baseline` is for dense baseline.')
flags.DEFINE_bool('prune_first_layer', False,
'Whether or not to apply sparsification to the first layer')
flags.DEFINE_bool('prune_last_layer', True,
'Whether or not to apply sparsification to the last layer')
flags.DEFINE_float('drop_fraction', 0.3,
'When changing mask dynamically, this fraction decides how '
'much of the ')
flags.DEFINE_string('drop_fraction_anneal', 'constant',
'If not empty the drop fraction is annealed during sparse'
' training. One of the following: `constant`, `cosine` or '
'`exponential_(\\d*\\.?\\d*)$`. For example: '
'`exponential_3`, `exponential_.3`, `exponential_0.3`. '
'The number after `exponential` defines the exponent.')
flags.DEFINE_string('grow_init', 'zeros',
'Passed to the SparseInitializer, one of: zeros, '
'initial_value, random_normal, random_uniform.')
flags.DEFINE_float('s_momentum', 0.9,
'Momentum values for exponential moving average of '
'gradients. Used when training_method="momentum".')
flags.DEFINE_float('rigl_acc_scale', 0.,
'Used to scale initial accumulated gradients for new '
'connections.')
flags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin mask updates.')
flags.DEFINE_integer('maskupdate_end_step', 75000, 'Step to end mask updates.')
flags.DEFINE_integer('maskupdate_frequency', 100,
'Step interval between mask updates.')
flags.DEFINE_string(
'mask_init_method',
default='random',
help='If not empty string and mask is not loaded from a checkpoint, '
'indicates the method used for mask initialization. One of the following: '
'`random`, `erdos_renyi`.')
flags.DEFINE_float('training_steps_multiplier', 1.0,
'Training schedule is shortened or extended with the '
'multiplier, if it is not 1.')
FLAGS = flags.FLAGS
PARAM_SUFFIXES = ('gamma', 'beta', 'weights', 'biases')
MASK_SUFFIX = 'mask'
CLASSES = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
'ship', 'truck'
]
def create_eval_metrics(labels, logits):
"""Creates the evaluation metrics for the model."""
eval_metrics = {}
label_keys = CLASSES
predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
labels=labels, predictions=predictions)
if FLAGS.per_class_metrics:
with tf.name_scope('class_level_summaries') as scope:
for i in range(len(label_keys)):
labels = tf.cast(labels, tf.int64)
name = scope + '/' + label_keys[i]
eval_metrics[('class_level_summaries/precision/' +
label_keys[i])] = tf.metrics.precision_at_k(
labels=labels,
predictions=logits,
class_id=i,
k=1,
name=name)
eval_metrics[('class_level_summaries/recall/' +
label_keys[i])] = tf.metrics.recall_at_k(
labels=labels,
predictions=logits,
class_id=i,
k=1,
name=name)
return eval_metrics
def train_fn(training_method, global_step, total_loss, train_dir, accuracy,
top_5_accuracy):
"""Training script for resnet model.
Args:
training_method: specifies the method used to sparsify networks.
global_step: the current step of training/eval.
total_loss: tensor float32 of the cross entropy + regularization losses.
train_dir: string specifying where directory where summaries are saved.
accuracy: tensor float32 batch classification accuracy.
top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes).
Returns:
hooks: summary tensors to be computed at each training step.
eval_metrics: set to None during training.
train_op: the optimization term.
"""
# Rougly drops at every 30k steps.
boundaries = [30000, 60000, 90000]
if FLAGS.training_steps_multiplier != 1.0:
multiplier = FLAGS.training_steps_multiplier
boundaries = [int(x * multiplier) for x in boundaries]
tf.logging.info(
'Learning Rate boundaries are updated with multiplier:%.2f', multiplier)
learning_rate = tf.train.piecewise_constant(
global_step,
boundaries,
values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)],
name='lr_schedule')
optimizer = tf.train.MomentumOptimizer(
learning_rate, momentum=FLAGS.momentum, use_nesterov=True)
if training_method == 'set':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseSETOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal)
elif training_method == 'static':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseStaticOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal)
elif training_method == 'momentum':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseMomentumOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
grow_init=FLAGS.grow_init,
drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)
elif training_method == 'rigl':
# We override the train op to also update the mask.
optimizer = sparse_optimizers.SparseRigLOptimizer(
optimizer, begin_step=FLAGS.maskupdate_begin_step,
end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
frequency=FLAGS.maskupdate_frequency,
drop_fraction=FLAGS.drop_fraction,
drop_fraction_anneal=FLAGS.drop_fraction_anneal,
initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)
elif training_method == 'snip':
optimizer = sparse_optimizers.SparseSnipOptimizer(
optimizer, mask_init_method=FLAGS.mask_init_method,
default_sparsity=FLAGS.end_sparsity, use_tpu=False)
elif training_method in ('scratch', 'baseline', 'prune'):
pass
else:
raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
# Create the training op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(total_loss, global_step)
if training_method == 'prune':
# construct the necessary hparams string from the FLAGS
hparams_string = ('begin_pruning_step={0},'
'sparsity_function_begin_step={0},'
'end_pruning_step={1},'
'sparsity_function_end_step={1},'
'target_sparsity={2},'
'pruning_frequency={3},'
'threshold_decay=0,'
'use_tpu={4}'.format(
FLAGS.sparsity_begin_step,
FLAGS.sparsity_end_step,
FLAGS.end_sparsity,
FLAGS.pruning_frequency,
False,
))
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
# Create a pruning object using the pruning hyperparameters
pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
tf.logging.info('starting mask update op')
# We override the train op to also update the mask.
with tf.control_dependencies([train_op]):
train_op = pruning_obj.conditional_mask_update_op()
masks = pruning.get_masks()
mask_metrics = utils.mask_summaries(masks)
for name, tensor in mask_metrics.items():
tf.summary.scalar(name, tensor)
tf.summary.scalar('learning_rate', learning_rate)
tf.summary.scalar('accuracy', accuracy)
tf.summary.scalar('total_loss', total_loss)
tf.summary.scalar('top_5_accuracy', top_5_accuracy)
# Logging drop_fraction if dynamic sparse training.
if training_method in ('set', 'momentum', 'rigl', 'static'):
tf.summary.scalar('drop_fraction', optimizer.drop_fraction)
summary_op = tf.summary.merge_all()
summary_hook = tf.train.SummarySaverHook(
save_secs=300, output_dir=train_dir, summary_op=summary_op)
hooks = [summary_hook]
eval_metrics = None
return hooks, eval_metrics, train_op
def build_model(mode,
images,
labels,
training_method='baseline',
num_classes=10,
depth=10,
width=4):
"""Build the wide ResNet model for training or eval.
If regularizer is specified, a regularizer term is added to the loss function.
The regularizer term is computed using either the pre-softmax activation or an
auxiliary network logits layer based upon activations earlier in the network
after the first resnet block.
Args:
mode: String for whether training or evaluation is taking place.
images: A 4D float32 tensor containing the model input images.
labels: A int32 tensor of size (batch size, number of classes)
containing the model labels.
training_method: The method used to sparsify the network weights.
num_classes: The number of distinct labels in the dataset.
depth: Number of core convolutional layers in the network.
width: The width of the convolurional filters in the resnet block.
Returns:
total_loss: A 1D float32 tensor that is the sum of cross-entropy and
all regularization losses.
accuracy: A 1D float32 accuracy tensor.
Raises:
ValueError: if depth is not the minimum amount required to build the
model.
"""
regularizer_term = tf.constant(FLAGS.l2, tf.float32)
kernel_regularizer = contrib_layers.l2_regularizer(scale=regularizer_term)
# depth should be 6n+4 where n is the desired number of resnet blocks
# if n=2,depth=10 n=3,depth=22, n=5,depth=34 n=7,depth=46
if (depth - 4) % 6 != 0:
raise ValueError('Depth of ResNet specified not sufficient.')
if mode == 'train':
is_training = True
else:
is_training = False
# 'threshold' would create layers with mask.
pruning_method = 'baseline' if training_method == 'baseline' else 'threshold'
model = WideResNetModel(
is_training=is_training,
regularizer=kernel_regularizer,
data_format='channels_last',
pruning_method=pruning_method,
prune_first_layer=FLAGS.prune_first_layer,
prune_last_layer=FLAGS.prune_last_layer)
logits = model.build(
images, depth=depth, width=width, num_classes=num_classes)
global_step = tf.train.get_or_create_global_step()
predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32))
in_top_5 = tf.cast(
tf.nn.in_top_k(predictions=logits, targets=labels, k=5), tf.float32)
top_5_accuracy = tf.cast(tf.reduce_mean(in_top_5), tf.float32)
return global_step, accuracy, top_5_accuracy, logits
def wide_resnet_w_pruning(features, labels, mode, params):
"""The model_fn for ResNet wide with pruning.
Args:
features: A float32 batch of images.
labels: A int32 batch of labels.
mode: Specifies whether training or evaluation.
params: Dictionary of parameters passed to the model.
Returns:
A EstimatorSpec for the model
Raises:
ValueError: if mode is not recognized as train or eval.
"""
if isinstance(features, dict):
features = features['feature']
train_dir = params['train_dir']
training_method = params['training_method']
global_step, accuracy, top_5_accuracy, logits = build_model(
mode=mode,
images=features,
labels=labels,
training_method=training_method,
num_classes=FLAGS.num_classes,
depth=FLAGS.resnet_depth,
width=FLAGS.resnet_width)
if mode == tf_estimator.ModeKeys.PREDICT:
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
return tf_estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'classify': tf_estimator.export.PredictOutput(predictions)
})
with tf.name_scope('computing_cross_entropy_loss'):
entropy_loss = tf.losses.sparse_softmax_cross_entropy(
labels=labels, logits=logits)
tf.summary.scalar('cross_entropy_loss', entropy_loss)
with tf.name_scope('computing_total_loss'):
total_loss = tf.losses.get_total_loss(add_regularization_losses=True)
if mode == tf_estimator.ModeKeys.TRAIN:
hooks, eval_metrics, train_op = train_fn(training_method, global_step,
total_loss, train_dir, accuracy,
top_5_accuracy)
elif mode == tf_estimator.ModeKeys.EVAL:
hooks = None
train_op = None
with tf.name_scope('summaries'):
eval_metrics = create_eval_metrics(labels, logits)
else:
raise ValueError('mode not recognized as training or eval.')
# If given load parameter values.
if FLAGS.initial_value_checkpoint:
tf.logging.info('Loading inital values from: %s',
FLAGS.initial_value_checkpoint)
utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,
FLAGS.train_dir, PARAM_SUFFIXES)
# Load or randomly initialize masks.
if (FLAGS.load_mask_dir and
FLAGS.training_method not in ('snip', 'baseline', 'prune')):
# Init masks.
tf.logging.info('Loading masks from %s', FLAGS.load_mask_dir)
utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir, FLAGS.train_dir,
MASK_SUFFIX)
scaffold = tf.train.Scaffold()
elif (FLAGS.mask_init_method and
FLAGS.training_method not in ('snip', 'baseline', 'scratch', 'prune')):
tf.logging.info('Initializing masks using method: %s',
FLAGS.mask_init_method)
all_masks = pruning.get_masks()
assigner = sparse_utils.get_mask_init_fn(
all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, {})
def init_fn(scaffold, session):
"""A callable for restoring variable from a checkpoint."""
del scaffold # Unused.
session.run(assigner)
scaffold = tf.train.Scaffold(init_fn=init_fn)
else:
assert FLAGS.training_method in ('snip', 'baseline', 'prune')
scaffold = None
tf.logging.info('No mask is set, starting dense.')
return tf_estimator.EstimatorSpec(
mode=mode,
training_hooks=hooks,
loss=total_loss,
train_op=train_op,
eval_metric_ops=eval_metrics,
scaffold=scaffold)
def main(argv):
del argv # Unused.
tf.set_random_seed(FLAGS.seed)
if FLAGS.training_steps_multiplier != 1.0:
multiplier = FLAGS.training_steps_multiplier
FLAGS.max_steps = int(FLAGS.max_steps * multiplier)
FLAGS.maskupdate_begin_step = int(FLAGS.maskupdate_begin_step * multiplier)
FLAGS.maskupdate_end_step = int(FLAGS.maskupdate_end_step * multiplier)
FLAGS.sparsity_begin_step = int(FLAGS.sparsity_begin_step * multiplier)
FLAGS.sparsity_end_step = int(FLAGS.sparsity_end_step * multiplier)
tf.logging.info(
'Training schedule is updated with multiplier: %.2f', multiplier)
# configures train directories based upon hyperparameters used.
if FLAGS.training_method == 'prune':
folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),
str(FLAGS.sparsity_begin_step),
str(FLAGS.sparsity_end_step),
str(FLAGS.pruning_frequency))
elif FLAGS.training_method in ('set', 'momentum', 'rigl', 'static'):
folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),
str(FLAGS.maskupdate_begin_step),
str(FLAGS.maskupdate_end_step),
str(FLAGS.maskupdate_frequency))
elif FLAGS.training_method in ('baseline', 'snip', 'scratch'):
folder_stub = os.path.join(FLAGS.training_method, str(0.0), str(0.0),
str(0.0), str(0.0))
else:
raise ValueError('Training method is not known %s' % FLAGS.training_method)
train_dir = os.path.join(FLAGS.train_dir, folder_stub)
# we pass the updated eval and train string to the params dictionary.
params = {}
params['train_dir'] = train_dir
params['data_split'] = FLAGS.mode
params['batch_size'] = FLAGS.batch_size
params['data_directory'] = FLAGS.data_directory
params['mode'] = FLAGS.mode
params['training_method'] = FLAGS.training_method
run_config = tf_estimator.RunConfig(
model_dir=train_dir,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
save_summary_steps=FLAGS.summaries_steps,
save_checkpoints_steps=FLAGS.checkpoint_steps,
log_step_count_steps=100)
classifier = tf_estimator.Estimator(
model_fn=wide_resnet_w_pruning,
model_dir=train_dir,
config=run_config,
params=params)
if FLAGS.mode == 'eval':
eval_steps = 10000 // FLAGS.batch_size
# Run evaluation when there's a new checkpoint
for ckpt in contrib_training.checkpoints_iterator(train_dir):
print('Starting to evaluate.')
try:
classifier.evaluate(
input_fn=input_fn,
steps=eval_steps,
checkpoint_path=ckpt,
name='eval')
# Terminate eval job when final checkpoint is reached
global_step = int(os.path.basename(ckpt).split('-')[1])
if global_step >= FLAGS.max_steps:
print('Evaluation finished after training step %d' % global_step)
break
except tf.errors.NotFoundError:
print('Checkpoint no longer exists,skipping checkpoint.')
else:
print('Starting training...')
if FLAGS.mode == 'train':
classifier.train(input_fn=input_fn, max_steps=FLAGS.max_steps)
if __name__ == '__main__':
tf.app.run(main)
================================================
FILE: rigl/experimental/jax/README.md
================================================
# Weight Symmetry Research Code
This code is mostly written by Yani Ioannou.
## Experiment Summary
There are a number of experiment drivers defined in the base directory:
### Experiment Types {#experiment-types}
random_mask
: Random Variable Sparsity Masks
: This experiment generates random masks of a given type (see
[Mask Types](#mask-types)) within the *given a sparsity range*, and trains
the models, tracking mask statistics and training details. Masks are
generated with a random number of connections and randomly shuffled.
shuffled_mask
: Random Fixed Sparsity Masks
: This experiment generates random masks of a given type (see
[Mask Types](#mask-types)) *of a fixed sparsity*, and trains the models,
tracking mask statistics and training details. Masks are generated with a
fixed number of connections and simply shuffled.
fixed_param
: Train models with (approximately) fixed number of parameters, but varying
depth/width.
: Train models with (approximately) fixed number of parameters, but varying
depth/width, with shuffled mask (as in shuffled_mask driver), and only the
MNIST_FC model type.
prune
: Simple Pruning/Training Driver
: This experiment trains a dense model pruning either iteratively or one-shot,
tracking mask statistics and training details.
train
: Simple Training Driver (Without Masking/Pruning)
: This experiment simply trains a dense model, tracking mask statistics and
training details.
### Mask Types {#mask-types}
symmetric
: Structured Mask.
: The mask is a structured
random
: Unstructured Mask.
: The mask as a whole is a random mask of a given sparsity, with some neurons
having fewer/more connections than others.
per-neuron
: Unstructured Mask.
: Each neuron has the same sparsity (# of masked connections), but is shuffled
randomly.
per-neuron-no-input-ablation:
: Unstructured Mask.
: As with per-neuron, each neuron has the same sparsity, but randomly shuffled
connections. Also at least one connection is maintained to each of the input
neurons (i.e. the input neurons are not effectively ablated), although these
connections are also randomly shuffled amongst the neurons of a given layer.
### Model Types {#model-types}
MNIST_FC
: A small fully-connected model, accepting number of neurons and depth as
parameters. No batch normalization, configurable drop-out rate (default: 0).
MNIST_CNN
: A small convolutional model designed for MNIST, accepting number of filters
for each layer and depth as parameters. Uses batch normalization and
configurable drop-out rate (default: 0).
CIFAR10_CNN
: A larger convolutional model designed for CIFAR10, accepting number of
filters for each layer and depth as parameters. No batch normalization,
configurable drop-out rate (default: 0).
### Dataset Types {#dataset-types}
MNIST
: Wrapper of the Tensorflow Datasets (TFDS) MNIST dataset.
CIFAR10
: Wrapper of the Tensorflow Datasets (TFDS) CIFAR10 dataset.
## Running Experiments
### Running on a Workstation
Train:
```shell
python -m weight_symmetry:${EXPERIMENT_TYPE}
```
## Result Processing/Analysis
### Plotting Results from a JSON Summary File
You can convert the results to a Pandas dataframe from a JSON summary file for
plotting/analysis using the example colab in `analysis/plot_summary_json.ipynb`.
================================================
FILE: rigl/experimental/jax/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains code for weight symmetry experiments."""
name = 'weight_symmetry'
================================================
FILE: rigl/experimental/jax/analysis/plot_summary_json.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "6iEEw5OwSlnz"
},
"source": [
"# Plot Results from an Experiment Summary JSON File",
"Licensed under the Apache License, Version 2.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Eg6FmoCaTCHM"
},
"source": [
"## Parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ML0hUJMzYF0W"
},
"outputs": [],
"source": [
"from google.colab import files\n",
"\n",
"# Experiment summary filenames (one per experiment)\n",
"SUMMARY_FILES = files.upload()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "MHubbscQSLGm"
},
"outputs": [],
"source": [
"# Labels to use for each of the summaries listed above (in the same order!)\n",
"XID_LABELS=['structured', 'unstructured'] #@param"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "x0jDBWKdU_2A"
},
"source": [
"## Loading of JSON Summary/Conversion to Pandas Dataframe"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Lz-HwS1tU-ie"
},
"outputs": [],
"source": [
"import json\n",
"import pandas as pd\n",
"import os\n",
"\n",
"from colabtools.interactive_widgets import ProgressIter\n",
"\n",
"dfs = []\n",
"for i, summary_file in enumerate(SUMMARY_FILES):\n",
" with open(summary_file) as summary_file:\n",
" data = json.load(summary_file)\n",
" dataframe = pd.DataFrame.from_dict(data, orient='index')\n",
" dataframe['experiment_label'] = XID_LABELS[i]\n",
" dfs.append(dataframe)\n",
"\n",
"df=pd.concat(dfs)\n",
"\n",
"print('Loaded {} rows for experiment'.format(len(data)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DhO6oT1nVpTV"
},
"source": [
"## Measurements and Labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XFRR3XrXVopB"
},
"outputs": [],
"source": [
"DATA_LABELS={\n",
" 'best_train_loss/test_accuracy': 'Test Accuracy (of best train loss)',\n",
" 'best_train_loss/train_accuracy': 'Train Accuracy (of best train loss)',\n",
" 'best_train_loss/test_avg_loss': 'Test Loss (of best train loss)',\n",
" 'best_train_loss/train_avg_loss': 'Train Loss (of best train loss)',\n",
" 'best_train_loss/step': 'Training Iterations (of best train loss)',\n",
" 'best_train_loss/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best train loss)',\n",
" 'best_train_loss/vector_difference_norm': 'Vector Difference Norm. (of best train loss)',\n",
" 'best_train_loss/cosine_distance': 'Cosine Similarity (of best train loss)',\n",
" 'best_test_acc/test_accuracy': 'Test Accuracy (of best test acc.)',\n",
" 'best_test_acc/train_accuracy': 'Train Accuracy (of best test acc.)',\n",
" 'best_test_acc/test_avg_loss': 'Test Loss (of best test acc.)',\n",
" 'best_test_acc/train_avg_loss': 'Train Loss (of best test acc.)',\n",
" 'best_test_acc/step': 'Training Iterations (of best test acc.)',\n",
" 'best_test_acc/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best Test Acc.)',\n",
" 'best_test_acc/cosine_distance': 'Cosine Similarity (of best Test Acc.)',\n",
" 'best_test_acc/vector_difference_norm': 'Vector Difference Norm. (of best Test Acc.)',\n",
" 'mask/sparsity': 'Sparsity',\n",
" 'mask/unique_neurons': '# Unique Neurons',\n",
" 'mask/zeroed_neurons': '# Zeroed Neurons',\n",
" 'mask/permutation_log10': 'log10(1 + Permutations)',\n",
" 'mask/permutation_num_digits': 'Permutation # of Digits',\n",
" 'mask/permutations': 'Permutation',\n",
" 'mask/total_neurons': 'Total # of Neurons',\n",
" 'propagated_mask/sparsity': 'Mask Sparsity',\n",
" 'propagated_mask/unique_neurons': '# Unique Neurons (prop.)',\n",
" 'propagated_mask/zeroed_neurons': '# Zeroed Neurons (prop.)',\n",
" 'propagated_mask/permutation_log10': 'log10(1 + Permutations) (prop.)',\n",
" 'propagated_mask/permutation_num_digits': 'Permutation # of Digits (prop.)',\n",
" 'propagated_mask/permutations': 'Mask Permutations',\n",
" 'propagated_mask/total_neurons': 'Total # of Neurons (prop.)',\n",
" 'training/train_avg_loss': 'Train Loss',\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "HAVkz8ZzV0Hd"
},
"source": [
"# Seaborn Plot Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "eoxoJH4gWHbb"
},
"outputs": [],
"source": [
"# Choose the X/Y/Z labels from the parameter list above.\n",
"X_LABEL='propagated_mask/sparsity' #@param {type:\"string\"}\n",
"Y_LABEL='best_train_loss/cumulative_gradient_norm' #@param {type:\"string\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "pudAXLl1VzFl"
},
"outputs": [],
"source": [
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Seaborn style - remove outer plot ticks, white plot background.\n",
"np.set_printoptions(linewidth=128, precision=3, edgeitems=5)\n",
"sns.set_style(\"whitegrid\")\n",
"sns.color_palette(\"muted\")\n",
"sns.set_context(\"paper\", font_scale=1, rc={\n",
" \"lines.linewidth\": 1.2,\n",
" \"xtick.major.size\": 0,\n",
" \"xtick.minor.size\": 0,\n",
" \"ytick.major.size\": 0,\n",
" \"ytick.minor.size\": 0\n",
"})\n",
"\n",
"# Higher resolution plots\n",
"%config InlineBackend.figure_format = 'retina'"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "lYUK9xi_aym3"
},
"source": [
"### Plot Raw Data Points"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "uWcT76L6Wbv6"
},
"outputs": [],
"source": [
"\n",
"plt.figure(figsize=(16,8))\n",
"axis = sns.scatterplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', s=50, alpha=.5)\n",
"axis.set_ylabel(DATA_LABELS[Y_LABEL])\n",
"axis.set_xlabel(DATA_LABELS[X_LABEL])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Kws6tjfTa7h0"
},
"source": [
"### Plot Mean/StdDev"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jR04tmMnaxjG"
},
"outputs": [],
"source": [
"plt.figure(figsize=(16,8))\n",
"axis = sns.lineplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', alpha=.5, ci=\"sd\", markers=True)\n",
"axis.set_ylabel(DATA_LABELS[Y_LABEL])\n",
"axis.set_xlabel(DATA_LABELS[X_LABEL])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jyNFtKQfajiq"
},
"outputs": [],
"source": [
"# Code to save output files for publication.\n",
"PARAM_STR=X_LABEL.replace('/', '-')+'_'+Y_LABEL.replace('/', '-')\n",
"\n",
"OUT_FILE_PDF=f'/tmp/{PARAM_STR}.pdf'\n",
"OUT_FILE_SVG=f'/tmp/{PARAM_STR}.svg'\n",
"OUT_FILE_PNG=f'/tmp/{PARAM_STR}.png'\n",
"\n",
"plt.savefig(OUT_FILE_PDF, pi=600)\n",
"files.download(OUT_FILE_PDF)\n",
"\n",
"plt.savefig(OUT_FILE_SVG)\n",
"files.download(OUT_FILE_SVG)\n",
"\n",
"plt.savefig(OUT_FILE_PNG)\n",
"files.download(OUT_FILE_PNG)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "plot_summary_json",
"provenance": [
{
"file_id": "1g2aTwv76XMrLfEwryfj_tGzNnvZWjIVl",
"timestamp": 1600990155741
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: rigl/experimental/jax/datasets/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/datasets/cifar10.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CIFAR10 Dataset.
Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
In this case, the CIFAR10 dataset.
"""
from typing import MutableMapping, Sequence
from rigl.experimental.jax.datasets import dataset_base
import tensorflow.compat.v2 as tf
class CIFAR10Dataset(dataset_base.ImageDataset):
"""CIFAR10 dataset.
Attributes:
NAME: The Tensorflow Dataset's dataset name.
"""
NAME: str = 'cifar10'
# Computed from the training set by taking the per-channel mean/std-dev
# over sample, height and width axes of all training samples.
MEAN_RGB: Sequence[float] = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255]
STDDEV_RGB: Sequence[float] = [0.2470 * 255, 0.2435 * 255, 0.2616 * 255]
def __init__(self,
batch_size,
batch_size_test,
shuffle_buffer_size = 1024,
seed = 42):
"""CIFAR10 dataset.
Args:
batch_size: The batch size to use for the training datasets.
batch_size_test: The batch size used for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
seed: The random seed used to shuffle.
Returns:
Dataset: A dataset object.
Raises:
ValueError: If the test dataset is not evenly divisible by the
test batch size.
"""
super().__init__(CIFAR10Dataset.NAME, batch_size, batch_size_test,
shuffle_buffer_size, seed)
if self.get_test_len() % batch_size_test != 0:
raise ValueError(
'Test data not evenly divisible by batch size: {} % {} != 0.'.format(
self.get_test_len(), batch_size_test))
def preprocess(
self, data):
"""Normalizes CIFAR10 images: `uint8` -> `float32`.
Args:
data: Data sample.
Returns:
Data after being augmented/normalized/transformed.
"""
data = super().preprocess(data)
mean_rgb = tf.constant(self.MEAN_RGB, shape=[1, 1, 3], dtype=tf.float32)
std_rgb = tf.constant(self.STDDEV_RGB, shape=[1, 1, 3], dtype=tf.float32)
data['image'] = (tf.cast(data['image'], tf.float32) - mean_rgb) / std_rgb
return data
================================================
FILE: rigl/experimental/jax/datasets/cifar10_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.cifar10."""
from absl.testing import absltest
import numpy as np
from rigl.experimental.jax.datasets import cifar10
class CIFAR10DatasetTest(absltest.TestCase):
"""Test cases for CIFAR10 Dataset."""
def setUp(self):
"""Common setup routines/variables for test cases."""
super().setUp()
self._batch_size = 16
self._batch_size_test = 10
self._shuffle_buffer_size = 8
self._dataset = cifar10.CIFAR10Dataset(
self._batch_size,
batch_size_test=self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_create_dataset(self):
"""Tests creation of dataset."""
self.assertIsInstance(self._dataset, cifar10.CIFAR10Dataset)
def test_train_image_dims_content(self):
"""Tests dimensions and contents of test data."""
iterator = self._dataset.get_train()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='DataShape'):
self.assertTupleEqual(image.shape, (self._batch_size, 32, 32, 3))
with self.subTest(name='DataType'):
self.assertTrue(np.issubdtype(image.dtype, float))
with self.subTest(name='DataValues'):
# Normalized by stddev., expect nothing to fall outside 3 stddev.
self.assertTrue((image >= -3.).all() and (image <= 3.).all())
with self.subTest(name='LabelShape'):
self.assertLen(label, self._batch_size)
with self.subTest(name='LabelType'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='LabelValues'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_test_image_dims_content(self):
"""Tests dimensions and contents of train data."""
iterator = self._dataset.get_test()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='DataShape'):
self.assertTupleEqual(image.shape, (self._batch_size_test, 32, 32, 3))
with self.subTest(name='DataType'):
self.assertTrue(np.issubdtype(image.dtype, float))
with self.subTest(name='DataValues'):
# Normalized by stddev., expect nothing to fall outside 3 stddev.
self.assertTrue((image >= -3.).all() and (image <= 3.).all())
with self.subTest(name='LabelShape'):
self.assertLen(label, self._batch_size_test)
with self.subTest(name='LabelType'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='LabelValues'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_train_data_length(self):
"""Tests length of training dataset."""
total_count = 0
for batch in self._dataset.get_train():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_train_len())
def test_test_data_length(self):
"""Tests length of test dataset."""
total_count = 0
for batch in self._dataset.get_test():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_test_len())
def test_dataset_nonevenly_divisible_batch_size(self):
"""Tests non-evenly divisible test batch size."""
with self.assertRaisesRegex(
ValueError, 'Test data not evenly divisible by batch size: .*'):
self._dataset = cifar10.CIFAR10Dataset(
self._batch_size, batch_size_test=101)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/datasets/dataset_base.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset Classes.
Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
"""
import abc
from typing import MutableMapping, Optional
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
class Dataset(metaclass=abc.ABCMeta):
"""Base class for datasets.
Attributes:
DATAKEY: The key used for the data component of a Tensorflow Dataset
(TFDS) sample, e.g. 'image' for image datasets.
LABELKEY: The key used fot the label component of a Tensorflow Dataset
sample, i.e. 'label'.
name: The TFDS name of the dataset.
batch_size: The batch size to use for the training dataset.
batch_size_test: The batch size to use for the test dataset.
num_classes: the number of supervised classes in the dataset.
shape: the shape of an input data array.
"""
DATAKEY: Optional[str] = None
LABELKEY: str = 'label'
def __init__(self,
name,
batch_size,
batch_size_test,
shuffle_buffer_size,
prefetch_size = 1,
seed = None): # pytype: disable=annotation-type-mismatch
"""Base class for datasets.
Args:
name: The TFDS name of the dataset.
batch_size: The batch size to use for the training dataset.
batch_size_test: The batch size to use for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
prefetch_size: The number of mini-batches to prefetch.
seed: The random seed used to shuffle.
Returns:
A Dataset object.
"""
super().__init__()
self.name = name
self.batch_size = batch_size
self.batch_size_test = batch_size_test
self._shuffle_buffer_size = shuffle_buffer_size
self._prefetch_size = prefetch_size
self._train_ds, self._train_info = tfds.load(
self.name,
split=tfds.Split.TRAIN,
data_dir=self._dataset_dir(),
with_info=True)
self._train_ds = self._train_ds.shuffle(
self._shuffle_buffer_size,
seed).map(self.preprocess).cache().map(self.augment).batch(
self.batch_size, drop_remainder=True).prefetch(self._prefetch_size)
self._test_ds, self._test_info = tfds.load(
self.name,
split=tfds.Split.TEST,
data_dir=self._dataset_dir(),
with_info=True)
self._test_ds = self._test_ds.map(self.preprocess).cache().batch(
self.batch_size_test).prefetch(self._prefetch_size)
self.num_classes = self._train_info.features['label'].num_classes
self.shape = self._train_info.features['image'].shape
def _dataset_dir(self):
"""Returns the dataset path for the TFDS data."""
return None
def get_train(self):
"""Returns the training dataset."""
return iter(tfds.as_numpy(self._train_ds))
def get_train_len(self):
"""Returns the length of the training dataset."""
return self._train_info.splits['train'].num_examples
def get_test(self):
"""Returns the test dataset."""
return iter(tfds.as_numpy(self._test_ds))
def get_test_len(self):
"""Returns the length of the test dataset."""
return self._test_info.splits['test'].num_examples
def preprocess(
self, data):
"""Preprocessing fn used by TFDS map for normalization.
This function is for transformations that can be cached, e.g.
normalization/whitening.
Args:
data: Data sample.
Returns:
Data after being normalized/transformed.
"""
return data
def augment(
self, data):
"""Preprocessing fn used by TFDS map for augmentation at training time.
This function is for transformations that should not be cached, e.g. random
augmentation that should change for every sample, and are only applied at
training time.
Args:
data: Data sample.
Returns:
Data after being augmented/transformed.
"""
return data
class ImageDataset(Dataset):
"""Base class for image datasets."""
DATAKEY = 'image'
def preprocess(
self, data):
"""Preprocessing function used by TFDS map for normalization.
This function is for transformations that can be cached, e.g.
normalization/whitening.
Args:
data: Data sample.
Returns:
Data after being normalized/transformed.
"""
data = super().preprocess(data)
# Ensure we only provide the image and label, stripping out other keys.
return dict((key, val)
for key, val in data.items()
if key in [self.LABELKEY, self.DATAKEY])
================================================
FILE: rigl/experimental/jax/datasets/dataset_base_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.dataset_base."""
from absl.testing import absltest
from rigl.experimental.jax.datasets import dataset_base
class DummyDataset(dataset_base.ImageDataset):
"""A dummy implementation of the abstract dataset class.
Attributes:
NAME: The Tensorflow Dataset's dataset name.
"""
NAME: str = 'mnist'
def __init__(self,
batch_size,
batch_size_test,
shuffle_buffer_size = 1024,
seed = 42):
"""Dummy MNIST dataset.
Args:
batch_size: The batch size to use for the training datasets.
batch_size_test: The batch size to used for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
seed: The random seed used to shuffle.
Returns:
Dataset: A dataset object.
"""
super().__init__(DummyDataset.NAME, batch_size, batch_size_test,
shuffle_buffer_size, seed)
class DummyDatasetTest(absltest.TestCase):
"""Test cases for dummy dataset."""
def setUp(self):
"""Common setup routines/variables for test cases."""
super().setUp()
self._batch_size = 16
self._batch_size_test = 10
self._shuffle_buffer_size = 8
self._dataset = DummyDataset(
self._batch_size,
batch_size_test=self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_create_dataset(self):
"""Tests creation of dataset."""
self.assertIsInstance(self._dataset, DummyDataset)
def test_train_image_dims_content(self):
"""Tests dimensions and contents of test data."""
iterator = iter(self._dataset.get_train())
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1))
with self.subTest(name='data_values'):
self.assertBetween(image.all(), 0, 256)
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size)
with self.subTest(name='label_values'):
self.assertBetween(label.all(), 0, self._dataset.num_classes)
def test_test_image_dims_content(self):
"""Tests dimensions and contents of train data."""
iterator = iter(self._dataset.get_test())
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1))
with self.subTest(name='data_values'):
self.assertBetween(image.all(), 0, 256)
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size_test)
with self.subTest(name='label_values'):
self.assertBetween(label.all(), 0, self._dataset.num_classes)
def test_train_data_length(self):
"""Tests length of training dataset."""
total_count = 0
for batch in self._dataset.get_train():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_train_len())
def test_test_data_length(self):
"""Tests length of test dataset."""
total_count = 0
for batch in self._dataset.get_test():
total_count += len(batch['label'])
# Check image size/content.
self.assertEqual(total_count, self._dataset.get_test_len())
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/datasets/dataset_factory.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset Factory.
Dataset factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
Attributes:
DATASETS: A list of the datasets that can be created.
"""
from typing import Any, Mapping, Type
from rigl.experimental.jax.datasets import cifar10
from rigl.experimental.jax.datasets import dataset_base
from rigl.experimental.jax.datasets import mnist
import tensorflow.compat.v2 as tf
DATASETS: Mapping[str, Type[dataset_base.Dataset]] = {
'MNIST': mnist.MNISTDataset,
'CIFAR10': cifar10.CIFAR10Dataset,
}
def create_dataset(name, *args, **kwargs):
"""Creates a Tensorflow datasets (TFDS) dataset.
Args:
name: The TFDS name of the dataset.
*args: Dataset arguments.
**kwargs: Dataset keyword arguments.
Returns:
Dataset: An abstracted dataset object.
Raises:
ValueError if a dataset with the given name does not exist.
"""
if name not in DATASETS:
raise ValueError(f'No such dataset: {name}')
return DATASETS[name](*args, **kwargs)
================================================
FILE: rigl/experimental/jax/datasets/dataset_factory_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.dataset_common."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from rigl.experimental.jax.datasets import dataset_base
from rigl.experimental.jax.datasets import dataset_factory
class DatasetCommonTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self._batch_size = 32
self._batch_size_test = 10
self._shuffle_buffer_size = 128
def _create_dataset(self, dataset_name):
"""Helper function for creating a dataset."""
return dataset_factory.create_dataset(
dataset_name,
self._batch_size,
self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_dataset_supported(self):
"""Tests supported datasets."""
for dataset_name in dataset_factory.DATASETS:
dataset = self._create_dataset(dataset_name)
self.assertIsInstance(dataset, dataset_base.Dataset)
@parameterized.parameters(*dataset_factory.DATASETS.keys())
def test_dataset_train_iterators(self, dataset_name):
"""Tests dataset's train iterator."""
dataset = self._create_dataset(dataset_name)
sample = next(dataset.get_train())
with self.subTest(name='{}_sample'.format(dataset_name)):
self.assertNotEmpty(sample)
with self.subTest(name='{}_label_type'.format(dataset_name)):
self.assertIsInstance(sample['label'], np.ndarray)
with self.subTest(name='{}_label_batch_size'.format(dataset_name)):
self.assertLen(sample['label'], self._batch_size)
with self.subTest(name='{}_image_type'.format(dataset_name)):
self.assertIsInstance(sample['image'], np.ndarray)
with self.subTest(name='{}_image_shape'.format(dataset_name)):
self.assertLen(sample['image'].shape, 4)
with self.subTest(name='{}_image_batch_size'.format(dataset_name)):
self.assertEqual(sample['image'].shape[0], self._batch_size)
with self.subTest(
name='{}_non_zero_image_dimensions'.format(dataset_name)):
self.assertGreater(sample['image'].shape[1], 1)
@parameterized.parameters(*dataset_factory.DATASETS.keys())
def test_dataset_test_iterators(self, dataset_name):
"""Tests dataset's test iterator."""
dataset = self._create_dataset(dataset_name)
sample = next(dataset.get_test())
with self.subTest(name='{}_sample'.format(dataset_name)):
self.assertNotEmpty(sample)
with self.subTest(name='{}_label_type'.format(dataset_name)):
self.assertIsInstance(sample['label'], np.ndarray)
with self.subTest(name='{}_label_batch_size'.format(dataset_name)):
self.assertLen(sample['label'], self._batch_size_test)
with self.subTest(name='{}_image_type'.format(dataset_name)):
self.assertIsInstance(sample['image'], np.ndarray)
with self.subTest(name='{}_image_shape'.format(dataset_name)):
self.assertLen(sample['image'].shape, 4)
with self.subTest(name='{}_image_batch_size'.format(dataset_name)):
self.assertEqual(sample['image'].shape[0], self._batch_size_test)
with self.subTest(
name='{}_non_zero_image_dimensions'.format(dataset_name)):
self.assertGreater(sample['image'].shape[1], 1)
def test_dataset_unsupported(self):
"""Tests unsupported datasets."""
with self.assertRaisesRegex(ValueError, 'No such dataset: unsupported'):
self._create_dataset('unsupported')
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/datasets/mnist.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST Dataset.
Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)
with JAX/FLAX, by defining a bunch of wrappers, including preprocessing.
In this case, the MNIST dataset.
"""
from typing import MutableMapping
from rigl.experimental.jax.datasets import dataset_base
import tensorflow.compat.v2 as tf
class MNISTDataset(dataset_base.ImageDataset):
"""MNIST dataset.
Attributes:
NAME: The Tensorflow Dataset's dataset name.
"""
NAME: str = 'mnist'
def __init__(self,
batch_size,
batch_size_test,
shuffle_buffer_size = 1024,
seed = 42):
"""MNIST dataset.
Args:
batch_size: The batch size to use for the training datasets.
batch_size_test: The batch size to used for the test dataset.
shuffle_buffer_size: The buffer size to use for dataset shuffling.
seed: The random seed used to shuffle.
Returns:
Dataset: A dataset object.
"""
super().__init__(MNISTDataset.NAME, batch_size, batch_size_test,
shuffle_buffer_size, seed)
def preprocess(
self, data):
"""Normalizes MNIST images: `uint8` -> `float32`.
Args:
data: Data sample.
Returns:
Data after being augmented/normalized/transformed.
"""
data = super().preprocess(data)
data['image'] = (tf.cast(data['image'], tf.float32) / 255.) - 0.5
return data
================================================
FILE: rigl/experimental/jax/datasets/mnist_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.datasets.mnist."""
from absl.testing import absltest
import numpy as np
from rigl.experimental.jax.datasets import mnist
class MNISTDatasetTest(absltest.TestCase):
"""Test cases for MNIST Dataset."""
def setUp(self):
"""Common setup routines/variables for test cases."""
super().setUp()
self._batch_size = 16
self._batch_size_test = 10
self._shuffle_buffer_size = 8
self._dataset = mnist.MNISTDataset(
self._batch_size,
batch_size_test=self._batch_size_test,
shuffle_buffer_size=self._shuffle_buffer_size)
def test_create_dataset(self):
"""Tests creation of dataset."""
self.assertIsInstance(self._dataset, mnist.MNISTDataset)
def test_train_image_dims_content(self):
"""Tests dimensions and contents of test data."""
iterator = self._dataset.get_train()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1))
with self.subTest(name='data_values'):
self.assertTrue((image >= -1.).all() and (image <= 1.).all())
with self.subTest(name='data_type'):
self.assertTrue(np.issubdtype(image.dtype, float))
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size)
with self.subTest(name='label_type'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='label_values'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_test_image_dims_content(self):
"""Tests dimensions and contents of train data."""
iterator = self._dataset.get_test()
sample = next(iterator)
image, label = sample['image'], sample['label']
with self.subTest(name='data_shape'):
self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1))
with self.subTest(name='data_type'):
self.assertTrue(np.issubdtype(image.dtype, float))
# TODO: Find a better approach to testing with JAX arrays.
with self.subTest(name='data_values'):
self.assertTrue((image >= -1.).all() and (image <= 1.).all())
with self.subTest(name='label_shape'):
self.assertLen(label, self._batch_size_test)
with self.subTest(name='label_type'):
self.assertTrue(np.issubdtype(label.dtype, int))
with self.subTest(name='label_values'):
self.assertTrue((label >= 0).all() and
(label <= self._dataset.num_classes).all())
def test_train_data_length(self):
"""Tests length of training dataset."""
total_count = 0
for batch in self._dataset.get_train():
total_count += len(batch['label'])
self.assertEqual(total_count, self._dataset.get_train_len())
def test_test_data_length(self):
"""Tests length of test dataset."""
total_count = 0
for batch in self._dataset.get_test():
total_count += len(batch['label'])
# Check image size/content.
self.assertEqual(total_count, self._dataset.get_test_len())
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/fixed_param.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Symmetry: Train models with fixed param, but diff. depth and width."""
import ast
import functools
import operator
from os import path
from typing import List, Sequence
import uuid
from absl import app
from absl import flags
from absl import logging
import flax
from flax.metrics import tensorboard
from flax.training import lr_schedule
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import mnist_fc
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.pruning import masked
from rigl.experimental.jax.pruning import symmetry
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils
experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))
logging.info('Saving experimental results to %s', experiment_dir)
host_count = jax.host_count()
local_device_count = jax.local_device_count()
logging.info('Device count: %d, host count: %d, local device count: %d',
jax.device_count(), host_count, local_device_count)
if jax.host_id() == 0:
summary_writer = tensorboard.SummaryWriter(experiment_dir)
dataset = dataset_factory.create_dataset(
FLAGS.dataset,
FLAGS.batch_size,
FLAGS.batch_size_test,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
logging.info('Training %s on the %s dataset...', MODEL, FLAGS.dataset)
rng = jax.random.PRNGKey(FLAGS.random_seed)
input_shape = (1,) + dataset.shape
input_len = functools.reduce(operator.mul, dataset.shape)
features = mnist_fc.feature_dim_for_param(
input_len,
FLAGS.param_count,
FLAGS.depth)
logging.info('Model Configuration: %s', str(features))
base_model, _ = model_factory.create_model(
MODEL,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
features=features)
model_param_count = utils.count_param(base_model, ('kernel',))
logging.info(
'Model Config: param.: %d, depth: %d. max width: %d, min width: %d',
model_param_count, len(features), max(features), min(features))
logging.info('Generating random mask based on model')
# Re-initialize the RNG to maintain same training pattern (as in prune code).
mask_rng = jax.random.PRNGKey(FLAGS.random_seed)
mask = masked.shuffled_mask(
base_model,
rng=mask_rng,
sparsity=FLAGS.mask_sparsity)
if jax.host_id() == 0:
mask_stats = symmetry.get_mask_stats(mask)
logging.info('Mask stats: %s', str(mask_stats))
for label, value in mask_stats.items():
try:
summary_writer.scalar(f'mask/{label}', value, 0)
# This is needed because permutations (long int) can't be cast to float32.
except (OverflowError, ValueError):
summary_writer.text(f'mask/{label}', str(value), 0)
logging.error('Could not write mask/%s to tensorflow summary as float32'
', writing as string instead.', label)
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(
mask_stats, path.join(experiment_dir, 'mask_stats.json'))
if FLAGS.dump_json:
mask_stats['permutations'] = str(mask_stats['permutations'])
utils.dump_dict_json(mask_stats,
path.join(experiment_dir, 'mask_stats.json'))
model_stats = {
'depth': len(features),
'max_width': max(features),
'min_width': min(features),
}
model_stats.update(
{'feature_{}'.format(i): value for i, value in enumerate(features)})
if FLAGS.dump_json:
utils.dump_dict_json(model_stats,
path.join(experiment_dir, 'model_stats.json'))
model, initial_state = model_factory.create_model(
'MNIST_FC',
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
features=features, masks=mask)
if FLAGS.opt == 'Adam':
optimizer = flax.optim.Adam(
learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
elif FLAGS.opt == 'Momentum':
optimizer = flax.optim.Momentum(
learning_rate=FLAGS.lr,
beta=FLAGS.momentum,
weight_decay=FLAGS.weight_decay,
nesterov=False)
else:
raise ValueError('Unknown Optimizer: {}'.format(FLAGS.opt))
steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size
if FLAGS.lr_schedule == 'constant':
lr_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.lr, steps_per_epoch)
elif FLAGS.lr_schedule == 'stepped':
lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, lr_schedule_steps)
elif FLAGS.lr_schedule == 'cosine':
lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, FLAGS.epochs)
else:
raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule))
if jax.host_id() == 0:
trainer = training.Trainer(
optimizer,
model,
initial_state,
dataset,
rng,
summary_writer=summary_writer,
)
else:
trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)
_, best_metrics = trainer.train(
FLAGS.epochs,
lr_fn=lr_fn,
update_iter=FLAGS.update_iterations,
update_epoch=FLAGS.update_epoch,
)
logging.info('Best metrics: %s', str(best_metrics))
if jax.host_id() == 0:
if FLAGS.dump_json:
utils.dump_dict_json(best_metrics,
path.join(experiment_dir, 'best_metrics.json'))
for label, value in best_metrics.items():
summary_writer.scalar('best/{}'.format(label), value,
FLAGS.epochs * steps_per_epoch)
summary_writer.close()
def main(argv: List[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
run_training()
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/experimental/jax/fixed_param_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.fixed_param."""
import glob
from os import path
import tempfile
from absl.testing import absltest
from absl.testing import flagsaver
from rigl.experimental.jax import fixed_param
class FixedParamTest(absltest.TestCase):
def test_run(self):
"""Tests if the driver for shuffled training runs correctly."""
experiment_dir = tempfile.mkdtemp()
eval_flags = dict(
epochs=1,
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
fixed_param.main([])
with self.subTest(name='tf_summary_file_exists'):
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/models/cifar10_cnn.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CIFAR10 CNN.
A small CNN for the CIFAR10 dataset, consists of a number of convolutional
layers (determined by length of filters parameter), followed by a
fully-connected layer.
"""
from typing import Callable, Mapping, Optional, Sequence
from absl import logging
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
class CIFAR10CNN(flax.deprecated.nn.Module):
"""Small CIFAR10 CNN."""
def apply(self,
inputs,
num_classes,
filter_shape = (3, 3),
filters = (32, 32, 64, 64, 128, 128),
init_fn=flax.deprecated.nn.initializers.kaiming_normal,
train=True,
activation_fn = flax.deprecated.nn.relu,
masks = None,
masked_layer_indices = None):
"""Applies a convolution to the inputs.
Args:
inputs: Input data with dimensions (batch, spatial_dims..., features).
num_classes: Number of classes in the dataset.
filter_shape: Shape of the convolutional filters.
filters: Number of filters in each convolutional layer, and number of conv
layers (given by length of sequence).
init_fn: Initialization function used for convolutional layers.
train: If model is being evaluated in training mode or not.
activation_fn: Activation function to be used for convolutional layers.
masks: Masks of the layers in this model, in the same form as
module params, or None.
masked_layer_indices: The layer indices of layers in model to be masked.
Returns:
A tensor of shape (batch, num_classes), containing the logit output.
Raises:
ValueError if the number of pooling layers is too many for the given input
size, or if the provided mask is not of the correct depth for the model.
"""
# Note: First dim is batch, last dim is channels, other dims are "spatial".
if not all([(dim >= 2**(len(filters)//2)) for dim in inputs.shape[1:-2]]):
raise ValueError(
'Input spatial size, {}, does not allow {} pooling layers.'.format(
str(inputs.shape[1:-2]), len(filters))
)
depth = 1 + len(filters)
masks = masked.generate_model_masks(depth, masks,
masked_layer_indices)
batch_norm = flax.deprecated.nn.BatchNorm.partial(
use_running_average=not train, momentum=0.99, epsilon=1e-5)
for i, filter_num in enumerate(filters):
if f'MaskedModule_{i}' in masks:
logging.info('Layer %d is masked in model', i)
mask = masks[f'MaskedModule_{i}']
inputs = masked.masked(flax.deprecated.nn.Conv, mask)(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init.sparse_init(
init_fn(), mask['kernel'] if mask is not None else None))
else:
inputs = flax.deprecated.nn.Conv(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init_fn())
inputs = batch_norm(inputs, name='bn_conv_{}'.format(i))
inputs = activation_fn(inputs)
if i % 2 == 1:
inputs = flax.deprecated.nn.max_pool(
inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID')
# Global average pooling if we have spatial dimensions left.
inputs = flax.deprecated.nn.avg_pool(
inputs, window_shape=(inputs.shape[1:-1]), padding='VALID')
inputs = inputs.reshape((inputs.shape[0], -1))
# This is effectively a Dense layer, but we cast it as a convolution layer
# to allow us to easily propagate masks, avoiding b/156135283.
inputs = flax.deprecated.nn.Conv(
inputs,
features=num_classes,
kernel_size=inputs.shape[1:-1],
kernel_init=flax.deprecated.nn.initializers.xavier_normal())
inputs = batch_norm(inputs, name='bn_dense_1')
inputs = jnp.squeeze(inputs)
return flax.deprecated.nn.log_softmax(inputs)
================================================
FILE: rigl/experimental/jax/models/cifar10_cnn_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.cifar10_cnn."""
from absl.testing import absltest
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import cifar10_cnn
class CIFAR10CNNTest(absltest.TestCase):
"""Tests the CIFAR10CNN model."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._num_classes = 10
self._batch_size = 2
self._input_shape = ((self._batch_size, 32, 32, 3), jnp.float32)
self._input = jnp.zeros(*self._input_shape)
def test_output_shapes(self):
"""Tests the output shapes of the model."""
with flax.deprecated.nn.stateful() as initial_state:
_, initial_params = cifar10_cnn.CIFAR10CNN.init_by_shape(
self._rng, (self._input_shape,), num_classes=self._num_classes)
model = flax.deprecated.nn.Model(cifar10_cnn.CIFAR10CNN, initial_params)
with flax.deprecated.nn.stateful(initial_state, mutable=False):
logits = model(self._input, num_classes=self._num_classes, train=False)
self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))
def test_invalid_spatial_dimensions(self):
"""Tests model with an invalid spatial dimension parameters."""
with self.assertRaisesRegex(ValueError, 'Input spatial size, '):
cifar10_cnn.CIFAR10CNN.init_by_shape(
self._rng, (self._input_shape,),
num_classes=self._num_classes,
filters=20 * (32,))
def test_invalid_masks_depth(self):
"""Tests model mask with the incorrect depth for the given model."""
invalid_masks = {
'MaskedModule_0': {
'kernel':
jnp.zeros((self._batch_size, 3, 3, 32))
}
}
with self.assertRaisesRegex(
ValueError, 'Mask is invalid for model.'):
cifar10_cnn.CIFAR10CNN.init_by_shape(
self._rng, (self._input_shape,),
num_classes=self._num_classes,
masks=invalid_masks)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/mnist_cnn.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST CNN.
A small CNN for the MNIST dataset, consists of a number of convolutional layers
(determined by length of filters parameter), followed by a fully-connected
layer.
"""
from typing import Callable, Mapping, Optional, Sequence
from absl import logging
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
class MNISTCNN(flax.deprecated.nn.Module):
"""Small MNIST CNN."""
def apply(self,
inputs,
num_classes,
filter_shape = (5, 5),
filters = (16, 32),
dense_size = 64,
train=True,
init_fn = flax.deprecated.nn.initializers.kaiming_normal,
activation_fn = flax.deprecated.nn.relu,
masks = None,
masked_layer_indices = None):
"""Applies a convolution to the inputs.
Args:
inputs: Input data with dimensions (batch, spatial_dims..., features).
num_classes: Number of classes in the dataset.
filter_shape: Shape of the convolutional filters.
filters: Number of filters in each convolutional layer, and number of conv
layers (given by length of sequence).
dense_size: Number of filters in each convolutional layer, and number of
conv layers (given by length of sequence).
train: If model is being evaluated in training mode or not.
init_fn: Initialization function used for convolutional layers.
activation_fn: Activation function to be used for convolutional layers.
masks: Masks of the layers in this model, in the same form as
module params, or None.
masked_layer_indices: The layer indices of layers in model to be masked.
Returns:
A tensor of shape (batch, num_classes), containing the logit output.
Raises:
ValueError if the number of pooling layers is too many for the given input
size.
"""
# Note: First dim is batch, last dim is channels, other dims are "spatial".
if not all([(dim >= 2**len(filters)) for dim in inputs.shape[1:-2]]):
raise ValueError(
'Input spatial size, {}, does not allow {} pooling layers.'.format(
str(inputs.shape[1:-2]), len(filters))
)
depth = 2 + len(filters)
masks = masked.generate_model_masks(depth, masks,
masked_layer_indices)
batch_norm = flax.deprecated.nn.BatchNorm.partial(
use_running_average=not train, momentum=0.99, epsilon=1e-5)
for i, filter_num in enumerate(filters):
if f'MaskedModule_{i}' in masks:
logging.info('Layer %d is masked in model', i)
mask = masks[f'MaskedModule_{i}']
inputs = masked.masked(flax.deprecated.nn.Conv, mask)(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init.sparse_init(
init_fn(), mask['kernel'] if mask is not None else None))
else:
inputs = flax.deprecated.nn.Conv(
inputs,
features=filter_num,
kernel_size=filter_shape,
kernel_init=init_fn())
inputs = batch_norm(inputs, name='bn_conv_{}'.format(i))
inputs = activation_fn(inputs)
if i < len(filters) - 1:
inputs = flax.deprecated.nn.max_pool(
inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID')
# Global average pool at end of convolutional layers.
inputs = flax.deprecated.nn.avg_pool(
inputs, window_shape=inputs.shape[1:-1], padding='VALID')
# This is effectively a Dense layer, but we cast it as a convolution layer
# to allow us to easily propagate masks, avoiding b/156135283.
if f'MaskedModule_{depth - 2}' in masks:
mask_dense_1 = masks[f'MaskedModule_{depth - 2}']
inputs = masked.masked(flax.deprecated.nn.Conv, mask_dense_1)(
inputs,
features=dense_size,
kernel_size=inputs.shape[1:-1],
kernel_init=init.sparse_init(
init_fn(),
mask_dense_1['kernel'] if mask_dense_1 is not None else None))
else:
inputs = flax.deprecated.nn.Conv(
inputs,
features=dense_size,
kernel_size=inputs.shape[1:-1],
kernel_init=init_fn())
inputs = batch_norm(inputs, name='bn_dense_1')
inputs = activation_fn(inputs)
inputs = flax.deprecated.nn.Dense(
inputs,
features=num_classes,
kernel_init=flax.deprecated.nn.initializers.xavier_normal())
inputs = batch_norm(inputs, name='bn_dense_2')
inputs = jnp.squeeze(inputs)
return flax.deprecated.nn.log_softmax(inputs)
================================================
FILE: rigl/experimental/jax/models/mnist_cnn_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.mnist_cnn."""
from absl.testing import absltest
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import mnist_cnn
class MNISTCNNTest(absltest.TestCase):
"""Tests the MNISTCNN model."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._num_classes = 10
self._batch_size = 2
self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
self._input = jnp.zeros(*self._input_shape)
def test_output_shapes(self):
"""Tests the output shapes of the model."""
with flax.deprecated.nn.stateful() as initial_state:
_, initial_params = mnist_cnn.MNISTCNN.init_by_shape(
self._rng, (self._input_shape,), num_classes=self._num_classes)
model = flax.deprecated.nn.Model(mnist_cnn.MNISTCNN, initial_params)
with flax.deprecated.nn.stateful(initial_state, mutable=False):
logits = model(self._input, num_classes=self._num_classes, train=False)
self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))
def test_invalid_depth(self):
"""Tests model mask with the incorrect depth for the given model."""
with self.assertRaisesRegex(ValueError, 'Input spatial size, '):
mnist_cnn.MNISTCNN.init_by_shape(
self._rng, (self._input_shape,),
num_classes=self._num_classes,
filters=10 * (32,))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/mnist_fc.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNIST Fully-Connected Neural Network.
A fully-connected model for the MNIST dataset, consists of a number of
dense layers (determined by length of features parameter).
"""
import math
from typing import Callable, Mapping, Optional, Sequence, Tuple
from absl import logging
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
def feature_dim_for_param(input_len,
param_count,
depth,
depth_mult = 2.):
"""Calculates feature dimensions for a fixed parameter count and depth.
This is calculated for the specific case of a fully-connected neural
network, where each layer consists of l * a**i neurons, where a is a
multiplier for each layer.
Assume,
x is the input size,
a is the depth multiplier,
l is the initial layer width,
d is the depth.
The total number of parameters, n, is then given by,
$$n = x*l + l^2 * sum_{i=2}^d a^{2i-3})$$.
Args:
input_len: Input size.
param_count: Number of parameters model should maintain.
depth: Depth of the model.
depth_mult: The layer width multiplier w.r.t. depth.
Returns:
The feature specification for a fully-connected model, as a tuple of layer
widths.
Raises:
ValueError: If the given number of parameters is too low for the given
depth and input size.
"""
# Calculate the initial width for the first layer.
if depth == 1:
initial_width = param_count / input_len
else:
# l = ((x^2 + 4cn)^{1/2} - x)/(2c) where c = sum_{i=2}^d a^{2i-3}.
depth_sum = sum(depth_mult**(2 * i - 3) for i in range(2, depth + 1))
initial_width = (math.sqrt(input_len**2 + 4 * depth_sum * param_count) -
input_len) / (2 * depth_sum)
if initial_width < 1:
raise ValueError(
'Expected parameter count too low for given depth and input size.')
return tuple(int(int(initial_width) * depth_mult**i) for i in range(depth))
class MNISTFC(flax.deprecated.nn.Module):
"""MNIST Fully-Connected Neural Network."""
def apply(self,
inputs,
num_classes,
features = (32, 32),
train=True,
init_fn = flax.deprecated.nn.initializers.kaiming_normal,
activation_fn = flax.deprecated.nn.relu,
masks = None,
masked_layer_indices = None,
dropout_rate = 0.):
"""Applies fully-connected neural network to the inputs.
Args:
inputs: Input data with dimensions (batch, features), if features has more
than one dimension, it is flattened.
num_classes: Number of classes in the dataset.
features: Number of neurons in each layer, and number of layers (given by
length of sequence) + one layer for softmax.
train: If model is being evaluated in training mode or not.
init_fn: Initialization function used for dense layers.
activation_fn: Activation function to be used for dense layers.
masks: Masks of the layers in this model, in the same form as module
params, or None.
masked_layer_indices: The layer indices of layers in model to be masked.
dropout_rate: Dropout rate, if 0 then dropout is not used (default).
Returns:
A tensor of shape (batch, num_classes), containing the logit output.
"""
batch_norm = flax.deprecated.nn.BatchNorm.partial(
use_running_average=not train, momentum=0.99, epsilon=1e-5)
depth = 1 + len(features)
masks = masked.generate_model_masks(depth, masks,
masked_layer_indices)
# If inputs are in image dimensions, flatten image.
inputs = inputs.reshape(inputs.shape[0], -1)
for i, feature_num in enumerate(features):
if f'MaskedModule_{i}' in masks:
logging.info('Layer %d is masked in model', i)
mask = masks[f'MaskedModule_{i}']
inputs = masked.masked(flax.deprecated.nn.Dense, mask)(
inputs,
features=feature_num,
kernel_init=init.sparse_init(
init_fn(), mask['kernel'] if mask is not None else None))
else:
inputs = flax.deprecated.nn.Dense(
inputs, features=feature_num, kernel_init=init_fn())
inputs = batch_norm(inputs, name=f'bn_conv_{i}')
inputs = activation_fn(inputs)
if dropout_rate > 0.0:
inputs = flax.deprecated.nn.dropout(
inputs, dropout_rate, deterministic=not train)
inputs = flax.deprecated.nn.Dense(
inputs,
features=num_classes,
kernel_init=flax.deprecated.nn.initializers.xavier_normal())
return flax.deprecated.nn.log_softmax(inputs)
================================================
FILE: rigl/experimental/jax/models/mnist_fc_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.mnist_fc."""
from typing import Sequence
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import mnist_fc
from rigl.experimental.jax.utils import utils
PARAM_COUNT_PARAM: Sequence[str] = ('kernel',)
class MNISTFCTest(parameterized.TestCase):
"""Tests the MNISTFC model."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._num_classes = 10
self._batch_size = 2
self._input_len = 28*28*1
self._input_shape = ((self._batch_size, self._input_len), jnp.float32)
self._input = jnp.zeros((self._batch_size, self._input_len), jnp.float32)
self._param_count = 1e7
def test_output_shapes(self):
"""Tests the output shape from the model."""
with flax.deprecated.nn.stateful() as initial_state:
_, initial_params = mnist_fc.MNISTFC.init_by_shape(
self._rng, (self._input_shape,), num_classes=self._num_classes)
model = flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params)
with flax.deprecated.nn.stateful(initial_state, mutable=False):
logits = model(self._input, num_classes=self._num_classes, train=False)
self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))
def test_invalid_masks_depth(self):
"""Tests a model with an invalid mask."""
invalid_masks = {
'MaskedModule_0': {
'kernel':
jnp.zeros((self._batch_size, 5 * 5 * 16))
}
}
with self.assertRaisesRegex(
ValueError, 'Mask is invalid for model.'):
mnist_fc.MNISTFC.init_by_shape(
self._rng,
(self._input_shape,),
num_classes=self._num_classes,
masks=invalid_masks)
def _create_model(self, features):
"""Convenience fn to create a FLAX model ."""
_, initial_params = mnist_fc.MNISTFC.init_by_shape(
self._rng,
(self._input_shape,),
num_classes=self._num_classes,
features=features)
return flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params)
@parameterized.parameters(*range(1, 6))
def test_feature_dim_for_param_depth(self, depth):
"""Tests feature_dim_for_param with multiple depths."""
features = mnist_fc.feature_dim_for_param(self._input_len,
self._param_count, depth)
model = self._create_model(features)
total_size = utils.count_param(model, PARAM_COUNT_PARAM)
with self.subTest(name='FeatureDimLen'):
self.assertLen(features, depth)
with self.subTest(name='FeatureDimParamCount'):
self.assertBetween(total_size, self._param_count * 0.95,
self._param_count * 1.05)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/models/model_factory.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for neural network models.
Attributes:
MODELS: A list of the models that can be created.
"""
from typing import Any, Callable, Mapping, Sequence, Tuple, Type
import flax
import jax.numpy as jnp
from rigl.experimental.jax.models import cifar10_cnn
from rigl.experimental.jax.models import mnist_cnn
from rigl.experimental.jax.models import mnist_fc
MODELS: Mapping[str, Type[flax.deprecated.nn.Model]] = {
'MNIST_CNN': mnist_cnn.MNISTCNN,
'MNIST_FC': mnist_fc.MNISTFC,
'CIFAR10_CNN': cifar10_cnn.CIFAR10CNN,
}
def create_model(
name, rng,
input_specs, **kwargs
):
"""Creates a Model.
Args:
name: the name of the model to instantiate.
rng : the random number generator to use for init.
input_specs: an iterable of (shape, dtype) pairs specifying the inputs.
**kwargs: list of model specific keyword arguments.
Returns:
A tuple of FLAX model (flax.deprecated.nn.Model), and initial model state.
Raises:
ValueError if a model with the given name does not exist.
"""
if name not in MODELS:
raise ValueError('No such model: {}'.format(name))
with flax.deprecated.nn.stateful() as init_state:
with flax.deprecated.nn.stochastic(rng):
model_class = MODELS[name].partial(**kwargs)
_, params = model_class.init_by_shape(rng, input_specs)
return flax.deprecated.nn.Model(model_class, params), init_state
def update_model(model,
**kwargs):
"""Updates a model to use different model arguments, but same parameters.
Args:
model: The model to update.
**kwargs: List of model specific keyword arguments.
Returns:
A FLAX model.
"""
return flax.deprecated.nn.Model(model.module.partial(**kwargs), model.params)
================================================
FILE: rigl/experimental/jax/models/model_factory_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.model_factory."""
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.models import model_factory
class ModelCommonTest(parameterized.TestCase):
"""Tests the model factory."""
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._input_shape = ((1, 28, 28, 1), jnp.float32)
self._num_classes = 10
def _create_model(self, model_name):
return model_factory.create_model(
model_name,
self._rng, (self._input_shape,),
num_classes=self._num_classes)
@parameterized.parameters(*model_factory.MODELS.keys())
def test_model_supported(self, model_name):
"""Tests supported models."""
model, state = self._create_model(model_name)
with self.subTest(name='test_model_supported_model_instance'):
self.assertIsInstance(model, flax.deprecated.nn.Model)
with self.subTest(name='test_model_supported_collection_instance'):
self.assertIsInstance(state, flax.deprecated.nn.Collection)
def test_model_unsupported(self):
"""Tests unsupported models."""
with self.assertRaisesRegex(ValueError, 'No such model: unsupported'):
self._create_model('unsupported')
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/prune.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Weight Symmetry: Iteratively Prune Model during Training.
Command for training and pruning an MNIST fully-connected model for 10 epochs
with a fixed pruning rate of 0.95:
prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10
--pruning_rate=0.95
Command for training and pruning an MNIST fully-connected model for 10
epochs, with pruning rates 0.3, 0.6 and 0.95 at epochs 2, 5, and 8 respectively
for all layers:
prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10
--pruning_schedule='[(2, 0.3), (5, 0.6), (8, 0.95)]'
Command for doing the same, but performing pruning only on the second layer:
prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10
--pruning_schedule="{'1': [(2, 0.3), (5, 0.6), (8, 0.95)]}"
"""
import ast
from collections import abc
import functools
from os import path
from typing import List
import uuid
from absl import app
from absl import flags
from absl import logging
import flax
from flax.metrics import tensorboard
from flax.training import lr_schedule
import jax
import jax.numpy as jnp
from rigl.experimental.jax.datasets import dataset_factory
from rigl.experimental.jax.models import model_factory
from rigl.experimental.jax.training import training
from rigl.experimental.jax.utils import utils
experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))
logging.info('Saving experimental results to %s', experiment_dir)
host_count = jax.host_count()
local_device_count = jax.local_device_count()
logging.info('Device count: %d, host count: %d, local device count: %d',
jax.device_count(), host_count, local_device_count)
if jax.host_id() == 0:
summary_writer = tensorboard.SummaryWriter(experiment_dir)
dataset = dataset_factory.create_dataset(
FLAGS.dataset,
FLAGS.batch_size,
FLAGS.batch_size_test,
shuffle_buffer_size=FLAGS.shuffle_buffer_size)
logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)
rng = jax.random.PRNGKey(FLAGS.random_seed)
input_shape = (1,) + dataset.shape
base_model, _ = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes)
initial_model, initial_state = model_factory.create_model(
FLAGS.model,
rng, ((input_shape, jnp.float32),),
num_classes=dataset.num_classes,
masked_layer_indices=FLAGS.masked_layer_indices)
if FLAGS.optimizer == 'Adam':
optimizer = flax.optim.Adam(
learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)
elif FLAGS.optimizer == 'Momentum':
optimizer = flax.optim.Momentum(
learning_rate=FLAGS.lr,
beta=FLAGS.momentum,
weight_decay=FLAGS.weight_decay,
nesterov=False)
steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size
if FLAGS.lr_schedule == LR_SCHEDULE_CONSTANT:
lr_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.lr, steps_per_epoch)
elif FLAGS.lr_schedule == LR_SCHEDULE_STEPPED:
lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)
lr_fn = lr_schedule.create_stepped_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, lr_schedule_steps)
elif FLAGS.lr_schedule == LR_SCHEDULE_COSINE:
lr_fn = lr_schedule.create_cosine_learning_rate_schedule(
FLAGS.lr, steps_per_epoch, FLAGS.epochs)
else:
raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}')
# Reuses the FLAX learning rate schedule framework for pruning rate schedule.
pruning_fn_p = functools.partial(
lr_schedule.create_stepped_learning_rate_schedule, FLAGS.pruning_rate,
steps_per_epoch)
if FLAGS.pruning_schedule:
pruning_schedule = ast.literal_eval(FLAGS.pruning_schedule)
if isinstance(pruning_schedule, abc.Mapping):
pruning_rate_fn = {
f'MaskedModule_{layer_num}': pruning_fn_p(schedule)
for layer_num, schedule in pruning_schedule.items()
}
else:
pruning_rate_fn = pruning_fn_p(pruning_schedule)
else:
pruning_rate_fn = lr_schedule.create_constant_learning_rate_schedule(
FLAGS.pruning_rate, steps_per_epoch)
if jax.host_id() == 0:
trainer = training.Trainer(
optimizer,
initial_model,
initial_state,
dataset,
rng,
summary_writer=summary_writer,
)
else:
trainer = training.Trainer(
optimizer, initial_model, initial_state, dataset, rng)
_, best_metrics = trainer.train(
FLAGS.epochs,
lr_fn=lr_fn,
pruning_rate_fn=pruning_rate_fn,
update_iter=FLAGS.update_iterations,
update_epoch=FLAGS.update_epoch,
)
logging.info('Best metrics: %s', str(best_metrics))
if jax.host_id() == 0:
if FLAGS.dump_json:
utils.dump_dict_json(best_metrics,
path.join(experiment_dir, 'best_metrics.json'))
for label, value in best_metrics.items():
summary_writer.scalar(f'best/{label}', value,
FLAGS.epochs * steps_per_epoch)
summary_writer.close()
def main(argv: List[str]):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
run_training()
if __name__ == '__main__':
app.run(main)
================================================
FILE: rigl/experimental/jax/prune_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.prune."""
import glob
from os import path
from absl.testing import absltest
from absl.testing import flagsaver
from rigl.experimental.jax import prune
class PruneTest(absltest.TestCase):
def test_prune_fixed_schedule(self):
"""Tests training/pruning driver with a fixed global sparsity."""
experiment_dir = self.create_tempdir().full_path
eval_flags = dict(
epochs=1,
pruning_rate=0.95,
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
prune.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_prune_global_pruning_schedule(self):
"""Tests training/pruning driver with a global sparsity schedule."""
experiment_dir = self.create_tempdir().full_path
eval_flags = dict(
epochs=10,
pruning_schedule='[(5, 0.33), (7, 0.66), (9, 0.95)]',
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
prune.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
def test_prune_local_pruning_schedule(self):
"""Tests training/pruning driver with a single layer sparsity schedule."""
experiment_dir = self.create_tempdir().full_path
eval_flags = dict(
epochs=10,
pruning_schedule='{1:[(5, 0.33), (7, 0.66), (9, 0.95)]}',
experiment_dir=experiment_dir,
)
with flagsaver.flagsaver(**eval_flags):
prune.main([])
outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')
files = glob.glob(outfile)
self.assertTrue(len(files) == 1 and path.exists(files[0]))
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/__init__.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: rigl/experimental/jax/pruning/init.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for initialization of masked models."""
import functools
from typing import Callable, Sequence, Optional
import flax
import jax
import jax.numpy as jnp
def sparse_init(
base_init,
mask,
dtype=jnp.float32):
"""Weight initializer with correct fan in/fan out for a masked model.
The weight initializer uses any dense initializer to correctly initialize a
masked weight matrix by calling the given initialization method with the
correct fan in/fan out for every neuron in the layer. If the mask is None, it
reverts to the original initialization method.
Args:
base_init: The base (dense) initialization method to use.
mask: The layer's mask, or None.
dtype: The weight array jnp.dtype.
Returns:
An initialization method that is mask aware for the given layer and mask.
"""
def init(rng, shape, dtype=dtype):
if mask is None:
return base_init(rng, shape, dtype)
# Find the ablated neurons in the mask, to determine correct fan_out.
neuron_weight_count = jnp.sum(
jnp.reshape(mask, (-1, mask.shape[-1])), axis=0)
non_zero_neurons = jnp.sum(neuron_weight_count != 0)
# Special case of completely ablated weight matrix/layer.
if jnp.sum(non_zero_neurons) == 0:
print('Empty weight mask!')
return jnp.zeros(shape, dtype)
# Neurons have different fan_in w/mask, build up initialization per-unit.
init_cols = []
rng, *split_rngs = jax.random.split(rng, mask.shape[-1] + 1)
for i in range(mask.shape[-1]):
# Special case of ablated neuron.
if neuron_weight_count[i] == 0:
init_cols.append(jnp.zeros(shape[:-1] + (1,), dtype))
continue
# Fake shape of weight matrix with correct fan_in, and fan_out.
sparse_shape = (int(neuron_weight_count[i]), int(non_zero_neurons))
# Use only the first column of init from initializer, since faked fan_out.
init = base_init(split_rngs[i], sparse_shape, dtype)[Ellipsis, 0]
# Expand out to full sparse array.
expanded_init = jnp.zeros(
mask[Ellipsis, i].shape,
dtype).flatten().at[jnp.where(mask[Ellipsis, i].flatten() == 1)].set(init)
expanded_init = jnp.reshape(expanded_init, mask[Ellipsis, i].shape)
init_cols.append(expanded_init[Ellipsis, jnp.newaxis])
return jnp.concatenate(init_cols, axis=-1)
return init
xavier_sparse_normal = glorot_sparse_normal = functools.partial(
sparse_init, flax.deprecated.nn.initializers.xavier_normal())
kaiming_sparse_normal = he_sparse_normal = functools.partial(
sparse_init, flax.deprecated.nn.initializers.kaiming_normal())
================================================
FILE: rigl/experimental/jax/pruning/init_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.pruning.init."""
from typing import Any, Mapping, Optional
from absl.testing import absltest
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import init
from rigl.experimental.jax.pruning import masked
class MaskedDense(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=layer_mask,
kernel_init=flax.deprecated.nn.initializers.kaiming_normal())
class MaskedDenseSparseInit(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
*args,
mask = None,
**kwargs):
inputs = inputs.reshape(inputs.shape[0], -1)
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=layer_mask,
kernel_init=init.kaiming_sparse_normal(
layer_mask['kernel'] if layer_mask is not None else None),
**kwargs)
class MaskedCNN(flax.deprecated.nn.Module):
"""Single-layer CNN Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Conv,
kernel_size=(3, 3),
mask=layer_mask,
kernel_init=flax.deprecated.nn.initializers.kaiming_normal())
class MaskedCNNSparseInit(flax.deprecated.nn.Module):
"""Single-layer CNN Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
*args,
mask = None,
**kwargs):
layer_mask = mask['MaskedModule_0'] if mask else None
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Conv,
kernel_size=(3, 3),
mask=layer_mask,
kernel_init=init.kaiming_sparse_normal(
layer_mask['kernel'] if layer_mask is not None else None),
**kwargs)
class InitTest(absltest.TestCase):
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._batch_size = 2
self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)
self._input = jnp.ones(*self._input_shape)
def test_init_kaiming_sparse_normal_output(self):
"""Tests the output shape/type of kaiming normal sparse initialization."""
input_array = jnp.ones((64, 16), jnp.float32)
mask = jax.random.bernoulli(self._rng, shape=(64, 16))
base_init = flax.deprecated.nn.initializers.kaiming_normal()(
self._rng, input_array.shape, input_array.dtype)
sparse_init = init.kaiming_sparse_normal(mask)(self._rng, input_array.shape,
input_array.dtype)
with self.subTest(name='test_sparse_init_output_shape'):
self.assertSequenceEqual(sparse_init.shape, base_init.shape)
with self.subTest(name='test_sparse_init_output_dtype'):
self.assertEqual(sparse_init.dtype, base_init.dtype)
with self.subTest(name='test_sparse_init_output_notallzero'):
self.assertTrue((sparse_init != 0).any())
def test_dense_no_mask(self):
"""Checks that in the special case of no mask, init is same as base_init."""
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._input_shape,))
self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)
_, initial_params = MaskedDenseSparseInit.init_by_shape(
jax.random.PRNGKey(42), (self._input_shape,), mask=None)
self._masked_model_sparse_init = flax.deprecated.nn.Model(
MaskedDenseSparseInit, initial_params)
self.assertTrue(
jnp.isclose(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'], self._unmasked_model.params['MaskedModule_0']
['unmasked']['kernel']).all())
def test_dense_sparse_init_kaiming(self):
"""Checks kaiming normal sparse initialization for dense layer."""
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._input_shape,))
self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)
mask = masked.simple_mask(self._unmasked_model, jnp.ones,
masked.WEIGHT_PARAM_NAMES)
_, initial_params = MaskedDenseSparseInit.init_by_shape(
jax.random.PRNGKey(42), (self._input_shape,), mask=mask)
self._masked_model_sparse_init = flax.deprecated.nn.Model(
MaskedDenseSparseInit, initial_params)
mean_init = jnp.mean(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
stddev_init = jnp.std(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
mean_sparse_init = jnp.mean(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
stddev_sparse_init = jnp.std(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
with self.subTest(name='test_cnn_sparse_init_mean'):
self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,
mean_init + 2 * stddev_init)
with self.subTest(name='test_cnn_sparse_init_stddev'):
self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,
1.5 * stddev_init)
def test_cnn_sparse_init_kaiming(self):
"""Checks kaiming normal sparse initialization for convolutional layer."""
_, initial_params = MaskedCNN.init_by_shape(self._rng, (self._input_shape,))
self._unmasked_model = flax.deprecated.nn.Model(MaskedCNN, initial_params)
mask = masked.simple_mask(self._unmasked_model, jnp.ones,
masked.WEIGHT_PARAM_NAMES)
_, initial_params = MaskedCNNSparseInit.init_by_shape(
jax.random.PRNGKey(42), (self._input_shape,), mask=mask)
self._masked_model_sparse_init = flax.deprecated.nn.Model(
MaskedCNNSparseInit, initial_params)
mean_init = jnp.mean(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
stddev_init = jnp.std(
self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])
mean_sparse_init = jnp.mean(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
stddev_sparse_init = jnp.std(
self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']
['kernel'])
with self.subTest(name='test_cnn_sparse_init_mean'):
self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,
mean_init + 2 * stddev_init)
with self.subTest(name='test_cnn_sparse_init_stddev'):
self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,
1.5 * stddev_init)
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/mask_factory.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pruning mask factory.
Attributes:
MaskFnType: A type alias for functions to create sparse masks.
MASK_TYPES: Masks types that can be created.
"""
from typing import Any, Callable, Mapping
import flax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import masked
# A function to create a mask, takes as arguments: a flax model, JAX PRNG Key,
# sparsity level as a float in [0, 1].
MaskFnType = Callable[
[flax.deprecated.nn.Model, Callable[[int],
jnp.array], float], masked.MaskType]
MASK_TYPES: Mapping[str, MaskFnType] = {
'random':
masked.shuffled_mask,
'per_neuron':
masked.shuffled_neuron_mask,
'per_neuron_no_input_ablation':
masked.shuffled_neuron_no_input_ablation_mask,
'symmetric':
masked.symmetric_mask,
}
def create_mask(mask_type, base_model,
rng, sparsity,
**kwargs):
"""Creates a Mask of the given type.
Args:
mask_type: the name of the type of mask to instantiate.
base_model: the model to create a mask for.
rng : the random number generator to use for init.
sparsity: the mask sparsity.
**kwargs: list of model specific keyword arguments.
Returns:
A mask for a FLAX model.
Raises:
ValueError if a model with the given name does not exist.
"""
if mask_type not in MASK_TYPES:
raise ValueError(f'Unknown mask type: {mask_type}')
return MASK_TYPES[mask_type](base_model, rng, sparsity, **kwargs)
================================================
FILE: rigl/experimental/jax/pruning/mask_factory_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for weight_symmetry.models.model_factory."""
from typing import Mapping, Optional
from absl.testing import absltest
from absl.testing import parameterized
import flax
import jax
import jax.numpy as jnp
from rigl.experimental.jax.pruning import mask_factory
from rigl.experimental.jax.pruning import masked
class MaskedDense(flax.deprecated.nn.Module):
"""Single-layer Dense Masked Network."""
NUM_FEATURES: int = 32
def apply(self,
inputs,
mask = None):
inputs = inputs.reshape(inputs.shape[0], -1)
return masked.MaskedModule(
inputs,
features=self.NUM_FEATURES,
wrapped_module=flax.deprecated.nn.Dense,
mask=mask['MaskedModule_0'] if mask else None)
class MaskFactoryTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self._rng = jax.random.PRNGKey(42)
self._input_shape = ((1, 28, 28, 1), jnp.float32)
self._num_classes = 10
self._sparsity = 0.9
_, initial_params = MaskedDense.init_by_shape(self._rng,
(self._input_shape,))
# Use the same initialization for both masked/unmasked models.
self._model = flax.deprecated.nn.Model(MaskedDense, initial_params)
def _create_mask(self, mask_type):
return mask_factory.create_mask(
mask_type, self._model,
self._rng, self._sparsity)
@parameterized.parameters(*mask_factory.MASK_TYPES.keys())
def test_mask_supported(self, mask_type):
"""Tests supported mask types."""
mask = self._create_mask(mask_type)
with self.subTest(name='test_mask_type'):
self.assertIsInstance(mask, dict)
def test_mask_unsupported(self):
"""Tests unsupported mask types."""
with self.assertRaisesRegex(ValueError,
'Unknown mask type: unsupported'):
self._create_mask('unsupported')
if __name__ == '__main__':
absltest.main()
================================================
FILE: rigl/experimental/jax/pruning/masked.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Masked wrapped for FLAX modules.
Attributes:
WEIGHT_PARAM_NAMES: The name of the weight parameters to use.
MaskType: Model mask type for static type checking.
MaskLayerType: Mask layer type for static type checking.
MutableMaskType: Mutable model mask type for static type checking.
MutableMaskLayerType: Mutable mask layer type for static type checking.
"""
import functools
import operator
from typing import Any, Callable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple, Type
from absl import logging
import flax
import jax
import jax.numpy as jnp
import jax.ops
# Model weight param names, e.g. 'kernel', (as opposed batch norm param, etc).
WEIGHT_PARAM_NAMES = ('kernel',) # Note: Bias is not typically masked.
# Mask layer type for static type checking.
MaskLayerType = Mapping[str, Optional[jnp.array]]
# Model mask type for static type checking.
MaskType = Mapping[str, Optional[MaskLayerType]]
# Mask layer type for static type checking.
MutableMaskLayerType = MutableMapping[str, Optional[jnp.array]]
# Model mask type for static type checking.
MutableMaskType = MutableMapping[str, MutableMaskLayerType]
class MaskedModule(flax.deprecated.nn.Module):
"""Generic FLAX Masking Module.
Masks a FLAX module, given a mask for params of each layer.
Attributes:
UNMASKED: The key to use for the unmasked parameter dictionary.
"""
UNMASKED = 'unmasked'
def apply(self,
*args,
wrapped_module,
mask = None,
**kwargs):
"""Apply the wrapped module, while applying the given masks to its params.
Args:
*args: The positional arguments for the wrapped module.
wrapped_module: The module class to be wrapped.
mask: The mask nested dictionary containing masks for the wrapped module's
params, in the same format/with the same keys as the module param dict
(or None if not to mask).
**kwargs: The keyword arguments for the wrapped module.
Returns:
The intermediate outputs specified by truncate_path.
Raises:
ValueError: If the given mask is not valid for the wrapped module, i.e. the
pytrees do not match.
"""
# Explicitly create the parameters of the wrapped module.
def init_fn(rng, input_shape):
del input_shape # Unused.
# Call init to get the params of the wrapped module.
_, params = wrapped_module.init(rng, *args, **kwargs)
return params
unmasked_params = self.param(self.UNMASKED, None, init_fn)
if mask is not None:
try:
masked_params = jax.tree_util.tree_map(
lambda x, *xs: x
if xs[0] is None else x * xs[0], unmasked_params, mask)
except ValueError as err:
raise ValueError('Mask is invalid for model.') from err
# Call the wrapped module with the masked params.
return wrapped_module.call(masked_params, *args, **kwargs)
else:
logging.warning('Using masked module without mask!')
# Call the wrapped module with the unmasked params.
return wrapped_module.call(unmasked_params, *args, **kwargs)
def masked(module, mask):
"""Convenience function for masking a FLAX module with MaskedModule."""
return MaskedModule.partial(wrapped_module=module, mask=mask)
def generate_model_masks(
depth,
mask = None,
masked_layer_indices = None):
"""Creates empty masks for this model, or initializes with existing mask.
Args:
depth: Number of layers in the model.
mask: Existing model mask for layers in this model, if not given, all
module masks are initialized to None.
masked_layer_indices: The layer indices of layers in model to be masked, or
all if None.
Returns:
A model mask, with None where no mask is given for a model layer, or that
specific layer is indicated as not to be masked by the masked_layer_indices
parameter.
"""
if depth <= 0:
raise ValueError(f'Invalid model depth: {depth}')
if mask is None:
mask = {f'MaskedModule_{i}': None for i in range(depth)}
# Have to explicitly check for None to differentiate from empty array.
if masked_layer_indices is not None:
# Check none of the indices are outside of model's layer bounds.
if any(i < 0 or i >= depth for i in masked_layer_indices):
raise ValueError(
f'Invalid indices for given depth ({depth}): {masked_layer_indices}')
mask = {
f'MaskedModule_{i}': mask[f'MaskedModule_{i}']
for i in masked_layer_indices
}
return mask
def _filter_param(param_names,
invert = False):
"""Convenience function for filtering maskable parameters from paths.
Args:
param_names: Names of parameters we are looking for.
invert: Inverts filter to exclude, rather than include, given parameters.
Returns:
A function to use with flax.deprecated.nn.optim.ModelParamTraversal for
filtering.
"""
def filter_fn(path, value):
del value # Unused.
parameter_found = any([
'{}/{}'.format(MaskedModule.UNMASKED, param_name) in path
for param_name in param_names
])
return not parameter_found if invert else parameter_found
return filter_fn
def mask_map(model,
fn):
"""Convenience function to create a mask for a model.
Args:
model: The Flax model, with at least one MaskedModule layer.
fn: The function to call on each masked parameter, to create the mask for
that parameter, takes the parameter name, and parameter value as arguments
and returns the new parameter value.
Returns:
A model parameter dictionary, with all masked parameters set by the given
function, and all other parameters set to None.
Raises:
ValueError: If the given model does not support masking, i.e. none of the
layers are wrapped by a MaskedModule.
"""
maskable = False
for layer_key, layer in model.params.items():
if MaskedModule.UNMASKED not in layer:
logging.warning(
'Layer \'%s\' does not support masking, i.e. it is not '
'wrapped by a MaskedModule', layer_key)
else:
maskable = True
if not maskable:
raise ValueError('Model does not support masking, i.e. no layers are '
'wrapped by a MaskedModule.')
# First set all non-masked params to None in copy of model pytree.
filter_non_masked = _filter_param(WEIGHT_PARAM_NAMES, invert=True)
nonmasked_traversal = flax.optim.ModelParamTraversal(filter_non_masked) # pytype: disable=module-attr
mask_model = nonmasked_traversal.update(lambda _: None, model)
# Then find params to mask, and set to array.
for param_name in WEIGHT_PARAM_NAMES:
filter_masked = _filter_param(WEIGHT_PARAM_NAMES)
mask_traversal = flax.optim.ModelParamTraversal(filter_masked) # pytype: disable=module-attr
mask_model = mask_traversal.update(
functools.partial(fn, param_name), mask_model)
mask = mask_model.params
# Remove unneeded unmasked param for mask.
for layer_key, layer in mask.items():
if MaskedModule.UNMASKED in layer:
mask[layer_key] = layer[MaskedModule.UNMASKED]
return mask
def iterate_mask(
mask,
param_names = None
):
"""Iterate over the parameters in as mask.
Args:
mask: The model mask.
param_names: The parameter names to iterate over in each layer, if None
iterates over all parameters of all layers.
Yields:
An iterator of tuples containing the parameter path and parameter value
in sorted order of layer parameters matching the names in param_names (or
all parameters if None).
"""
flat_mask = flax.traverse_util.flatten_dict(mask)
for key, value in flat_mask.items():
if param_names is None or key in param_names:
path = '/' + '/'.join(key)
yield path, value
def shuffled_mask(model, rng,
sparsity):
"""Returns a randomly shuffled mask with a given sparsity for all layers.
Returns a random weight mask for a model param array, by randomly shuffling a
mask with a fixed number of non-zero/zero entries, given by the sparsity.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
Returns:
A randomly shuffled weight mask, in the same form as flax.Module.params.
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are
maskable, i.e. is wrapped by MaskedModule.
"""
if sparsity > 1 or sparsity < 0:
raise ValueError(
'Given sparsity, {}, is not in range [0, 1]'.format(sparsity))
def create_shuffled_mask(param_name, param):
del param_name # Unused.
mask = jnp.arange(param.size)
mask = jnp.where(mask >= sparsity * param.size, jnp.ones_like(mask),
jnp.zeros_like(mask))
mask = jax.random.permutation(rng, mask)
return mask.reshape(param.shape)
return mask_map(model, create_shuffled_mask)
def random_mask(model,
rng,
mean_sparsity = 0.5):
"""Returns a random weight mask for a masked model.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
mean_sparsity: The mean number of 0's in the mask, i.e. mean = (1 -
mean_sparsity) for the Bernoulli distribution to sample from.
Returns:
A random weight mask, in the same form as flax.Module.params
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or if a layer to
mask is not maskable, i.e. is not wrapped by MaskedModule.
"""
if mean_sparsity > 1 or mean_sparsity < 0:
raise ValueError(
'Given sparsity, {}, is not in range [0, 1]'.format(mean_sparsity))
# Invert mean_sparsity to get mean for Bernoulli distribution.
mean = 1. - mean_sparsity
def create_random_mask(param_name, param):
del param_name # Unused.
return jax.random.bernoulli(
rng, p=mean,
shape=param.shape).astype(jnp.int32) # TPU doesn't support uint8.
return mask_map(model, create_random_mask)
def simple_mask(model,
init_fn,
masked_param):
"""Creates a mask given a model and numpy initialization function.
Args:
model: The model to create a mask for.
init_fn: The numpy initialization function, e.g. numpy.ones.
masked_param: The list of parameters to mask.
Returns:
A mask for the model.
"""
def create_init_fn_mask(param_name, param):
if param_name in masked_param:
return init_fn(param.shape)
return None
return mask_map(model, create_init_fn_mask)
def symmetric_mask(model,
rng,
sparsity = 0.5):
"""Generates a random weight mask that's symmetric, i.e. structurally pruned.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), in the
range [0, 1]: 1.0 will mask all weights, while 0 will mask none.
Returns:
A symmetric random weight mask, in the same form as flax.Module.params.
"""
if sparsity > 1 or sparsity < 0:
raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')
def create_neuron_symmetric_mask(param_name, param):
del param_name # Unused.
neuron_length = functools.reduce(operator.mul, param.shape[:-1])
neuron_mask = jnp.arange(neuron_length)
neuron_mask = jnp.where(neuron_mask >= sparsity * neuron_mask.size,
jnp.ones_like(neuron_mask),
jnp.zeros_like(neuron_mask))
neuron_mask = jax.random.shuffle(rng, neuron_mask)
mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1)
return mask.reshape(param.shape)
return mask_map(model, create_neuron_symmetric_mask)
class _PerNeuronShuffle:
"""This class is needed to get around the fact that JAX RNG is stateless."""
def __init__(self, init_rng, sparsity):
"""Creates the per-neuron shuffle class, with initial RNG state.
Args:
init_rng: The initial random number generator state to use.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
"""
self._rng = init_rng
self._sparsity = sparsity
def __call__(self, param_name, param):
"""Shuffles the weight matrix/mask for a given parameter, per-neuron.
This is to be used with mask_map, and accepts the standard mask_map
function parameters.
Args:
param_name: The parameter's name.
param: The parameter's weight or mask matrix.
Returns:
A shuffled weight/mask matrix, with each neuron shuffled independently.
"""
del param_name # Unused.
neuron_length = functools.reduce(operator.mul, param.shape[:-1])
neuron_mask = jnp.arange(neuron_length)
neuron_mask = jnp.where(neuron_mask >= self._sparsity * neuron_mask.size,
jnp.ones_like(neuron_mask),
jnp.zeros_like(neuron_mask))
mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1)
self._rng, rng_input = jax.random.split(self._rng)
mask = jax.random.shuffle(rng_input, mask, axis=0)
return mask.reshape(param.shape)
def shuffled_neuron_mask(model,
rng,
sparsity):
"""Returns a shuffled mask with a given fixed sparsity for all neurons/layers.
Returns a randomly shuffled weight mask for a model param array, by setting a
fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector
in the model, and then randomly shuffling each neuron's weight mask with a
fixed number of non-zero/zero entries, given by the sparsity. This ensures no
neuron is ablated for a non-zero sparsity.
Note: This is much more complicated for convolutional layers due to the
receptive field being different for every pixel! We only take into account
channel-wise masks and not spatial ablations in propagation in that case.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
Returns:
A randomly shuffled weight mask, in the same form as flax.Module.params.
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are
maskable, i.e. is wrapped by MaskedModule.
"""
if sparsity > 1 or sparsity < 0:
raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')
return mask_map(model, _PerNeuronShuffle(rng, sparsity))
def _fill_diagonal_wrap(shape,
value,
dtype = jnp.uint8):
"""Fills the diagonal of a 2D array, while also wrapping tall arrays.
For a matrix of dimensions (N x M),:
if N <= M, i.e. the array is wide rectangular, the array's diagonal is
filled, for example:
_fill_diagonal_wrap(jnp.zeroes((2, 3), dtype=uint8), 1)
> [[1, 0, 0],
[0, 1, 0]]
if N > M, i.e. the array is tall rectangular, the array's diagonal, and
offset diagonals are filled. This differs from
numpy.fill_diagonal(..., wrap=True), in that it does not include a single
row gap between the diagonals, and it is not in-place but returns a copy of
the given array. For example,
_fill_diagonal_wrap(jnp.zeroes((3, 2), dtype=uint8), 1)
> [[1, 0],
[0, 1],
[1, 0]]
Args:
shape: The shape of the 2D array to return with the diagonal filled.
value: The value to fill in for the diagonal, and offset diagonals.
dtype: The datatype of the jax numpy array to return.
Returns:
A copy of the given array with the main diagonal filled, and offset
diagonals filled if the given array is tall.
"""
if len(shape) != 2:
raise ValueError(
f'Expected an 2D array, however array has dimensions: {shape}')
array = jnp.zeros(shape, dtype=dtype)
rows, cols = shape
def diagonal_indices(offset): # Returns jax.ops._Indexable.
"""Returns slice of the nth diagonal of an array, where n is offset."""
# This is an a numpy-style advanced slice of the form [start:end:step], that
# gives you the offset (vertically) diagonal of an array. If it was the main
# diagonal of a (flattened) square matrix of n X n it would be 0:n**2:n+1,
# i.e. start at 0, and look at each n+1 elements, end when you get to end
# of array. We need to look at vertically-offset diagonals as well, which is
# handled by offset.
return jnp.index_exp[cols * offset:cols * (offset + cols):cols + 1]
# Fills (square) matrix diagonals with the given value, tiling over tall
# rectangular arrays by offsetting the filled diagonals by multiples of the
# height of the square arrays.
diagonals = [
array.ravel().at[diagonal_indices(offset)].set(value).reshape(array.shape)
for offset in range(0, rows, cols)
]
return functools.reduce(jnp.add, diagonals)
def _random_neuron_mask(neuron_length,
unmasked_count,
rng,
dtype = jnp.uint32):
"""Generates a random mask for a neuron.
Args:
neuron_length: The length of the neuron's weight vector.
unmasked_count: The number of elements that should be unmasked.
rng: A jax.random.PRNGKey random seed.
dtype: Type of array to create.
Returns:
A random neuron weight vector mask.
"""
if unmasked_count > neuron_length:
raise ValueError('unmasked_count cannot be greater that neuron_length: '
f'{unmasked_count} > {neuron_length}')
neuron_mask = jnp.concatenate(
(jnp.ones(unmasked_count), jnp.zeros(neuron_length - unmasked_count)),
axis=0)
neuron_mask = jax.random.shuffle(rng, neuron_mask)
return neuron_mask.astype(dtype)
class _PerNeuronNoInputAblationShuffle:
"""This class is needed to get around the fact that JAX RNG is stateless."""
def __init__(self, init_rng, sparsity):
"""Creates the per-neuron shuffle class, with initial RNG state.
Args:
init_rng: The initial random number generator state to use.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, while 0 will mask none.
"""
self._rng = init_rng
self._sparsity = sparsity
def _get_rng(self):
"""Creates a new JAX RNG, while updating RNG state."""
self._rng, rng_input = jax.random.split(self._rng)
return rng_input
def __call__(self, param_name, param):
"""Shuffles the weight matrix/mask for a given parameter, per-neuron.
This is to be used with mask_map, and accepts the standard mask_map
function parameters.
Args:
param_name: The parameter's name.
param: The parameter's weight or mask matrix.
Returns:
A shuffled weight/mask matrix, with each neuron shuffled independently.
"""
del param_name # Unused.
incoming_connections = jnp.prod(jnp.array(param.shape[:-1]))
num_neurons = param.shape[-1]
# Ensure each input neuron has at least one connection unmasked.
mask = _fill_diagonal_wrap((incoming_connections, num_neurons), 1,
dtype=jnp.uint8)
# Randomly shuffle which of the neurons have these connections.
mask = jax.random.shuffle(self._get_rng(), mask, axis=0)
# Add extra required random connections to mask to satisfy sparsity.
mask_cols = []
for col in range(mask.shape[-1]):
neuron_mask = mask[:, col]
off_diagonal_count = max(
round((1 - self._sparsity) * incoming_connections)
- jnp.count_nonzero(neuron_mask), 0)
zero_indices = jnp.flatnonzero(neuron_mask == 0)
random_entries = _random_neuron_mask(
len(zero_indices), off_diagonal_count, self._get_rng())
neuron_mask = neuron_mask.at[zero_indices].set(random_entries)
mask_cols.append(neuron_mask)
return jnp.column_stack(mask_cols).reshape(param.shape)
def shuffled_neuron_no_input_ablation_mask(model,
rng,
sparsity):
"""Returns a shuffled mask with a given fixed sparsity for all neurons/layers.
Returns a randomly shuffled weight mask for a model param array, by setting a
fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector
in the model, and then randomly shuffling each neuron's weight mask with a
fixed number of non-zero/zero entries, given by the sparsity. This ensures no
neuron is ablated for a non-zero sparsity.
This function also ensures that no neurons in the previous layer are
effectively ablated, by ensuring that each neuron has at least one connection.
Note: This is much more complicated for convolutional layers due to the
receptive field being different for every pixel! We only take into account
channel-wise masks and not spatial ablations in propagation in that case.
Args:
model: Flax model that contains masked modules.
rng: Random number generator, i.e. jax.random.PRNGKey.
sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will
mask all weights, except for the minimum number required to maintain,
connectivity with the input layer, while 0 will mask none.
Returns:
A randomly shuffled weight mask, in the same form as flax.Module.params.
Raises:
ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are
maskable, i.e. is wrapped by MaskedModule.
"""
if sparsity > 1.0 or sparsity < 0.0:
raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')
# First, generate a random permutation matrix, and ensure our mask has at
# least N connections, where there are N neurons in the previous layer.
return mask_map(model, _PerNeuronNoInputAblationShuffle(rng, sparsity))
def propagate_masks(
mask,
param_names = WEIGHT_PARAM_NAMES
):
"""Accounts for implicitly pruned neurons in a model's weight masks.
When neurons are randomly ablated in one layer, they can effectively ablate
neurons in the next layer if in effect all incoming weights of a neuron are
zero. This method accounts for this by propagating forward mask information
through the entire model.
Args:
mask: Model masks to check, in same pytree structure as Model.params.
param_names: List of param keys in mask to count.
Returns:
A refined model mask with weights that are effectively ablated in the
original mask set to zero.
"""
flat_mask = flax.traverse_util.flatten_dict(mask)
mask_layer_list = list(flat_mask.values())
mask_layer_keys = list(flat_mask.keys())
mask_layer_param_names = [layer_param[-1] for layer_param in mask_layer_keys]
for param_name in param_names:
# Find which of the param arrays correspond to leaf nodes with this name.
param_indices = [
i for i, names in enumerate(mask_layer_param_names)
if param_name in names
]
for i in range(1, len(param_indices)):
last_weight_mask = mask_layer_list[param_indices[i - 1]]
weight_mask = mask_layer_list[param_indices[i]]
if last_weight_mask is None or weight_mask is None:
continue
last_weight_mask_reshaped = jnp.reshape(last_weight_mask,
(-1, last_weight_mask.shape[-1]))
# Neurons with any outgoing weights from previous layer.
alive_incoming = jnp.sum(last_weight_mask_reshaped, axis=0) != 0
# Combine effective mask of previous layer with neuron's current mask.
if len(weight_mask.shape) > 2:
# Convolutional layer, only consider channel-wise masks, if any spatial
# weight is non-zero that channel is considered non-masked.
spatial_dim = len(weight_mask.shape) - 2
new_weight_mask = alive_incoming[:, jnp.newaxis] * jnp.amax(
weight_mask, axis=tuple(range(spatial_dim)))
new_weight_mask = jnp.tile(new_weight_mask,
weight_mask.shape[:-2] + (1, 1))
else:
# Check for case of dense following convolution, i.e. spatial input into
# dense, to prevent b/156135283. Must use convolution for these layers.
if len(last_weight_mask.shape) > 2:
raise ValueError(
'propagate_masks requires knowledge of the spatial '
'dimensions of the previous layer. Use a functionally equivalent '
'conv. layer in place of a dense layer in a model with a mixed '
'conv/dense setting.')
new_weight_mask = alive_incoming[:, jnp.newaxis] * weight_mask
mask_layer_list[param_indices[i]] = jnp.reshape(
new_weight_mask, mask_layer_list[param_indices[i]].shape)
return flax.traverse_util.unflatten_dict(
dict(zip(mask_layer_keys, mask_layer_list)))
def mask_layer_sparsity(mask_layer):
"""Calculates the sparsity of a single layer's mask.
Args:
mask_layer: mask layer to calculate the sparsity of.
Returns:
The sparsity of the mask.
"""
parameter_count = 0
masked_count = 0
for key in mask_layer:
if mask_layer[key] is not None and key in WEIGHT_PARAM_NAMES:
parameter_count += mask_layer[key].size
masked_count += jnp.sum(mask_layer[key])
if parameter_count == 0:
return 0.
return 1. - masked_count/parameter_count
def mask_sparsity(
mask,
param_names = None):
"""Calculates the sparsity of the given parameters over a model mask.
Args:
mask: Model mask to calculate sparsity over.
param_names: List of param keys in mask to count.
Returns:
The overall sparsity of the mask.
"""
if param_names is None:
param_names = WEIGHT_PARAM_NAMES
parameter_count = 0
masked_count = 0
for path, value in iterate_mask(mask):
if value is not None and any(
param_name in path for param_name in param_names):
parameter_count += value.size
masked_count += jnp.sum(value.flatten())
if parameter_count == 0:
return 0.
return 1.0 - float(masked_count / parameter_count)
================================================
FILE: rigl/experimental/jax/pruning/masked_test.py
================================================
# coding=utf-8
# Copyright 2022 RigL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "L
gitextract_rhnta46k/ ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── rigl/ │ ├── __init__.py │ ├── cifar_resnet/ │ │ ├── data_helper.py │ │ ├── data_helper_test.py │ │ ├── resnet_model.py │ │ └── resnet_train_eval.py │ ├── experimental/ │ │ └── jax/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── analysis/ │ │ │ └── plot_summary_json.ipynb │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── cifar10.py │ │ │ ├── cifar10_test.py │ │ │ ├── dataset_base.py │ │ │ ├── dataset_base_test.py │ │ │ ├── dataset_factory.py │ │ │ ├── dataset_factory_test.py │ │ │ ├── mnist.py │ │ │ └── mnist_test.py │ │ ├── fixed_param.py │ │ ├── fixed_param_test.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cifar10_cnn.py │ │ │ ├── cifar10_cnn_test.py │ │ │ ├── mnist_cnn.py │ │ │ ├── mnist_cnn_test.py │ │ │ ├── mnist_fc.py │ │ │ ├── mnist_fc_test.py │ │ │ ├── model_factory.py │ │ │ └── model_factory_test.py │ │ ├── prune.py │ │ ├── prune_test.py │ │ ├── pruning/ │ │ │ ├── __init__.py │ │ │ ├── init.py │ │ │ ├── init_test.py │ │ │ ├── mask_factory.py │ │ │ ├── mask_factory_test.py │ │ │ ├── masked.py │ │ │ ├── masked_test.py │ │ │ ├── pruning.py │ │ │ ├── pruning_test.py │ │ │ ├── symmetry.py │ │ │ └── symmetry_test.py │ │ ├── random_mask.py │ │ ├── random_mask_test.py │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── shuffled_mask.py │ │ ├── shuffled_mask_test.py │ │ ├── train.py │ │ ├── train_test.py │ │ ├── training/ │ │ │ ├── __init__.py │ │ │ ├── training.py │ │ │ └── training_test.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── imagenet_resnet/ │ │ ├── colabs/ │ │ │ ├── MobileNet_Counting.ipynb │ │ │ └── Resnet_50_Param_Flops_Counting.ipynb │ │ ├── imagenet_train_eval.py │ │ ├── mobilenetv1_model.py │ │ ├── mobilenetv2_model.py │ │ ├── pruning_layers.py │ │ ├── resnet_model.py │ │ ├── train_test.py │ │ ├── utils.py │ │ └── vgg.py │ ├── mnist/ │ │ ├── mnist_train_eval.py │ │ └── visualize_mask_records.py │ ├── requirements.txt │ ├── rigl_tf2/ │ │ ├── README.md │ │ ├── colabs/ │ │ │ └── MnistProp.ipynb │ │ ├── configs/ │ │ │ ├── dense.gin │ │ │ ├── grasp.gin │ │ │ ├── hessian.gin │ │ │ ├── interpolate.gin │ │ │ ├── lottery.gin │ │ │ ├── prune.gin │ │ │ ├── rigl.gin │ │ │ ├── scratch.gin │ │ │ ├── set.gin │ │ │ ├── small_dense.gin │ │ │ └── snip.gin │ │ ├── init_utils.py │ │ ├── interpolate.py │ │ ├── mask_updaters.py │ │ ├── metainit.py │ │ ├── mlp_configs/ │ │ │ ├── dense.gin │ │ │ ├── lottery.gin │ │ │ ├── prune.gin │ │ │ ├── rigl.gin │ │ │ ├── scratch.gin │ │ │ ├── set.gin │ │ │ └── small_dense.gin │ │ ├── networks.py │ │ ├── train.py │ │ └── utils.py │ ├── rl/ │ │ ├── README.md │ │ ├── dqn_agents.py │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── run_experiment.py │ │ ├── sparse_utils.py │ │ ├── sparsetrain_configs/ │ │ │ ├── dqn_atari_dense.gin │ │ │ ├── dqn_atari_dense_impala_net.gin │ │ │ ├── dqn_atari_prune.gin │ │ │ ├── dqn_atari_prune_impala_net.gin │ │ │ ├── dqn_atari_rigl.gin │ │ │ ├── dqn_atari_rigl_impala_net.gin │ │ │ ├── dqn_atari_set.gin │ │ │ ├── dqn_atari_set_impala_net.gin │ │ │ ├── dqn_atari_static.gin │ │ │ └── dqn_atari_static_impala_net.gin │ │ ├── tfagents/ │ │ │ ├── configs/ │ │ │ │ ├── dqn_gym_dense_config.gin │ │ │ │ ├── dqn_gym_pruning_config.gin │ │ │ │ ├── dqn_gym_sparse_config.gin │ │ │ │ ├── ppo_mujoco_dense_config.gin │ │ │ │ ├── ppo_mujoco_pruning_config.gin │ │ │ │ ├── ppo_mujoco_sparse_config.gin │ │ │ │ ├── sac_mujoco_dense_config.gin │ │ │ │ ├── sac_mujoco_pruning_config.gin │ │ │ │ └── sac_mujoco_sparse_config.gin │ │ │ ├── dqn_train_eval.py │ │ │ ├── ppo_train_eval.py │ │ │ ├── sac_train_eval.py │ │ │ ├── sparse_encoding_network.py │ │ │ ├── sparse_ppo_actor_network.py │ │ │ ├── sparse_ppo_discrete_actor_network.py │ │ │ ├── sparse_ppo_discrete_actor_network_test.py │ │ │ ├── sparse_tanh_normal_projection_network.py │ │ │ ├── sparse_value_network.py │ │ │ └── tf_sparse_utils.py │ │ └── train.py │ ├── sparse_optimizers.py │ ├── sparse_optimizers_base.py │ ├── sparse_optimizers_test.py │ ├── sparse_utils.py │ ├── sparse_utils_test.py │ └── str_sparsities.py └── run.sh
SYMBOL INDEX (698 symbols across 81 files)
FILE: rigl/cifar_resnet/data_helper.py
function pad_input (line 29) | def pad_input(x, crop_dim=4):
function preprocess_train (line 47) | def preprocess_train(x, width, height):
function input_fn (line 64) | def input_fn(params):
FILE: rigl/cifar_resnet/data_helper_test.py
class DataHelperTest (line 39) | class DataHelperTest(tf.test.TestCase, parameterized.TestCase):
method get_next (line 41) | def get_next(self):
method testInputPipeline (line 55) | def testInputPipeline(self):
method testTrainingStep (line 78) | def testTrainingStep(self, training_method):
FILE: rigl/cifar_resnet/resnet_model.py
class WideResNetModel (line 33) | class WideResNetModel(object):
method __init__ (line 36) | def __init__(self,
method build (line 70) | def build(self, inputs, depth, width, num_classes, name=None):
method _batch_norm (line 127) | def _batch_norm(self, net, name=None):
method _dense (line 150) | def _dense(self, net, num_units, name=None, sparsity_technique='baseli...
method _conv (line 158) | def _conv(self,
method _residual_block (line 183) | def _residual_block(self, net, name, output_size, subsample, blocks):
FILE: rigl/cifar_resnet/resnet_train_eval.py
function create_eval_metrics (line 141) | def create_eval_metrics(labels, logits):
function train_fn (line 171) | def train_fn(training_method, global_step, total_loss, train_dir, accuracy,
function build_model (line 299) | def build_model(mode,
function wide_resnet_w_pruning (line 370) | def wide_resnet_w_pruning(features, labels, mode, params):
function main (line 474) | def main(argv):
FILE: rigl/experimental/jax/datasets/cifar10.py
class CIFAR10Dataset (line 27) | class CIFAR10Dataset(dataset_base.ImageDataset):
method __init__ (line 39) | def __init__(self,
method preprocess (line 66) | def preprocess(
FILE: rigl/experimental/jax/datasets/cifar10_test.py
class CIFAR10DatasetTest (line 23) | class CIFAR10DatasetTest(absltest.TestCase):
method setUp (line 26) | def setUp(self):
method test_create_dataset (line 38) | def test_create_dataset(self):
method test_train_image_dims_content (line 42) | def test_train_image_dims_content(self):
method test_test_image_dims_content (line 68) | def test_test_image_dims_content(self):
method test_train_data_length (line 94) | def test_train_data_length(self):
method test_test_data_length (line 102) | def test_test_data_length(self):
method test_dataset_nonevenly_divisible_batch_size (line 110) | def test_dataset_nonevenly_divisible_batch_size(self):
FILE: rigl/experimental/jax/datasets/dataset_base.py
class Dataset (line 29) | class Dataset(metaclass=abc.ABCMeta):
method __init__ (line 47) | def __init__(self,
method _dataset_dir (line 95) | def _dataset_dir(self):
method get_train (line 99) | def get_train(self):
method get_train_len (line 103) | def get_train_len(self):
method get_test (line 107) | def get_test(self):
method get_test_len (line 111) | def get_test_len(self):
method preprocess (line 115) | def preprocess(
method augment (line 130) | def augment(
class ImageDataset (line 147) | class ImageDataset(Dataset):
method preprocess (line 152) | def preprocess(
FILE: rigl/experimental/jax/datasets/dataset_base_test.py
class DummyDataset (line 22) | class DummyDataset(dataset_base.ImageDataset):
method __init__ (line 30) | def __init__(self,
class DummyDatasetTest (line 50) | class DummyDatasetTest(absltest.TestCase):
method setUp (line 53) | def setUp(self):
method test_create_dataset (line 64) | def test_create_dataset(self):
method test_train_image_dims_content (line 68) | def test_train_image_dims_content(self):
method test_test_image_dims_content (line 86) | def test_test_image_dims_content(self):
method test_train_data_length (line 104) | def test_train_data_length(self):
method test_test_data_length (line 112) | def test_test_data_length(self):
FILE: rigl/experimental/jax/datasets/dataset_factory.py
function create_dataset (line 38) | def create_dataset(name, *args, **kwargs):
FILE: rigl/experimental/jax/datasets/dataset_factory_test.py
class DatasetCommonTest (line 24) | class DatasetCommonTest(parameterized.TestCase):
method setUp (line 26) | def setUp(self):
method _create_dataset (line 32) | def _create_dataset(self, dataset_name):
method test_dataset_supported (line 40) | def test_dataset_supported(self):
method test_dataset_train_iterators (line 48) | def test_dataset_train_iterators(self, dataset_name):
method test_dataset_test_iterators (line 76) | def test_dataset_test_iterators(self, dataset_name):
method test_dataset_unsupported (line 103) | def test_dataset_unsupported(self):
FILE: rigl/experimental/jax/datasets/mnist.py
class MNISTDataset (line 27) | class MNISTDataset(dataset_base.ImageDataset):
method __init__ (line 35) | def __init__(self,
method preprocess (line 54) | def preprocess(
FILE: rigl/experimental/jax/datasets/mnist_test.py
class MNISTDatasetTest (line 23) | class MNISTDatasetTest(absltest.TestCase):
method setUp (line 26) | def setUp(self):
method test_create_dataset (line 38) | def test_create_dataset(self):
method test_train_image_dims_content (line 42) | def test_train_image_dims_content(self):
method test_test_image_dims_content (line 67) | def test_test_image_dims_content(self):
method test_train_data_length (line 93) | def test_train_data_length(self):
method test_test_data_length (line 101) | def test_test_data_length(self):
FILE: rigl/experimental/jax/fixed_param.py
function main (line 195) | def main(argv: List[str]):
FILE: rigl/experimental/jax/fixed_param_test.py
class FixedParamTest (line 27) | class FixedParamTest(absltest.TestCase):
method test_run (line 29) | def test_run(self):
FILE: rigl/experimental/jax/models/cifar10_cnn.py
class CIFAR10CNN (line 32) | class CIFAR10CNN(flax.deprecated.nn.Module):
method apply (line 35) | def apply(self,
FILE: rigl/experimental/jax/models/cifar10_cnn_test.py
class CIFAR10CNNTest (line 25) | class CIFAR10CNNTest(absltest.TestCase):
method setUp (line 28) | def setUp(self):
method test_output_shapes (line 36) | def test_output_shapes(self):
method test_invalid_spatial_dimensions (line 48) | def test_invalid_spatial_dimensions(self):
method test_invalid_masks_depth (line 56) | def test_invalid_masks_depth(self):
FILE: rigl/experimental/jax/models/mnist_cnn.py
class MNISTCNN (line 32) | class MNISTCNN(flax.deprecated.nn.Module):
method apply (line 35) | def apply(self,
FILE: rigl/experimental/jax/models/mnist_cnn_test.py
class MNISTCNNTest (line 25) | class MNISTCNNTest(absltest.TestCase):
method setUp (line 28) | def setUp(self):
method test_output_shapes (line 36) | def test_output_shapes(self):
method test_invalid_depth (line 48) | def test_invalid_depth(self):
FILE: rigl/experimental/jax/models/mnist_fc.py
function feature_dim_for_param (line 32) | def feature_dim_for_param(input_len,
class MNISTFC (line 81) | class MNISTFC(flax.deprecated.nn.Module):
method apply (line 84) | def apply(self,
FILE: rigl/experimental/jax/models/mnist_fc_test.py
class MNISTFCTest (line 31) | class MNISTFCTest(parameterized.TestCase):
method setUp (line 34) | def setUp(self):
method test_output_shapes (line 44) | def test_output_shapes(self):
method test_invalid_masks_depth (line 56) | def test_invalid_masks_depth(self):
method _create_model (line 73) | def _create_model(self, features):
method test_feature_dim_for_param_depth (line 83) | def test_feature_dim_for_param_depth(self, depth):
FILE: rigl/experimental/jax/models/model_factory.py
function create_model (line 37) | def create_model(
function update_model (line 66) | def update_model(model,
FILE: rigl/experimental/jax/models/model_factory_test.py
class ModelCommonTest (line 26) | class ModelCommonTest(parameterized.TestCase):
method setUp (line 29) | def setUp(self):
method _create_model (line 35) | def _create_model(self, model_name):
method test_model_supported (line 42) | def test_model_supported(self, model_name):
method test_model_unsupported (line 52) | def test_model_unsupported(self):
FILE: rigl/experimental/jax/prune.py
function main (line 166) | def main(argv: List[str]):
FILE: rigl/experimental/jax/prune_test.py
class PruneTest (line 26) | class PruneTest(absltest.TestCase):
method test_prune_fixed_schedule (line 28) | def test_prune_fixed_schedule(self):
method test_prune_global_pruning_schedule (line 45) | def test_prune_global_pruning_schedule(self):
method test_prune_local_pruning_schedule (line 62) | def test_prune_local_pruning_schedule(self):
FILE: rigl/experimental/jax/pruning/init.py
function sparse_init (line 25) | def sparse_init(
FILE: rigl/experimental/jax/pruning/init_test.py
class MaskedDense (line 28) | class MaskedDense(flax.deprecated.nn.Module):
method apply (line 33) | def apply(self,
class MaskedDenseSparseInit (line 47) | class MaskedDenseSparseInit(flax.deprecated.nn.Module):
method apply (line 52) | def apply(self,
class MaskedCNN (line 70) | class MaskedCNN(flax.deprecated.nn.Module):
method apply (line 75) | def apply(self,
class MaskedCNNSparseInit (line 89) | class MaskedCNNSparseInit(flax.deprecated.nn.Module):
method apply (line 94) | def apply(self,
class InitTest (line 112) | class InitTest(absltest.TestCase):
method setUp (line 114) | def setUp(self):
method test_init_kaiming_sparse_normal_output (line 121) | def test_init_kaiming_sparse_normal_output(self):
method test_dense_no_mask (line 140) | def test_dense_no_mask(self):
method test_dense_sparse_init_kaiming (line 157) | def test_dense_sparse_init_kaiming(self):
method test_cnn_sparse_init_kaiming (line 193) | def test_cnn_sparse_init_kaiming(self):
FILE: rigl/experimental/jax/pruning/mask_factory.py
function create_mask (line 47) | def create_mask(mask_type, base_model,
FILE: rigl/experimental/jax/pruning/mask_factory_test.py
class MaskedDense (line 29) | class MaskedDense(flax.deprecated.nn.Module):
method apply (line 34) | def apply(self,
class MaskFactoryTest (line 46) | class MaskFactoryTest(parameterized.TestCase):
method setUp (line 48) | def setUp(self):
method _create_mask (line 60) | def _create_mask(self, mask_type):
method test_mask_supported (line 66) | def test_mask_supported(self, mask_type):
method test_mask_unsupported (line 73) | def test_mask_unsupported(self):
FILE: rigl/experimental/jax/pruning/masked.py
class MaskedModule (line 55) | class MaskedModule(flax.deprecated.nn.Module):
method apply (line 66) | def apply(self,
function masked (line 115) | def masked(module, mask):
function generate_model_masks (line 120) | def generate_model_masks(
function _filter_param (line 158) | def _filter_param(param_names,
function mask_map (line 182) | def mask_map(model,
function iterate_mask (line 234) | def iterate_mask(
function shuffled_mask (line 257) | def shuffled_mask(model, rng,
function random_mask (line 292) | def random_mask(model,
function simple_mask (line 326) | def simple_mask(model,
function symmetric_mask (line 348) | def symmetric_mask(model,
class _PerNeuronShuffle (line 379) | class _PerNeuronShuffle:
method __init__ (line 382) | def __init__(self, init_rng, sparsity):
method __call__ (line 393) | def __call__(self, param_name, param):
function shuffled_neuron_mask (line 418) | def shuffled_neuron_mask(model,
function _fill_diagonal_wrap (line 452) | def _fill_diagonal_wrap(shape,
function _random_neuron_mask (line 511) | def _random_neuron_mask(neuron_length,
class _PerNeuronNoInputAblationShuffle (line 535) | class _PerNeuronNoInputAblationShuffle:
method __init__ (line 538) | def __init__(self, init_rng, sparsity):
method _get_rng (line 549) | def _get_rng(self):
method __call__ (line 554) | def __call__(self, param_name, param):
function shuffled_neuron_no_input_ablation_mask (line 597) | def shuffled_neuron_no_input_ablation_mask(model,
function propagate_masks (line 637) | def propagate_masks(
function mask_layer_sparsity (line 710) | def mask_layer_sparsity(mask_layer):
function mask_sparsity (line 733) | def mask_sparsity(
FILE: rigl/experimental/jax/pruning/masked_test.py
class Dense (line 29) | class Dense(flax.deprecated.nn.Module):
method apply (line 34) | def apply(self, inputs):
class MaskedDense (line 39) | class MaskedDense(flax.deprecated.nn.Module):
method apply (line 44) | def apply(self,
class DenseTwoLayer (line 56) | class DenseTwoLayer(flax.deprecated.nn.Module):
method apply (line 61) | def apply(self, inputs):
class MaskedTwoLayerDense (line 67) | class MaskedTwoLayerDense(flax.deprecated.nn.Module):
method apply (line 72) | def apply(self,
class MaskedConv (line 89) | class MaskedConv(flax.deprecated.nn.Module):
method apply (line 94) | def apply(self,
class MaskedTwoLayerConv (line 105) | class MaskedTwoLayerConv(flax.deprecated.nn.Module):
method apply (line 110) | def apply(self,
class MaskedThreeLayerConvDense (line 127) | class MaskedThreeLayerConvDense(flax.deprecated.nn.Module):
method apply (line 132) | def apply(self,
class MaskedTwoLayerMixedConvDense (line 155) | class MaskedTwoLayerMixedConvDense(flax.deprecated.nn.Module):
method apply (line 160) | def apply(self,
class MaskedTest (line 176) | class MaskedTest(parameterized.TestCase):
method setUp (line 179) | def setUp(self):
method test_fully_masked_layer (line 241) | def test_fully_masked_layer(self):
method test_no_mask_masked_layer (line 253) | def test_no_mask_masked_layer(self):
method test_empty_mask_masked_layer (line 263) | def test_empty_mask_masked_layer(self):
method test_invalid_mask (line 275) | def test_invalid_mask(self):
method test_shuffled_mask_invalid_model (line 287) | def test_shuffled_mask_invalid_model(self):
method test_shuffled_mask_invalid_sparsity (line 294) | def test_shuffled_mask_invalid_sparsity(self):
method test_shuffled_mask_sparsity_full (line 307) | def test_shuffled_mask_sparsity_full(self):
method test_shuffled_mask_sparsity_empty (line 328) | def test_shuffled_mask_sparsity_empty(self):
method test_shuffled_mask_sparsity_half_full (line 349) | def test_shuffled_mask_sparsity_half_full(self):
method test_shuffled_mask_sparsity_full_twolayer (line 359) | def test_shuffled_mask_sparsity_full_twolayer(self):
method test_shuffled_mask_sparsity_empty_twolayer (line 390) | def test_shuffled_mask_sparsity_empty_twolayer(self):
method test_random_invalid_model (line 416) | def test_random_invalid_model(self):
method test_random_invalid_sparsity (line 423) | def test_random_invalid_sparsity(self):
method test_random_mask_sparsity_full (line 436) | def test_random_mask_sparsity_full(self):
method test_random_mask_sparsity_empty (line 451) | def test_random_mask_sparsity_empty(self):
method test_random_mask_sparsity_half_full (line 468) | def test_random_mask_sparsity_half_full(self):
method test_simple_mask_one_layer (line 480) | def test_simple_mask_one_layer(self):
method test_simple_mask_two_layer (line 500) | def test_simple_mask_two_layer(self):
method test_shuffled_mask_neuron_mask_sparsity_empty (line 528) | def test_shuffled_mask_neuron_mask_sparsity_empty(self):
method test_shuffled_mask_neuron_mask_sparsity_half_full (line 549) | def test_shuffled_mask_neuron_mask_sparsity_half_full(self):
method test_symmetric_mask_sparsity_empty (line 566) | def test_symmetric_mask_sparsity_empty(self):
method test_symmetric_mask_sparsity_half_full (line 587) | def test_symmetric_mask_sparsity_half_full(self):
method test_propagate_masks_ablated_neurons_one_layer (line 604) | def test_propagate_masks_ablated_neurons_one_layer(self):
method test_propagate_masks_ablated_neurons_two_layers (line 625) | def test_propagate_masks_ablated_neurons_two_layers(self):
method test_propagate_masks_ablated_neurons_two_layers_nonmasked (line 653) | def test_propagate_masks_ablated_neurons_two_layers_nonmasked(self):
method test_propagate_masks_ablated_neurons_one_conv_layer (line 683) | def test_propagate_masks_ablated_neurons_one_conv_layer(self):
method test_propagate_masks_ablated_neurons_two_conv_layers (line 704) | def test_propagate_masks_ablated_neurons_two_conv_layers(self):
method test_propagate_masks_ablated_neurons_three_conv_fc_layers (line 734) | def test_propagate_masks_ablated_neurons_three_conv_fc_layers(self):
method test_propagate_masks_ablated_neurons_mixed_conv_dense_layers (line 774) | def test_propagate_masks_ablated_neurons_mixed_conv_dense_layers(self):
method test_mask_layer_sparsity_zero_mask (line 802) | def test_mask_layer_sparsity_zero_mask(self):
method test_mask_layer_sparsity_half_mask (line 809) | def test_mask_layer_sparsity_half_mask(self):
method test_mask_layer_sparsity_ones_mask (line 816) | def test_mask_layer_sparsity_ones_mask(self):
method test_mask_sparsity_zero_mask (line 823) | def test_mask_sparsity_zero_mask(self):
method test_mask_sparsity_ones_mask (line 829) | def test_mask_sparsity_ones_mask(self):
method test_mask_sparsity_mixed_mask (line 835) | def test_mask_sparsity_mixed_mask(self):
method test_generate_model_masks_depth_only (line 873) | def test_generate_model_masks_depth_only(self, depth):
method test_generate_model_masks_indices (line 890) | def test_generate_model_masks_indices(self, depth, indices):
method test_generate_model_masks_existing_mask (line 909) | def test_generate_model_masks_existing_mask(self, depth, existing_mask,
method test_generate_model_masks_invalid_depth_zero (line 931) | def test_generate_model_masks_invalid_depth_zero(self):
method test_generate_model_masks_invalid_index_toohigh (line 936) | def test_generate_model_masks_invalid_index_toohigh(self):
method test_generate_model_masks_invalid_index_negative (line 941) | def test_generate_model_masks_invalid_index_negative(self):
method test_shuffled_neuron_no_input_ablation_mask_invalid_model (line 946) | def test_shuffled_neuron_no_input_ablation_mask_invalid_model(self):
method test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity (line 954) | def test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity(self):
method test_shuffled_neuron_no_input_ablation_mask_sparsity_full (line 969) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self):
method test_shuffled_neuron_no_input_ablation_mask_sparsity_empty (line 994) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty(self):
method test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full (line 1016) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self):
method test_shuffled_neuron_no_input_ablation_mask_sparsity_quarter_full (line 1033) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_quarter_full(...
method test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer (line 1050) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer...
method test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolayer (line 1092) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolaye...
FILE: rigl/experimental/jax/pruning/pruning.py
function weight_magnitude (line 26) | def weight_magnitude(weights):
function prune (line 31) | def prune(
FILE: rigl/experimental/jax/pruning/pruning_test.py
class MaskedDense (line 28) | class MaskedDense(flax.deprecated.nn.Module):
method apply (line 33) | def apply(self,
class MaskedTwoLayerDense (line 45) | class MaskedTwoLayerDense(flax.deprecated.nn.Module):
method apply (line 50) | def apply(self,
class MaskedConv (line 67) | class MaskedConv(flax.deprecated.nn.Module):
method apply (line 72) | def apply(self,
class MaskedTwoLayerConv (line 83) | class MaskedTwoLayerConv(flax.deprecated.nn.Module):
method apply (line 88) | def apply(self,
class PruningTest (line 105) | class PruningTest(absltest.TestCase):
method setUp (line 108) | def setUp(self):
method test_prune_single_layer_dense_no_mask (line 134) | def test_prune_single_layer_dense_no_mask(self):
method test_prune_single_layer_local_pruning (line 145) | def test_prune_single_layer_local_pruning(self):
method test_prune_single_layer_dense_with_mask (line 158) | def test_prune_single_layer_dense_with_mask(self):
method test_prune_two_layers_dense_no_mask (line 172) | def test_prune_two_layers_dense_no_mask(self):
method test_prune_two_layer_local_pruning_rate (line 186) | def test_prune_two_layer_local_pruning_rate(self):
method test_prune_one_layer_conv_no_mask (line 206) | def test_prune_one_layer_conv_no_mask(self):
method test_prune_one_layer_conv_with_mask (line 217) | def test_prune_one_layer_conv_with_mask(self):
method test_prune_two_layer_conv_no_mask (line 231) | def test_prune_two_layer_conv_no_mask(self):
FILE: rigl/experimental/jax/pruning/symmetry.py
function count_permutations_mask_layer (line 30) | def count_permutations_mask_layer(
function count_permutations_mask (line 125) | def count_permutations_mask(mask):
function get_mask_stats (line 161) | def get_mask_stats(mask):
FILE: rigl/experimental/jax/pruning/symmetry_test.py
class MaskedDense (line 33) | class MaskedDense(flax.deprecated.nn.Module):
method apply (line 42) | def apply(self,
class MaskedConv (line 53) | class MaskedConv(flax.deprecated.nn.Module):
method apply (line 62) | def apply(self,
class MaskedTwoLayerDense (line 73) | class MaskedTwoLayerDense(flax.deprecated.nn.Module):
method apply (line 82) | def apply(self,
class SymmetryTest (line 99) | class SymmetryTest(parameterized.TestCase):
method setUp (line 102) | def setUp(self):
method test_count_permutations_layer_mask_full (line 123) | def test_count_permutations_layer_mask_full(self):
method test_count_permutations_layer_mask_empty (line 146) | def test_count_permutations_layer_mask_empty(self):
method test_count_permutations_conv_layer_mask_full (line 168) | def test_count_permutations_conv_layer_mask_full(self):
method test_count_permutations_conv_layer_mask_empty (line 191) | def test_count_permutations_conv_layer_mask_empty(self):
method test_count_permutations_layer_mask_known_perm (line 213) | def test_count_permutations_layer_mask_known_perm(self):
method test_count_permutations_layer_mask_known_perm_zeros (line 247) | def test_count_permutations_layer_mask_known_perm_zeros(self):
method test_count_permutations_shuffled_full_mask (line 279) | def test_count_permutations_shuffled_full_mask(self):
method test_count_permutations_shuffled_empty_mask (line 297) | def test_count_permutations_shuffled_empty_mask(self):
method test_count_permutations_mask_layer_twolayer_known_symmetric (line 316) | def test_count_permutations_mask_layer_twolayer_known_symmetric(self):
method test_count_permutations_mask_layer_twolayer (line 396) | def test_count_permutations_mask_layer_twolayer(self, mask, unique,
method test_count_permutations_mask_full (line 414) | def test_count_permutations_mask_full(self):
method test_count_permutations_mask_bn_layer_full (line 433) | def test_count_permutations_mask_bn_layer_full(self):
method test_count_permutations_mask_empty (line 452) | def test_count_permutations_mask_empty(self):
method test_count_permutations_mask_twolayer_full (line 470) | def test_count_permutations_mask_twolayer_full(self):
method test_count_permutations_mask_twolayers_empty (line 494) | def test_count_permutations_mask_twolayers_empty(self):
method test_count_permutations_mask_twolayer_known_symmetric (line 515) | def test_count_permutations_mask_twolayer_known_symmetric(self):
method test_count_permutations_mask_twolayer_known_non_symmetric (line 542) | def test_count_permutations_mask_twolayer_known_non_symmetric(self):
method test_get_mask_stats_keys_values (line 569) | def test_get_mask_stats_keys_values(self):
FILE: rigl/experimental/jax/random_mask.py
function main (line 177) | def main(argv: List[str]):
FILE: rigl/experimental/jax/random_mask_test.py
class RandomMaskTest (line 26) | class RandomMaskTest(absltest.TestCase):
method test_run_fc (line 28) | def test_run_fc(self):
method test_run_conv (line 46) | def test_run_conv(self):
method test_run_random (line 64) | def test_run_random(self):
method test_run_per_neuron (line 82) | def test_run_per_neuron(self):
method test_run_symmetric (line 100) | def test_run_symmetric(self):
FILE: rigl/experimental/jax/shuffled_mask.py
function main (line 178) | def main(argv: List[str]):
FILE: rigl/experimental/jax/shuffled_mask_test.py
class ShuffledMaskTest (line 26) | class ShuffledMaskTest(absltest.TestCase):
method test_run_fc (line 28) | def test_run_fc(self):
method test_run_conv (line 45) | def test_run_conv(self):
method test_run_random (line 62) | def test_run_random(self):
method test_run_per_neuron (line 79) | def test_run_per_neuron(self):
method test_run_symmetric (line 96) | def test_run_symmetric(self):
FILE: rigl/experimental/jax/train.py
function run_training (line 86) | def run_training():
function main (line 170) | def main(argv):
FILE: rigl/experimental/jax/train_test.py
class TrainTest (line 26) | class TrainTest(absltest.TestCase):
method test_train_driver_run (line 28) | def test_train_driver_run(self):
FILE: rigl/experimental/jax/training/training.py
function _shard_batch (line 51) | def _shard_batch(xs):
function train_step (line 61) | def train_step(
class Trainer (line 110) | class Trainer:
method __init__ (line 118) | def __init__(
FILE: rigl/experimental/jax/training/training_test.py
class TrainingTest (line 33) | class TrainingTest(absltest.TestCase):
method setUp (line 36) | def setUp(self):
method test_train_one_step (line 76) | def test_train_one_step(self):
method test_train_one_epoch (line 104) | def test_train_one_epoch(self):
method test_train_one_epoch_tensorboard (line 140) | def test_train_one_epoch_tensorboard(self):
method test_train_one_epoch_pruning_global_schedule (line 180) | def test_train_one_epoch_pruning_global_schedule(self):
method test_train_one_epoch_pruning_local_schedule (line 217) | def test_train_one_epoch_pruning_local_schedule(self):
method test_eval_batch (line 254) | def test_eval_batch(self):
method test_eval (line 275) | def test_eval(self):
FILE: rigl/experimental/jax/utils/utils.py
function cross_entropy_loss (line 34) | def cross_entropy_loss(log_softmax_logits,
function compute_metrics (line 48) | def compute_metrics(logits,
function _np_converter (line 76) | def _np_converter(obj):
function dump_dict_json (line 86) | def dump_dict_json(data_dict, path):
function count_param (line 100) | def count_param(model,
function cosine_similarity (line 120) | def cosine_similarity(a, b):
function param_as_array (line 127) | def param_as_array(params):
function cosine_similarity_model (line 133) | def cosine_similarity_model(initial_model,
function vector_difference_norm_model (line 142) | def vector_difference_norm_model(initial_model,
function pairwise_longest (line 154) | def pairwise_longest(iterable):
FILE: rigl/experimental/jax/utils/utils_test.py
class TwoLayerDense (line 34) | class TwoLayerDense(flax.deprecated.nn.Module):
method apply (line 39) | def apply(self, inputs):
class UtilsTest (line 47) | class UtilsTest(parameterized.TestCase):
method setUp (line 50) | def setUp(self):
method _create_logits_labels (line 68) | def _create_logits_labels(self, correct):
method test_compute_metrics_correct (line 93) | def test_compute_metrics_correct(self):
method test_compute_metrics_incorrect (line 122) | def test_compute_metrics_incorrect(self):
method test_compute_metrics_equal_logits (line 151) | def test_compute_metrics_equal_logits(self):
method test_dump_dict_json (line 180) | def test_dump_dict_json(self):
method test_count_param_two_layer_dense (line 200) | def test_count_param_two_layer_dense(self):
method test_count_invalid_param (line 209) | def test_count_invalid_param(self):
method test_model_param_as_array (line 215) | def test_model_param_as_array(self):
method test_cosine_similarity_random (line 228) | def test_cosine_similarity_random(self):
method test_cosine_similarity_same (line 238) | def test_cosine_similarity_same(self):
method test_cosine_similarity_same_model (line 247) | def test_cosine_similarity_same_model(self):
method test_vector_difference_norm_diff_model (line 253) | def test_vector_difference_norm_diff_model(self):
method test_vector_difference_norm_same_model (line 260) | def test_vector_difference_norm_same_model(self):
method test_pairwise_longest_list_iterator (line 277) | def test_pairwise_longest_list_iterator(
FILE: rigl/imagenet_resnet/imagenet_train_eval.py
function set_lr_schedule (line 280) | def set_lr_schedule():
function set_custom_sparsity_map (line 308) | def set_custom_sparsity_map():
function lr_schedule (line 317) | def lr_schedule(current_epoch):
function train_function (line 333) | def train_function(training_method, loss, cross_loss, reg_loss, output_dir,
function resnet_model_fn_w_pruning (line 478) | def resnet_model_fn_w_pruning(features, labels, mode, params):
class ExportModelHook (line 668) | class ExportModelHook(tf.train.SessionRunHook):
method __init__ (line 671) | def __init__(self, classifier, export_dir):
method begin (line 683) | def begin(self):
method after_run (line 686) | def after_run(self, run_context, run_values):
function main (line 703) | def main(argv):
FILE: rigl/imagenet_resnet/mobilenetv1_model.py
function _make_divisible (line 33) | def _make_divisible(v, divisor=8, min_value=None):
function depthwise_conv2d_fixed_padding (line 43) | def depthwise_conv2d_fixed_padding(inputs,
function conv2d_fixed_padding (line 95) | def conv2d_fixed_padding(inputs,
function mbv1_block_ (line 156) | def mbv1_block_(inputs,
function mobilenet_v1_generator (line 223) | def mobilenet_v1_generator(num_classes=1000,
function mobilenet_v1 (line 345) | def mobilenet_v1(num_classes,
FILE: rigl/imagenet_resnet/mobilenetv2_model.py
function _make_divisible (line 33) | def _make_divisible(v, divisor=8, min_value=None):
function depthwise_conv2d_fixed_padding (line 43) | def depthwise_conv2d_fixed_padding(inputs,
function conv2d_fixed_padding (line 95) | def conv2d_fixed_padding(inputs,
function inverted_res_block_ (line 156) | def inverted_res_block_(inputs,
function mobilenet_v2_generator (line 255) | def mobilenet_v2_generator(num_classes=1000,
function mobilenet_v2 (line 401) | def mobilenet_v2(num_classes,
FILE: rigl/imagenet_resnet/pruning_layers.py
function get_model_variables (line 29) | def get_model_variables(getter,
function variable_getter (line 62) | def variable_getter(rename=None):
function sparse_conv2d (line 72) | def sparse_conv2d(x,
function sparse_fully_connected (line 175) | def sparse_fully_connected(x,
FILE: rigl/imagenet_resnet/resnet_model.py
function batch_norm_relu (line 41) | def batch_norm_relu(inputs, is_training, relu=True, init_zero=False,
function fixed_padding (line 83) | def fixed_padding(inputs, kernel_size, data_format='channels_first'):
class RandomSparseInitializer (line 111) | class RandomSparseInitializer(init_ops.Initializer):
method __init__ (line 114) | def __init__(self, sparsity, seed=None, dtype=tf.float32):
method __call__ (line 123) | def __call__(self, *args, **kwargs):
method get_config (line 131) | def get_config(self):
class SparseConvVarianceScalingInitializer (line 139) | class SparseConvVarianceScalingInitializer(init_ops.Initializer):
method __init__ (line 142) | def __init__(self, sparsity, seed=None, dtype=tf.float32):
method __call__ (line 149) | def __call__(self, shape, dtype=None, partition_info=None):
method get_config (line 168) | def get_config(self):
class SparseFCVarianceScalingInitializer (line 175) | class SparseFCVarianceScalingInitializer(init_ops.Initializer):
method __init__ (line 178) | def __init__(self, sparsity, seed=None, dtype=tf.float32):
method __call__ (line 185) | def __call__(self, shape, dtype=None, partition_info=None):
method get_config (line 207) | def get_config(self):
function _pick_initializer (line 214) | def _pick_initializer(kernel_initializer, init_method, pruning_method,
function conv2d_fixed_padding (line 234) | def conv2d_fixed_padding(inputs,
function residual_block_ (line 306) | def residual_block_(inputs,
function bottleneck_block_ (line 396) | def bottleneck_block_(inputs,
function block_group (line 504) | def block_group(inputs,
function resnet_v1_generator (line 577) | def resnet_v1_generator(block_fn,
function resnet_v1_ (line 734) | def resnet_v1_(resnet_depth,
FILE: rigl/imagenet_resnet/train_test.py
class DataInputTest (line 36) | class DataInputTest(tf.test.TestCase, parameterized.TestCase):
method _retrieve_data (line 38) | def _retrieve_data(self, is_training, data_dir):
method testTrainingPipeline (line 50) | def testTrainingPipeline(self, training_method):
FILE: rigl/imagenet_resnet/utils.py
function format_tensors (line 28) | def format_tensors(*dicts):
function host_call_fn (line 59) | def host_call_fn(model_dir, **kwargs):
function mask_summaries (line 83) | def mask_summaries(masks, with_img=False):
function initialize_parameters_from_ckpt (line 93) | def initialize_parameters_from_ckpt(ckpt_path, model_dir, param_suffixes):
FILE: rigl/imagenet_resnet/vgg.py
function vgg_net (line 64) | def vgg_net(inputs,
function vgg (line 203) | def vgg(vgg_type,
FILE: rigl/mnist/mnist_train_eval.py
function mnist_network_fc (line 112) | def mnist_network_fc(input_batch, reuse=False, model_pruning=False):
function get_compressed_fc (line 165) | def get_compressed_fc(masks):
function main (line 192) | def main(unused_args):
FILE: rigl/mnist/visualize_mask_records.py
function main (line 62) | def main(unused_args):
FILE: rigl/rigl_tf2/init_utils.py
function unit_scaled_init (line 23) | def unit_scaled_init(mask, method='fanavg_uniform', scale=1.0):
function layer_scaled_init (line 70) | def layer_scaled_init(mask, method='fanavg_uniform', scale=1.0):
function unit_scaled_init_tf1 (line 81) | def unit_scaled_init_tf1(mask,
FILE: rigl/rigl_tf2/interpolate.py
function test_model (line 61) | def test_model(model, d_test, batch_size=1000):
function interpolate (line 80) | def interpolate(model_start, model_end, model_inter, d_set,
function main (line 97) | def main(unused_argv):
FILE: rigl/rigl_tf2/mask_updaters.py
function get_all_layers (line 22) | def get_all_layers(model, filter_fn=lambda _: True):
function is_pruned (line 33) | def is_pruned(layer):
class MaskUpdater (line 37) | class MaskUpdater(object):
method __init__ (line 49) | def __init__(self, model, optimizer, use_stateless=True,
method prune_masks (line 58) | def prune_masks(self, prune_fraction):
method update_masks (line 67) | def update_masks(self, drop_fraction):
method get_all_pruning_layers (line 76) | def get_all_pruning_layers(self):
method get_vars_and_masks (line 83) | def get_vars_and_masks(self):
method get_drop_scores (line 93) | def get_drop_scores(self, all_vars, all_masks):
method get_grow_scores (line 96) | def get_grow_scores(self, all_vars, all_masks):
method generic_mask_update (line 99) | def generic_mask_update(self, mask, var, score_drop, score_grow,
method reset_momentum (line 156) | def reset_momentum(self, var, new_connections):
method _random_uniform (line 164) | def _random_uniform(self, *args, **kwargs):
method _random_normal (line 173) | def _random_normal(self, *args, **kwargs):
method set_validation_data (line 182) | def set_validation_data(self, val_x, val_y):
method _get_gradients (line 185) | def _get_gradients(self, all_vars):
class SET (line 195) | class SET(MaskUpdater):
method get_drop_scores (line 204) | def get_drop_scores(self, all_vars, all_masks, noise_std=0):
method get_grow_scores (line 214) | def get_grow_scores(self, all_vars, all_masks):
class RigL (line 219) | class RigL(MaskUpdater):
method get_drop_scores (line 225) | def get_drop_scores(self, all_vars, all_masks, noise_std=0):
method get_grow_scores (line 235) | def get_grow_scores(self, all_vars, all_masks):
class RigLInverted (line 239) | class RigLInverted(RigL):
method get_grow_scores (line 245) | def get_grow_scores(self, all_vars, all_masks):
class UpdateSchedule (line 251) | class UpdateSchedule(object):
method __init__ (line 260) | def __init__(self, mask_updater, init_drop_fraction, update_freq,
method get_drop_fraction (line 268) | def get_drop_fraction(self, step):
method is_update_iter (line 271) | def is_update_iter(self, step):
method update (line 286) | def update(self, step, check_update_iter=True):
method prune (line 296) | def prune(self, prune_fraction):
method set_validation_data (line 300) | def set_validation_data(self, val_x, val_y):
class ConstantUpdateSchedule (line 304) | class ConstantUpdateSchedule(UpdateSchedule):
method get_drop_fraction (line 307) | def get_drop_fraction(self, step):
class CosineUpdateSchedule (line 311) | class CosineUpdateSchedule(UpdateSchedule):
method __init__ (line 314) | def __init__(self, *args, **kwargs):
method get_drop_fraction (line 322) | def get_drop_fraction(self, step):
class ScaledLRUpdateSchedule (line 326) | class ScaledLRUpdateSchedule(UpdateSchedule):
method __init__ (line 329) | def __init__(self, mask_updater, init_drop_fraction, update_freq,
method _get_lr (line 336) | def _get_lr(self, step):
method get_drop_fraction (line 342) | def get_drop_fraction(self, step):
function get_mask_updater (line 359) | def get_mask_updater(
FILE: rigl/rigl_tf2/metainit.py
class ScaleSGD (line 23) | class ScaleSGD(tf1.train.Optimizer):
method __init__ (line 29) | def __init__(self, learning_rate=0.1, momentum=0.9, mindim=3,
method _prepare (line 40) | def _prepare(self):
method _create_slots (line 44) | def _create_slots(self, var_list):
method _resource_apply_dense (line 53) | def _resource_apply_dense(self, grad, handle):
method _apply_dense (line 71) | def _apply_dense(self, grad, var):
method _apply_sparse (line 74) | def _apply_sparse(self, grad, var):
function meta_init (line 78) | def meta_init(model, loss, x_shape, y_shape, n_params, learning_rate=0.001,
FILE: rigl/rigl_tf2/networks.py
function lenet5 (line 25) | def lenet5(input_shape,
function mlp (line 58) | def mlp(input_shape,
FILE: rigl/rigl_tf2/train.py
function get_rows (line 59) | def get_rows(model, variables, masks, ind_l, indices, x_batch, y_batch,
function sparse_hessian_calculator (line 89) | def sparse_hessian_calculator(model,
function hessian (line 170) | def hessian(model,
function update_prune_step (line 195) | def update_prune_step(model, step):
function log_sparsities (line 202) | def log_sparsities(model):
function cosine_distance (line 212) | def cosine_distance(x, y):
function flatten_list_of_vars (line 219) | def flatten_list_of_vars(var_list):
function var_to_img (line 224) | def var_to_img(tensor):
function mask_gradients (line 235) | def mask_gradients(model, gradients, variables):
function train_model (line 248) | def train_model(model,
function test_model (line 445) | def test_model(model, d_test, batch_size=1000):
function main (line 461) | def main(unused_argv):
FILE: rigl/rigl_tf2/utils.py
function get_dataset (line 37) | def get_dataset():
function get_pruning_params (line 51) | def get_pruning_params(mode='prune',
function maybe_prune_layer (line 75) | def maybe_prune_layer(layer, params, filter_fn):
function get_network (line 82) | def get_network(
function get_optimizer (line 182) | def get_optimizer(total_steps,
FILE: rigl/rl/dqn_agents.py
function flatten_list_of_vars (line 36) | def flatten_list_of_vars(var_list):
function _get_bn_layer_name (line 41) | def _get_bn_layer_name(block_id, i):
function _get_conv_layer_name (line 45) | def _get_conv_layer_name(block_id, i):
class _Stack (line 49) | class _Stack(tf.keras.Model):
method __init__ (line 53) | def __init__(self,
method call (line 80) | def call(self, conv_out, training=False):
class ImpalaNetwork (line 103) | class ImpalaNetwork(tf.keras.Model):
method __init__ (line 120) | def __init__(self,
method get_features (line 190) | def get_features(self, state, training=True):
method call (line 205) | def call(self, state, training=True):
class NatureDQNNetwork (line 211) | class NatureDQNNetwork(tf.keras.Model):
method __init__ (line 214) | def __init__(self, num_actions, width=1, mode='dense', name=None):
method call (line 284) | def call(self, state):
class SparseDQNAgent (line 309) | class SparseDQNAgent(dqn_agent.DQNAgent):
method __init__ (line 312) | def __init__(self,
method _create_network (line 337) | def _create_network(self, name):
method _set_additional_ops (line 344) | def _set_additional_ops(self):
method _build_train_op (line 370) | def _build_train_op(self):
method _create_summary_ops (line 406) | def _create_summary_ops(self, grads_and_vars):
method update_prune_step (line 430) | def update_prune_step(self):
method maybe_update_and_apply_masks (line 433) | def maybe_update_and_apply_masks(self):
method maybe_init_masks (line 436) | def maybe_init_masks(self):
method _train_step (line 440) | def _train_step(self):
method _build_sync_op (line 459) | def _build_sync_op(self):
method _build_networks (line 474) | def _build_networks(self):
FILE: rigl/rl/run_experiment.py
function create_sparse_agent (line 33) | def create_sparse_agent(sess, num_actions, agent=None, summary_writer=No...
class SparseTrainRunner (line 54) | class SparseTrainRunner(run_experiment.Runner):
method __init__ (line 57) | def __init__(self,
method _run_one_phase_fix_episodes (line 127) | def _run_one_phase_fix_episodes(self, max_episodes, statistics):
method _run_eval_phase (line 165) | def _run_eval_phase(self, statistics):
method _run_one_step (line 177) | def _run_one_step(self, action):
method run_experiment (line 186) | def run_experiment(self):
FILE: rigl/rl/sparse_utils.py
function get_total_params (line 36) | def get_total_params(model):
function get_pruning_sparsities (line 56) | def get_pruning_sparsities(
function get_pruning_params (line 86) | def get_pruning_params(mode,
function maybe_prune_layer (line 113) | def maybe_prune_layer(layer, params, filter_fn=None):
function get_wrap_fn (line 121) | def get_wrap_fn(mode):
function update_prune_step (line 139) | def update_prune_step(model, step):
function update_prune_masks (line 150) | def update_prune_masks(model):
function get_all_layers (line 157) | def get_all_layers(model, filter_fn=lambda _: True):
function get_all_variables_and_masks (line 168) | def get_all_variables_and_masks(model):
function get_all_pruning_layers (line 179) | def get_all_pruning_layers(model):
function log_sparsities (line 185) | def log_sparsities(model):
class SparseOptTf2Mixin (line 197) | class SparseOptTf2Mixin:
method compute_gradients (line 200) | def compute_gradients(self, *args, **kwargs):
method set_model (line 204) | def set_model(self, model):
method get_weights (line 207) | def get_weights(self):
method get_masks (line 213) | def get_masks(self):
method get_masked_weights (line 219) | def get_masked_weights(self):
class UpdatedSETOptimizer (line 227) | class UpdatedSETOptimizer(SparseOptTf2Mixin,
method _before_apply_gradients (line 230) | def _before_apply_gradients(self, grads_and_vars):
class UpdatedRigLOptimizer (line 235) | class UpdatedRigLOptimizer(SparseOptTf2Mixin,
method _before_apply_gradients (line 238) | def _before_apply_gradients(self, grads_and_vars):
function init_masks (line 245) | def init_masks(model,
FILE: rigl/rl/tfagents/dqn_train_eval.py
class SparseDqnAgent (line 75) | class SparseDqnAgent(dqn_agent.DqnAgent):
method __init__ (line 78) | def __init__(self, *args, **kwargs):
method _train (line 95) | def _train(self, experience, weights):
function _scale_width (line 151) | def _scale_width(num_units, width):
function build_network (line 156) | def build_network(
function train_eval (line 200) | def train_eval(
function main (line 404) | def main(_):
FILE: rigl/rl/tfagents/ppo_train_eval.py
function _normalize_advantages (line 99) | def _normalize_advantages(advantages, axes=(0,), variance_epsilon=1e-8):
class SparsePPOAgent (line 112) | class SparsePPOAgent(ppo_clip_agent.PPOClipAgent):
method __init__ (line 115) | def __init__(self,
method _process_experience_weights (line 168) | def _process_experience_weights(self, experience, weights):
method _train (line 233) | def _train(self, experience, weights):
method get_loss (line 424) | def get_loss(self,
method value_estimation_loss (line 541) | def value_estimation_loss(self,
method policy_gradient_loss (line 644) | def policy_gradient_loss(
method entropy_regularization_loss (line 801) | def entropy_regularization_loss(
class ReverbFixedLengthSequenceObserver (line 844) | class ReverbFixedLengthSequenceObserver(reverb_utils.ReverbAddTrajectory...
method __call__ (line 857) | def __call__(self, trajectory):
function train_eval (line 874) | def train_eval(
function main (line 1175) | def main(_):
FILE: rigl/rl/tfagents/sac_train_eval.py
function create_fc_layers (line 81) | def create_fc_layers(layer_units, width=1.0, weight_decay=0):
function create_identity_layer (line 90) | def create_identity_layer():
function create_sequential_critic_network (line 94) | def create_sequential_critic_network(obs_fc_layer_units,
class _TanhNormalProjectionNetworkWrapper (line 176) | class _TanhNormalProjectionNetworkWrapper(
method __init__ (line 180) | def __init__(self, sample_spec, predefined_outer_rank=1, weight_decay=...
method call (line 186) | def call(self, inputs, network_state=(), **kwargs):
function create_sequential_actor_network (line 194) | def create_sequential_actor_network(actor_fc_layers,
class SparseSacAgent (line 234) | class SparseSacAgent(sac_agent.SacAgent):
method __init__ (line 237) | def __init__(self,
method _train (line 316) | def _train(self, experience, weights):
function train_eval (line 455) | def train_eval(
function main (line 698) | def main(_):
FILE: rigl/rl/tfagents/sparse_encoding_network.py
function _copy_layer (line 46) | def _copy_layer(layer):
class EncodingNetwork (line 79) | class EncodingNetwork(network.Network):
method __init__ (line 82) | def __init__(self,
method call (line 297) | def call(self, observation, step_type=None, network_state=(), training...
FILE: rigl/rl/tfagents/sparse_ppo_actor_network.py
function tanh_and_scale_to_spec (line 30) | def tanh_and_scale_to_spec(inputs, spec):
class PPOActorNetwork (line 38) | class PPOActorNetwork():
method __init__ (line 41) | def __init__(self,
method create_sequential_actor_net (line 53) | def create_sequential_actor_net(self,
FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network.py
function tanh_and_scale_to_spec (line 31) | def tanh_and_scale_to_spec(inputs, spec):
class PPODiscreteActorNetwork (line 39) | class PPODiscreteActorNetwork():
method __init__ (line 42) | def __init__(self, seed_stream_class=tfp.util.SeedStream,
method create_sequential_actor_net (line 57) | def create_sequential_actor_net(self,
FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network_test.py
class DeterministicSeedStream (line 32) | class DeterministicSeedStream(object):
method __init__ (line 35) | def __init__(self, seed, salt=''):
method __call__ (line 39) | def __call__(self):
class PpoActorNetworkTest (line 43) | class PpoActorNetworkTest(parameterized.TestCase, test_utils.TestCase):
method setUp (line 45) | def setUp(self):
method tearDown (line 52) | def tearDown(self):
method _init_network (line 56) | def _init_network(
method test_no_mismatched_shape (line 66) | def test_no_mismatched_shape(self):
method test_is_sparse (line 83) | def test_is_sparse(self, is_sparse, sparse_output_layer, expected_laye...
method test_width_scaling (line 96) | def test_width_scaling(self):
method test_weight_decay (line 115) | def test_weight_decay(self, is_sparse, sparse_output_layer,
FILE: rigl/rl/tfagents/sparse_tanh_normal_projection_network.py
class SparseTanhNormalProjectionNetwork (line 34) | class SparseTanhNormalProjectionNetwork(
method __init__ (line 42) | def __init__(self,
FILE: rigl/rl/tfagents/sparse_value_network.py
class ValueNetwork (line 42) | class ValueNetwork(network.Network):
method __init__ (line 45) | def __init__(self,
method call (line 160) | def call(self, observation, step_type=None, network_state=(), training...
FILE: rigl/rl/tfagents/tf_sparse_utils.py
function log_total_params (line 34) | def log_total_params(networks):
function scale_width (line 43) | def scale_width(num_units, width):
function wrap_all_layers (line 49) | def wrap_all_layers(layers,
function wrap_layer (line 115) | def wrap_layer(layer,
function is_valid_layer_to_wrap (line 144) | def is_valid_layer_to_wrap(layer):
function log_sparsities (line 153) | def log_sparsities(model, model_name='q_net', log_images=False):
function update_prune_step (line 174) | def update_prune_step(model, step):
function flatten_list_of_vars (line 180) | def flatten_list_of_vars(var_list):
function log_snr (line 186) | def log_snr(tape, loss, step, variables_to_train, freq=1000):
FILE: rigl/rl/train.py
function create_sparsetrain_runner (line 41) | def create_sparsetrain_runner(base_dir):
function main (line 46) | def main(unused_argv):
FILE: rigl/sparse_optimizers.py
class PruningGetterTf1Mixin (line 46) | class PruningGetterTf1Mixin:
method get_weights (line 49) | def get_weights(self):
method get_masks (line 52) | def get_masks(self):
method get_masked_weights (line 55) | def get_masked_weights(self):
class SparseSETOptimizer (line 59) | class SparseSETOptimizer(PruningGetterTf1Mixin,
class SparseRigLOptimizer (line 64) | class SparseRigLOptimizer(PruningGetterTf1Mixin,
class SparseStaticOptimizer (line 69) | class SparseStaticOptimizer(SparseSETOptimizer):
method __init__ (line 86) | def __init__(self,
method generic_mask_update (line 109) | def generic_mask_update(self, mask, weights, noise_std=1e-5):
class SparseMomentumOptimizer (line 126) | class SparseMomentumOptimizer(SparseSETOptimizer):
method __init__ (line 149) | def __init__(self,
method set_masked_grads (line 176) | def set_masked_grads(self, grads, weights):
method compute_gradients (line 183) | def compute_gradients(self, loss, **kwargs):
method _before_apply_gradients (line 195) | def _before_apply_gradients(self, grads_and_vars):
method generic_mask_update (line 199) | def generic_mask_update(self, mask, weights, noise_std=1e-5):
class SparseSnipOptimizer (line 217) | class SparseSnipOptimizer(tf_optimizer.Optimizer):
method __init__ (line 235) | def __init__(self,
method compute_gradients (line 254) | def compute_gradients(self, loss, **kwargs):
method apply_gradients (line 258) | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
class SparseDNWOptimizer (line 340) | class SparseDNWOptimizer(tf_optimizer.Optimizer):
method __init__ (line 360) | def __init__(self,
method compute_gradients (line 375) | def compute_gradients(self, loss, var_list=None, **kwargs):
method replace_with_masked_weights (line 388) | def replace_with_masked_weights(self, var_list):
method replace_masked_weights (line 397) | def replace_masked_weights(self, grads_and_vars):
method apply_gradients (line 408) | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
method get_weights (line 473) | def get_weights(self):
method get_masks (line 476) | def get_masks(self):
method get_masked_weights (line 479) | def get_masked_weights(self):
FILE: rigl/sparse_optimizers_base.py
function extract_number (line 45) | def extract_number(token):
class SparseSETOptimizerBase (line 62) | class SparseSETOptimizerBase(tf_optimizer.Optimizer):
method __init__ (line 87) | def __init__(self,
method compute_gradients (line 113) | def compute_gradients(self, loss, **kwargs):
method apply_gradients (line 118) | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
method _before_apply_gradients (line 148) | def _before_apply_gradients(self, grads_and_vars):
method cond_mask_update_op (line 152) | def cond_mask_update_op(self, global_step, false_branch):
method get_weights (line 189) | def get_weights(self):
method get_masks (line 192) | def get_masks(self):
method get_masked_weights (line 195) | def get_masked_weights(self):
method is_mask_update_iter (line 198) | def is_mask_update_iter(self, global_step, last_update_step):
method get_drop_fraction (line 232) | def get_drop_fraction(self, global_step, is_mask_update_iter_op):
method generic_mask_update (line 260) | def generic_mask_update(self, mask, weights, noise_std=1e-5):
method _get_update_op (line 276) | def _get_update_op(self,
method reset_momentum (line 345) | def reset_momentum(self, weights, new_connections):
method get_grow_tensor (line 355) | def get_grow_tensor(self, weights, method):
method _random_uniform (line 402) | def _random_uniform(self, *args, **kwargs):
method _random_normal (line 411) | def _random_normal(self, *args, **kwargs):
class SparseRigLOptimizerBase (line 421) | class SparseRigLOptimizerBase(SparseSETOptimizerBase):
method __init__ (line 444) | def __init__(self,
method set_masked_grads (line 471) | def set_masked_grads(self, grads, weights):
method compute_gradients (line 478) | def compute_gradients(self, loss, **kwargs):
method apply_gradients (line 487) | def apply_gradients(self, grads_and_vars, global_step=None, name=None):
method generic_mask_update (line 523) | def generic_mask_update(self, mask, weights, noise_std=1e-5):
method get_grow_tensor (line 540) | def get_grow_tensor(self, weights, method):
method reset_momentum (line 555) | def reset_momentum(self, weights, new_connections):
FILE: rigl/sparse_optimizers_test.py
class SparseSETOptimizerTest (line 38) | class SparseSETOptimizerTest(tf.test.TestCase, parameterized.TestCase):
method _setup_graph (line 40) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
method testMaskNonUpdateIterations (line 72) | def testMaskNonUpdateIterations(self, n_inp, n_out, drop_frac):
method testUpdateIterations (line 95) | def testUpdateIterations(self, n_inp, n_out, drop_frac):
method testNoDrop (line 121) | def testNoDrop(self, start_iter, end_iter, freq_iter):
method testNewConnectionZeroInit (line 141) | def testNewConnectionZeroInit(self):
method testShapeOfGetGrowTensor (line 160) | def testShapeOfGetGrowTensor(self, shape, init_type):
method testDtypeOfGetGrowTensor (line 172) | def testDtypeOfGetGrowTensor(self, dtype, init_type):
method testValueErrorOfGetGrowTensor (line 182) | def testValueErrorOfGetGrowTensor(self, method):
class SparseStaticOptimizerTest (line 192) | class SparseStaticOptimizerTest(tf.test.TestCase, parameterized.TestCase):
method _setup_graph (line 194) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
method testMaskStatic (line 226) | def testMaskStatic(self, n_inp, n_out, drop_frac):
class SparseMomentumOptimizerTest (line 247) | class SparseMomentumOptimizerTest(tf.test.TestCase, parameterized.TestCa...
method _setup_graph (line 249) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
method testMomentumUpdate (line 276) | def testMomentumUpdate(self, n_inp, n_out, momentum):
class SparseRigLOptimizerTest (line 297) | class SparseRigLOptimizerTest(tf.test.TestCase, parameterized.TestCase):
method _setup_graph (line 299) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
method testMaskedGradientCalculation (line 331) | def testMaskedGradientCalculation(self, n_inp, n_out):
method testApplyGradients (line 353) | def testApplyGradients(self, start_iter, end_iter, freq_iter, is_incre...
class SparseSnipOptimizerTest (line 370) | class SparseSnipOptimizerTest(tf.test.TestCase, parameterized.TestCase):
method _setup_graph (line 372) | def _setup_graph(self, default_sparsity, mask_init_method,
method testSnipSparsity (line 407) | def testSnipSparsity(self, n_inp, n_out, default_sparsity):
method testGradientUsed (line 422) | def testGradientUsed(self, n_inp, n_out, default_sparsity):
method testInitialMaskIsDense (line 441) | def testInitialMaskIsDense(self, n_inp, n_out, default_sparsity):
method testAfterSnipTraining (line 451) | def testAfterSnipTraining(self, n_inp, n_out, default_sparsity):
class SparseDNWOptimizerTest (line 471) | class SparseDNWOptimizerTest(tf.test.TestCase, parameterized.TestCase):
method _setup_graph (line 473) | def _setup_graph(self,
method testDNWSparsity (line 515) | def testDNWSparsity(self, n_inp, n_out, default_sparsity):
method testWeightsUsed (line 529) | def testWeightsUsed(self, n_inp, n_out, default_sparsity):
method testGradientIsDense (line 548) | def testGradientIsDense(self, n_inp, n_out, default_sparsity):
method testDNWUpdates (line 557) | def testDNWUpdates(self, n_inp, n_out, default_sparsity):
method testSparsityAfterDNWUpdates (line 574) | def testSparsityAfterDNWUpdates(self, n_inp, n_out, default_sparsity):
FILE: rigl/sparse_utils.py
function mask_extract_name_fn (line 31) | def mask_extract_name_fn(mask_name):
function get_n_zeros (line 35) | def get_n_zeros(size, sparsity):
function calculate_sparsity (line 39) | def calculate_sparsity(masks):
function get_mask_random_numpy (line 48) | def get_mask_random_numpy(mask_shape, sparsity, random_state=None):
function get_mask_random (line 71) | def get_mask_random(mask, sparsity, dtype, random_state=None):
function get_sparsities_erdos_renyi (line 90) | def get_sparsities_erdos_renyi(all_masks,
function get_sparsities_uniform (line 210) | def get_sparsities_uniform(all_masks,
function get_sparsities_str (line 238) | def get_sparsities_str(all_masks, default_sparsity):
function get_sparsities (line 258) | def get_sparsities(all_masks,
function get_mask_init_fn (line 319) | def get_mask_init_fn(all_masks,
function _get_kernel (line 368) | def _get_kernel(layer):
function get_stats (line 376) | def get_stats(masked_layers,
FILE: rigl/sparse_utils_test.py
class GetMaskRandomTest (line 29) | class GetMaskRandomTest(tf.test.TestCase, parameterized.TestCase):
method _setup_session (line 31) | def _setup_session(self):
method testMaskConnectionDeterminism (line 38) | def testMaskConnectionDeterminism(self, shape, sparsity):
method testMaskFraction (line 49) | def testMaskFraction(self, shape, sparsity, expected_ones):
method testMaskDtype (line 58) | def testMaskDtype(self, dtype):
class GetSparsitiesTest (line 65) | class GetSparsitiesTest(tf.test.TestCase, parameterized.TestCase):
method _setup_session (line 67) | def _setup_session(self):
method testSparsityDictRandom (line 74) | def testSparsityDictRandom(self, default_sparsity):
method testSparsityDictErdosRenyiCustom (line 87) | def testSparsityDictErdosRenyiCustom(self, default_sparsity):
method testSparsityDictErdosRenyiError (line 98) | def testSparsityDictErdosRenyiError(self, default_sparsity):
method testSparsityDictErdosRenyiSparsitiesScale (line 113) | def testSparsityDictErdosRenyiSparsitiesScale(
FILE: rigl/str_sparsities.py
function _name_map_str (line 86) | def _name_map_str(k):
function read_all (line 109) | def read_all():
Condensed preview — 141 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (966K chars).
[
{
"path": "CONTRIBUTING.md",
"chars": 678,
"preview": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project.\n\n- If you want to contribute to"
},
{
"path": "LICENSE",
"chars": 11358,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 12693,
"preview": "# Rigging the Lottery: Making All Tickets Winners\n<img src=\"https://github.com/google-research/rigl/blob/master/imgs/flo"
},
{
"path": "rigl/__init__.py",
"chars": 678,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/cifar_resnet/data_helper.py",
"chars": 3600,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/cifar_resnet/data_helper_test.py",
"chars": 3792,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/cifar_resnet/resnet_model.py",
"chars": 8151,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/cifar_resnet/resnet_train_eval.py",
"chars": 23121,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/README.md",
"chars": 3405,
"preview": "# Weight Symmetry Research Code\nThis code is mostly written by Yani Ioannou.\n\n## Experiment Summary\n\nThere are a number "
},
{
"path": "rigl/experimental/jax/__init__.py",
"chars": 683,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/analysis/plot_summary_json.ipynb",
"chars": 9527,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"colab_type\": \"text\",\n \"id\": \"6iE"
},
{
"path": "rigl/experimental/jax/datasets/__init__.py",
"chars": 593,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/cifar10.py",
"chars": 2838,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/cifar10_test.py",
"chars": 4109,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/dataset_base.py",
"chars": 5243,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/dataset_base_test.py",
"chars": 3976,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/dataset_factory.py",
"chars": 1668,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/dataset_factory_test.py",
"chars": 4041,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/mnist.py",
"chars": 2036,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/datasets/mnist_test.py",
"chars": 3742,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/fixed_param.py",
"chars": 6627,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/fixed_param_test.py",
"chars": 1405,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/__init__.py",
"chars": 593,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/cifar10_cnn.py",
"chars": 4635,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/cifar10_cnn_test.py",
"chars": 2576,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/mnist_cnn.py",
"chars": 5276,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/mnist_cnn_test.py",
"chars": 2048,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/mnist_fc.py",
"chars": 5340,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/mnist_fc_test.py",
"chars": 3425,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/model_factory.py",
"chars": 2351,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/models/model_factory_test.py",
"chars": 1936,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/prune.py",
"chars": 5820,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/prune_test.py",
"chars": 2506,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/__init__.py",
"chars": 593,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/init.py",
"chars": 3228,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/init_test.py",
"chars": 8068,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/mask_factory.py",
"chars": 2116,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/mask_factory_test.py",
"chars": 2523,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/masked.py",
"chars": 27065,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/masked_test.py",
"chars": 43127,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/pruning.py",
"chars": 3601,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/pruning_test.py",
"chars": 8971,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/symmetry.py",
"chars": 5937,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/pruning/symmetry_test.py",
"chars": 23529,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/random_mask.py",
"chars": 6368,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/random_mask_test.py",
"chars": 3658,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/requirements.txt",
"chars": 118,
"preview": "absl-py>=0.10.0\nflax>=0.2.2\njax>=0.2.0\njaxlib>=0.1.55\ntensorboard>=2.3.0\ntensorflow>=2.3.1\ntensorflow_datasets>=3.2.1\n"
},
{
"path": "rigl/experimental/jax/run.sh",
"chars": 1245,
"preview": "# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use thi"
},
{
"path": "rigl/experimental/jax/shuffled_mask.py",
"chars": 6390,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/shuffled_mask_test.py",
"chars": 3394,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/train.py",
"chars": 5952,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/train_test.py",
"chars": 1401,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/training/__init__.py",
"chars": 593,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/training/training.py",
"chars": 18771,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/training/training_test.py",
"chars": 10496,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/utils/__init__.py",
"chars": 593,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/utils/utils.py",
"chars": 5359,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/experimental/jax/utils/utils_test.py",
"chars": 10193,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb",
"chars": 24260,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"colab_type\": \"text\",\n \"id\": \"e5O"
},
{
"path": "rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb",
"chars": 27462,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"colab_type\": \"text\",\n \"id\": \"e5O"
},
{
"path": "rigl/imagenet_resnet/imagenet_train_eval.py",
"chars": 34374,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/mobilenetv1_model.py",
"chars": 13266,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/mobilenetv2_model.py",
"chars": 15571,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/pruning_layers.py",
"chars": 8676,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/resnet_model.py",
"chars": 29329,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/train_test.py",
"chars": 2868,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/utils.py",
"chars": 4331,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/imagenet_resnet/vgg.py",
"chars": 9298,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/mnist/mnist_train_eval.py",
"chars": 19616,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/mnist/visualize_mask_records.py",
"chars": 4153,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/requirements.txt",
"chars": 177,
"preview": "absl-py>=0.6.0\ngin-config\nnumpy>=1.15.4\nsix>=1.12.0\ntensorflow>=1.12.0,<2.0 # change to 'tensorflow-gpu' for gpu suppor"
},
{
"path": "rigl/rigl_tf2/README.md",
"chars": 2978,
"preview": "# Gradient Flow in Sparse Neural Networks and How Lottery Tickets Win\n<img src=\"https://github.com/google-research/rigl/"
},
{
"path": "rigl/rigl_tf2/colabs/MnistProp.ipynb",
"chars": 10346,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"e5O1UdsY202_\"\n },\n \"sou"
},
{
"path": "rigl/rigl_tf2/configs/dense.gin",
"chars": 687,
"preview": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500 # Log every 5"
},
{
"path": "rigl/rigl_tf2/configs/grasp.gin",
"chars": 796,
"preview": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining."
},
{
"path": "rigl/rigl_tf2/configs/hessian.gin",
"chars": 877,
"preview": "hessian.batch_size = 60000\nhessian.rows_at_once = 2\n# range(0,100,5) + range(100,2000,100) + range(2000,11719,500)\nhessi"
},
{
"path": "rigl/rigl_tf2/configs/interpolate.gin",
"chars": 86,
"preview": "interpolate.i_start = -0.20\ninterpolate.i_end = 1.20\ninterpolate.n_interpolation = 29\n"
},
{
"path": "rigl/rigl_tf2/configs/lottery.gin",
"chars": 544,
"preview": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500 # Log every 5"
},
{
"path": "rigl/rigl_tf2/configs/prune.gin",
"chars": 610,
"preview": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500 # Log every 5"
},
{
"path": "rigl/rigl_tf2/configs/rigl.gin",
"chars": 840,
"preview": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining."
},
{
"path": "rigl/rigl_tf2/configs/scratch.gin",
"chars": 646,
"preview": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining."
},
{
"path": "rigl/rigl_tf2/configs/set.gin",
"chars": 839,
"preview": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining."
},
{
"path": "rigl/rigl_tf2/configs/small_dense.gin",
"chars": 682,
"preview": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500 # Log every 5"
},
{
"path": "rigl/rigl_tf2/configs/snip.gin",
"chars": 787,
"preview": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining."
},
{
"path": "rigl/rigl_tf2/init_utils.py",
"chars": 5116,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rigl_tf2/interpolate.py",
"chars": 6584,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rigl_tf2/mask_updaters.py",
"chars": 14036,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rigl_tf2/metainit.py",
"chars": 4451,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rigl_tf2/mlp_configs/dense.gin",
"chars": 488,
"preview": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500 # Log every 5"
},
{
"path": "rigl/rigl_tf2/mlp_configs/lottery.gin",
"chars": 255,
"preview": "# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_path = '/tmp/sparse_spectru"
},
{
"path": "rigl/rigl_tf2/mlp_configs/prune.gin",
"chars": 537,
"preview": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500 # Log every 5"
},
{
"path": "rigl/rigl_tf2/mlp_configs/rigl.gin",
"chars": 485,
"preview": "training.use_metainit = False\n\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_i"
},
{
"path": "rigl/rigl_tf2/mlp_configs/scratch.gin",
"chars": 291,
"preview": "training.use_metainit = False\n\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_i"
},
{
"path": "rigl/rigl_tf2/mlp_configs/set.gin",
"chars": 483,
"preview": "training.use_metainit = False\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_in"
},
{
"path": "rigl/rigl_tf2/mlp_configs/small_dense.gin",
"chars": 610,
"preview": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500 # Log every 5"
},
{
"path": "rigl/rigl_tf2/networks.py",
"chars": 2777,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rigl_tf2/train.py",
"chars": 20547,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rigl_tf2/utils.py",
"chars": 8213,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/README.md",
"chars": 2285,
"preview": "# The State of Sparse Training in Deep Reinforcement Learning\n[**Paper**] [goo.gle/sparserl-paper](https://goo.gle/spars"
},
{
"path": "rigl/rl/dqn_agents.py",
"chars": 18416,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/requirements.txt",
"chars": 242,
"preview": "absl-py>=0.6.0\ndopamine-rl==4.0.5\ngin-config\nmujoco-py<2.2,>=2.1\nnumpy>=1.15.4\nsix>=1.12.0\ntensorflow==2.9.1 # change t"
},
{
"path": "rigl/rl/run.sh",
"chars": 835,
"preview": "# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use thi"
},
{
"path": "rigl/rl/run_experiment.py",
"chars": 8204,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/sparse_utils.py",
"chars": 8944,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_dense.gin",
"chars": 717,
"preview": "include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin'\n\nimport rigl.rl.dqn_agents\n\nDQNAgent.network = @dqn_agents."
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin",
"chars": 796,
"preview": "include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin'\n\nimport rigl.rl.dqn_agents\n\nDQNAgent.network = @dqn_agents."
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_prune.gin",
"chars": 678,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'prune'\n\n\nget_pruning_sparsities.target"
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_prune_impala_net.gin",
"chars": 689,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'prune'\n\n\nget_pruning_sparsi"
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_rigl.gin",
"chars": 856,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'rigl'\n\n# For sparse training methods w"
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_rigl_impala_net.gin",
"chars": 867,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'rigl'\n\n# For sparse trainin"
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_set.gin",
"chars": 850,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'set'\n\n# For sparse training methods we"
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_set_impala_net.gin",
"chars": 861,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'set'\n\n# For sparse training"
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_static.gin",
"chars": 440,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'static'\n\n# For sparse training methods"
},
{
"path": "rigl/rl/sparsetrain_configs/dqn_atari_static_impala_net.gin",
"chars": 451,
"preview": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'static'\n\n# For sparse train"
},
{
"path": "rigl/rl/tfagents/configs/dqn_gym_dense_config.gin",
"chars": 637,
"preview": "# Configs to run DQN training for dense networks on classic control environments.\n\ntrain_eval.env_name='CartPole-v0'\ntra"
},
{
"path": "rigl/rl/tfagents/configs/dqn_gym_pruning_config.gin",
"chars": 720,
"preview": "include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin'\n\n# Configs to run DQN training for pruning on classic contro"
},
{
"path": "rigl/rl/tfagents/configs/dqn_gym_sparse_config.gin",
"chars": 782,
"preview": "include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin'\n\n# Configs to run DQN training for static, set, and rigl on "
},
{
"path": "rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin",
"chars": 989,
"preview": "# Config to run training for dense on mujoco environments.\n\ntrain_eval.env_name='HalfCheetah-v2'\ntrain_eval.actor_fc_lay"
},
{
"path": "rigl/rl/tfagents/configs/ppo_mujoco_pruning_config.gin",
"chars": 649,
"preview": "include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin'\n\ntrain_eval.sparse_output_layer = True\ntrain_eval.train_m"
},
{
"path": "rigl/rl/tfagents/configs/ppo_mujoco_sparse_config.gin",
"chars": 817,
"preview": "include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin'\n\n# Config to run PPO training for static, set, and rigl o"
},
{
"path": "rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin",
"chars": 237,
"preview": "# Config to run SAC training for dense on mujoco environments.\n\ntrain_eval.env_name = 'Humanoid-v2'\ntrain_eval.initial_c"
},
{
"path": "rigl/rl/tfagents/configs/sac_mujoco_pruning_config.gin",
"chars": 742,
"preview": "include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin'\n\n# Configs to run SAC training for pruning on mujoco envi"
},
{
"path": "rigl/rl/tfagents/configs/sac_mujoco_sparse_config.gin",
"chars": 787,
"preview": "include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin'\n\n# Configs to run SAC training for static, set, and rigl "
},
{
"path": "rigl/rl/tfagents/dqn_train_eval.py",
"chars": 14815,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/ppo_train_eval.py",
"chars": 48348,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/sac_train_eval.py",
"chars": 27108,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/sparse_encoding_network.py",
"chars": 12976,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/sparse_ppo_actor_network.py",
"chars": 4277,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/sparse_ppo_discrete_actor_network.py",
"chars": 4899,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/sparse_ppo_discrete_actor_network_test.py",
"chars": 5127,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/sparse_tanh_normal_projection_network.py",
"chars": 2595,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/sparse_value_network.py",
"chars": 6938,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/tfagents/tf_sparse_utils.py",
"chars": 7944,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/rl/train.py",
"chars": 1820,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/sparse_optimizers.py",
"chars": 17956,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/sparse_optimizers_base.py",
"chars": 23625,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/sparse_optimizers_test.py",
"chars": 25606,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/sparse_utils.py",
"chars": 17807,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/sparse_utils_test.py",
"chars": 6085,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "rigl/str_sparsities.py",
"chars": 9969,
"preview": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you "
},
{
"path": "run.sh",
"chars": 765,
"preview": "# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use thi"
}
]
About this extraction
This page contains the full source code of the google-research/rigl GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 141 files (902.2 KB), approximately 232.9k tokens, and a symbol index with 698 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.