Repository: google-research/rigl Branch: master Commit: d39fc7d46505 Files: 141 Total size: 902.2 KB Directory structure: gitextract_rhnta46k/ ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── rigl/ │ ├── __init__.py │ ├── cifar_resnet/ │ │ ├── data_helper.py │ │ ├── data_helper_test.py │ │ ├── resnet_model.py │ │ └── resnet_train_eval.py │ ├── experimental/ │ │ └── jax/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── analysis/ │ │ │ └── plot_summary_json.ipynb │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── cifar10.py │ │ │ ├── cifar10_test.py │ │ │ ├── dataset_base.py │ │ │ ├── dataset_base_test.py │ │ │ ├── dataset_factory.py │ │ │ ├── dataset_factory_test.py │ │ │ ├── mnist.py │ │ │ └── mnist_test.py │ │ ├── fixed_param.py │ │ ├── fixed_param_test.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── cifar10_cnn.py │ │ │ ├── cifar10_cnn_test.py │ │ │ ├── mnist_cnn.py │ │ │ ├── mnist_cnn_test.py │ │ │ ├── mnist_fc.py │ │ │ ├── mnist_fc_test.py │ │ │ ├── model_factory.py │ │ │ └── model_factory_test.py │ │ ├── prune.py │ │ ├── prune_test.py │ │ ├── pruning/ │ │ │ ├── __init__.py │ │ │ ├── init.py │ │ │ ├── init_test.py │ │ │ ├── mask_factory.py │ │ │ ├── mask_factory_test.py │ │ │ ├── masked.py │ │ │ ├── masked_test.py │ │ │ ├── pruning.py │ │ │ ├── pruning_test.py │ │ │ ├── symmetry.py │ │ │ └── symmetry_test.py │ │ ├── random_mask.py │ │ ├── random_mask_test.py │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── shuffled_mask.py │ │ ├── shuffled_mask_test.py │ │ ├── train.py │ │ ├── train_test.py │ │ ├── training/ │ │ │ ├── __init__.py │ │ │ ├── training.py │ │ │ └── training_test.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── imagenet_resnet/ │ │ ├── colabs/ │ │ │ ├── MobileNet_Counting.ipynb │ │ │ └── Resnet_50_Param_Flops_Counting.ipynb │ │ ├── imagenet_train_eval.py │ │ ├── mobilenetv1_model.py │ │ ├── mobilenetv2_model.py │ │ ├── pruning_layers.py │ │ ├── resnet_model.py │ │ ├── train_test.py │ │ ├── utils.py │ │ └── vgg.py │ ├── mnist/ │ │ ├── mnist_train_eval.py │ │ └── visualize_mask_records.py │ ├── requirements.txt │ ├── rigl_tf2/ │ │ ├── README.md │ │ ├── colabs/ │ │ │ └── MnistProp.ipynb │ │ ├── configs/ │ │ │ ├── dense.gin │ │ │ ├── grasp.gin │ │ │ ├── hessian.gin │ │ │ ├── interpolate.gin │ │ │ ├── lottery.gin │ │ │ ├── prune.gin │ │ │ ├── rigl.gin │ │ │ ├── scratch.gin │ │ │ ├── set.gin │ │ │ ├── small_dense.gin │ │ │ └── snip.gin │ │ ├── init_utils.py │ │ ├── interpolate.py │ │ ├── mask_updaters.py │ │ ├── metainit.py │ │ ├── mlp_configs/ │ │ │ ├── dense.gin │ │ │ ├── lottery.gin │ │ │ ├── prune.gin │ │ │ ├── rigl.gin │ │ │ ├── scratch.gin │ │ │ ├── set.gin │ │ │ └── small_dense.gin │ │ ├── networks.py │ │ ├── train.py │ │ └── utils.py │ ├── rl/ │ │ ├── README.md │ │ ├── dqn_agents.py │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── run_experiment.py │ │ ├── sparse_utils.py │ │ ├── sparsetrain_configs/ │ │ │ ├── dqn_atari_dense.gin │ │ │ ├── dqn_atari_dense_impala_net.gin │ │ │ ├── dqn_atari_prune.gin │ │ │ ├── dqn_atari_prune_impala_net.gin │ │ │ ├── dqn_atari_rigl.gin │ │ │ ├── dqn_atari_rigl_impala_net.gin │ │ │ ├── dqn_atari_set.gin │ │ │ ├── dqn_atari_set_impala_net.gin │ │ │ ├── dqn_atari_static.gin │ │ │ └── dqn_atari_static_impala_net.gin │ │ ├── tfagents/ │ │ │ ├── configs/ │ │ │ │ ├── dqn_gym_dense_config.gin │ │ │ │ ├── dqn_gym_pruning_config.gin │ │ │ │ ├── dqn_gym_sparse_config.gin │ │ │ │ ├── ppo_mujoco_dense_config.gin │ │ │ │ ├── ppo_mujoco_pruning_config.gin │ │ │ │ ├── ppo_mujoco_sparse_config.gin │ │ │ │ ├── sac_mujoco_dense_config.gin │ │ │ │ ├── sac_mujoco_pruning_config.gin │ │ │ │ └── sac_mujoco_sparse_config.gin │ │ │ ├── dqn_train_eval.py │ │ │ ├── ppo_train_eval.py │ │ │ ├── sac_train_eval.py │ │ │ ├── sparse_encoding_network.py │ │ │ ├── sparse_ppo_actor_network.py │ │ │ ├── sparse_ppo_discrete_actor_network.py │ │ │ ├── sparse_ppo_discrete_actor_network_test.py │ │ │ ├── sparse_tanh_normal_projection_network.py │ │ │ ├── sparse_value_network.py │ │ │ └── tf_sparse_utils.py │ │ └── train.py │ ├── sparse_optimizers.py │ ├── sparse_optimizers_base.py │ ├── sparse_optimizers_test.py │ ├── sparse_utils.py │ ├── sparse_utils_test.py │ └── str_sparsities.py └── run.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute We'd love to accept your patches and contributions to this project. - If you want to contribute to the library please check `Issues` tab and feel free to take on any problem/issue you find interesting. - If your `issue` is not reported yet, please create a new one. It is important to discuss the problem/request before implementing the solution. - Reach us at rigl.authors@gmail.com any time! ## Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Rigging the Lottery: Making All Tickets Winners 80% Sparse Resnet-50 **Paper**: [https://arxiv.org/abs/1911.11134](https://arxiv.org/abs/1911.11134) **15min Presentation** [[pml4dc](https://pml4dc.github.io/iclr2020/program/pml4dc_7.html)] [[icml](https://icml.cc/virtual/2020/paper/5808)] **ML Reproducibility Challenge 2020** [report](https://openreview.net/forum?id=riCIeP6LzEE) ## Colabs for Calculating FLOPs of Sparse Models [MobileNet-v1](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb) [ResNet-50](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb) ## Best Sparse Models Parameters are float, so each parameter is represented with 4 bytes. Uniform sparsity distribution keeps first layer dense therefore have slightly larger size and parameters. ERK applies to all layers except for 99% sparse model, in which we set the first layer to be dense, since otherwise we observe much worse performance. ### Extended Training Results Performance of RigL increases significantly with extended training iterations. In this section we extend the training of sparse models by 5x. Note that sparse models require much less FLOPs per training iteration and therefore most of the extended trainings cost less FLOPs than baseline dense training. Observing improving performance we wanted to understand where the performance of sparse networks saturates. Longest training we ran had 100x training length of the original 100 epoch ImageNet training. This training costs 5.8x of the original dense training FLOPS and the resulting 99% sparse Resnet-50 achieves an impressive 68.15% test accuracy (vs 5x training accuracy of 61.86%). | S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | - (DENSE) | 0 | 3.2e18 | 8.2e9 | 102.122 | 76.8 | - | | ERK | 0.8 | 2.09x | 0.42x | 23.683 | 77.17 | [link](https://storage.googleapis.com/gresearch/rigl/s80erk5x.tar.gz) | | Uniform | 0.8 | 1.14x | 0.23x | 23.685 | 76.71 | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform5x.tar.gz) | | ERK | 0.9 | 1.23x | 0.24x | 13.499 | 76.42 | [link](https://storage.googleapis.com/gresearch/rigl/s90erk5x.tar.gz) | | Uniform | 0.9 | 0.66x | 0.13x | 13.532 | 75.73 | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform5x.tar.gz) | | ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.63 | [link](https://storage.googleapis.com/gresearch/rigl/s95erk5x.tar.gz) | | Uniform | 0.95 | 0.42x | 0.08x | 8.433 | 73.22 | [link](https://storage.googleapis.com/gresearch/rigl/s95uniform5x.tar.gz) | | ERK | 0.965 | 0.45x | 0.09x | 6.904 | 72.77 | [link](https://storage.googleapis.com/gresearch/rigl/s965erk5x.tar.gz) | | Uniform | 0.965 | 0.34x | 0.07x | 6.904 | 71.31 | [link](https://storage.googleapis.com/gresearch/rigl/s965uniform5x.tar.gz) | | ERK | 0.99 | 0.29x | 0.05x | 4.354 | 61.86 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk5x.tar.gz) | | ERK | 0.99 | 0.58x | 0.05x | 4.354 | 63.89 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk10x.tar.gz) | | ERK | 0.99 | 2.32x | 0.05x | 4.354 | 66.94 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk40x.tar.gz) | | ERK | **0.99** | 5.8x | 0.05x | 4.354 | **68.15** | [link](https://storage.googleapis.com/gresearch/rigl/s99erk100x.tar.gz) | We also ran extended training runs with MobileNet-v1. Again training 100x more, we were not able saturate the performance. Training longer consistently achieved better results. | S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | - (DENSE) | 0 | 4.5e17 | 1.14e9 | 16.864 | 72.1 | - | | ERK | 0.89 | 1.39x | 0.21x | 2.392 | 69.31 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk10x.tar.gz) | | ERK | 0.89 | 2.79x | 0.21x | 2.392 | 70.63 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk50x.tar.gz) | | Uniform | 0.89 | 1.25x | 0.09x | 2.392 | 69.28 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform10x.tar.gz) | | Uniform | 0.89 | 6.25x | 0.09x | 2.392 | 70.25 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform50x.tar.gz) | | Uniform | 0.89 | 12.5x | 0.09x | 2.392 | 70.59 | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform100x.tar.gz) | ### 1x Training Results | S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.12 | [link](https://storage.googleapis.com/gresearch/rigl/s80erk1x.tar.gz) | | Uniform | 0.8 | 0.23x | 0.23x | 23.685 | 74.60 | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform1x.tar.gz) | | ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.07 | [link](https://storage.googleapis.com/gresearch/rigl/s90erk1x.tar.gz) | | Uniform | 0.9 | 0.13x | 0.13x | 13.532 | 72.02 | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform1x.tar.gz) | ### Results w/o label smoothing | S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.02 | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_1x.tar.gz) | | ERK | 0.8 | 2.09x | 0.42x | 23.683 | 76.17 | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_5x.tar.gz) | | ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.4 | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_1x.tar.gz) | | ERK | 0.9 | 1.23x | 0.24x | 13.499 | 75.9 | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_5x.tar.gz) | | ERK | 0.95 | 0.13x | 0.12x | 8.399 | 70.39 | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_1x.tar.gz) | | ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.36 | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_5x.tar.gz) | ### Evaluating checkpoints Download the checkpoints and run the evaluation on ERK checkpoints with the following: ```python python imagenet_train_eval.py --mode=eval_once --output_dir=path/to/ckpt/folder \ --eval_once_ckpt_prefix=model.ckpt-3200000 --use_folder_stub=False \ --training_method=rigl --mask_init_method=erdos_renyi_kernel \ --first_layer_sparsity=-1 ``` When running checkpoints with uniform sparsity distribution use `--mask_init_method=random` and `--first_layer_sparsity=0`. Set `--model_architecture=mobilenet_v1` when evaluating mobilenet checkpoints. ## Sparse Training Algorithms In this repository we implement following dynamic sparsity strategies: 1. [SET](https://www.nature.com/articles/s41467-018-04316-3): Implements Sparse Evalutionary Training (SET) which corresponds to replacing low magnitude connections randomly with new ones. 2. [SNFS](https://arxiv.org/abs/1907.04840): Implements momentum based training *without* sparsity re-distribution: 3. [RigL](https://arxiv.org/abs/1911.11134): Our method, RigL, removes a fraction of connections based on weight magnitudes and activates new ones using instantaneous gradient information. And the following one-shot pruning algorithm: 1. [SNIP](https://arxiv.org/abs/1810.02340): Single-shot Network Pruning based on connection sensitivity prunes the least salient connections before training. We have code for following settings: - [Imagenet2012](https://github.com/google-research/rigl/tree/master/rigl/imagenet_resnet): TPU compatible code with Resnet-50 and MobileNet-v1/v2. - [CIFAR-10](https://github.com/google-research/rigl/tree/master/rigl/cifar_resnet) with WideResNets. - [MNIST](https://github.com/google-research/rigl/tree/master/rigl/mnist) with 2 layer fully connected network. ## Setup First clone this repo. ```bash git clone https://github.com/google-research/rigl.git cd rigl ``` We use [Neurips 2019 MicroNet Challenge](https://micronet-challenge.github.io/) code for counting operations and size of our networks. Let's clone the google_research repo and add current folder to the python path. ```bash git clone https://github.com/google-research/google-research.git mv google-research/ google_research/ export PYTHONPATH=$PYTHONPATH:$PWD ``` Now we can run some tests. Following script creates a virtual environment and installs the necessary libraries. Finally, it runs few tests. ```bash bash run.sh ``` We need to activate the virtual environment before running an experiment. With that, we are ready to run some trivial MNIST experiments. ```bash source env/bin/activate python rigl/mnist/mnist_train_eval.py ``` You can load and verify the performance of the Resnet-50 checkpoints like following. ```bash python rigl/imagenet_resnet/imagenet_train_eval.py --mode=eval_once --training_method=baseline --eval_batch_size=100 --output_dir=/path/to/folder --eval_once_ckpt_prefix=s80_model.ckpt-1280000 --use_folder_stub=False ``` We use the [Official TPU Code](https://github.com/tensorflow/tpu/tree/master/models/official/resnet) for loading ImageNet data. First clone the tensorflow/tpu repo and then add models/ folder to the python path. ```bash git clone https://github.com/tensorflow/tpu.git export PYTHONPATH=$PYTHONPATH:$PWD/tpu/models/ ``` ## Other Implementations - [Graphcore-TF-MNIST](https://github.com/graphcore/examples/tree/master/applications/tensorflow/dynamic_sparsity/mnist_rigl): with sparse matrix ops! - [Pytorch implementation](https://github.com/McCrearyD/rigl-torch) by Dyllan McCreary. - [Micrograd-Pure Python](https://evcu.github.io/ml/sparse-micrograd/): This is a toy example with pure python sparse implementation. Caution, very slow but fun. ## Citation ``` @incollection{rigl, author = {Evci, Utku and Gale, Trevor and Menick, Jacob and Castro, Pablo Samuel and Elsen, Erich}, booktitle = {Proceedings of Machine Learning and Systems 2020}, pages = {471--481}, title = {Rigging the Lottery: Making All Tickets Winners}, year = {2020} } ``` ## Disclaimer This is not an official Google product. ================================================ FILE: rigl/__init__.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This repo involves the code for training sparse neural networks.""" name = 'rigl' ================================================ FILE: rigl/cifar_resnet/data_helper.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helper functions for CIFAR10 data input pipeline. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf import tensorflow_datasets as tfds IMG_SIZE = 32 def pad_input(x, crop_dim=4): """Concatenates sides of image with pixels cropped from the border of image. Args: x: Input image float32 tensor. crop_dim: Number of pixels to crop from the edge of the image. Cropped pixels are then concatenated to the original image. Returns: x: input image float32 tensor. Transformed by padding edges with cropped pixels. """ x = tf.concat( [x[:crop_dim, :, :][::-1], x, x[-crop_dim:, :, :][::-1]], axis=0) x = tf.concat( [x[:, :crop_dim, :][:, ::-1], x, x[:, -crop_dim:, :][:, ::-1]], axis=1) return x def preprocess_train(x, width, height): """Pre-processing applied to training data set. Args: x: Input image float32 tensor. width: int specifying intended width in pixels of image after preprocessing. height: int specifying intended height in pixels of image after preprocessing. Returns: x: transformed input with random crops, flips and reflection. """ x = pad_input(x, crop_dim=4) x = tf.random_crop(x, [width, height, 3]) x = tf.image.random_flip_left_right(x) return x def input_fn(params): """Provides batches of CIFAR data. Args: params: A dictionary with a set of arguments, namely: * batch_size (int32), specifies data points in a batch * data_split (string), designates train or eval * data_dictionary (string), specifies directory location of input dataset Returns: images: A float32`Tensor` of size [batch_size, 32, 32, 3]. labels: A int32`Tensor` of size [batch_size, num_classes]. """ def parse_serialized_example(record): """Parses a CIFAR10 example.""" image = record['image'] label = tf.cast(record['label'], tf.int32) image = tf.cast(image, tf.float32) image = tf.image.per_image_standardization(image) if data_split == 'train': image = preprocess_train(image, IMG_SIZE, IMG_SIZE) return image, label data_split = params['data_split'] batch_size = params['batch_size'] if data_split == 'eval': data_split = 'test' dataset = tfds.load('cifar10:3.*.*', split=data_split) # we only repeat an example and shuffle inputs during training if data_split == 'train': dataset = dataset.repeat().shuffle(buffer_size=50000) # deserialize record into tensors and apply pre-processing. dataset = dataset.map(parse_serialized_example).prefetch(batch_size) # at test time, for the final batch we drop remaining examples so that no # example is seen twice. dataset = dataset.batch(batch_size) images_batch, labels_batch = tf.data.make_one_shot_iterator( dataset).get_next() return (tf.reshape(images_batch, [batch_size, IMG_SIZE, IMG_SIZE, 3]), tf.reshape(labels_batch, [batch_size])) ================================================ FILE: rigl/cifar_resnet/data_helper_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Tests for the data_helper input pipeline and the training process. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os from absl import flags from absl import logging import absl.testing.parameterized as parameterized from rigl.cifar_resnet import resnet_train_eval from rigl.cifar_resnet.data_helper import input_fn import tensorflow.compat.v1 as tf from tensorflow.contrib.model_pruning.python import pruning FLAGS = flags.FLAGS BATCH_SIZE = 1 NUM_IMAGES = 1 JITTER_MULTIPLIER = 2 class DataHelperTest(tf.test.TestCase, parameterized.TestCase): def get_next(self): data_directory = FLAGS.data_directory # we pass the updated eval and train string to the params dictionary. params = { 'mode': 'test', 'data_split': 'eval', 'batch_size': BATCH_SIZE, 'data_directory': data_directory } test_inputs, test_labels = input_fn(params) return test_inputs, test_labels def testInputPipeline(self): tf.reset_default_graph() g = tf.Graph() with g.as_default(): test_inputs, test_labels = self.get_next() with self.test_session() as sess: test_images_out, test_labels_out = sess.run([test_inputs, test_labels]) self.assertAllEqual(test_images_out.shape, [BATCH_SIZE, 32, 32, 3]) self.assertAllEqual(test_labels_out.shape, [BATCH_SIZE]) @parameterized.parameters( { 'training_method': 'baseline', }, { 'training_method': 'threshold', }, { 'training_method': 'rigl', }, ) def testTrainingStep(self, training_method): tf.reset_default_graph() g = tf.Graph() with g.as_default(): images, labels = self.get_next() global_step, _, _, logits = resnet_train_eval.build_model( mode='train', images=images, labels=labels, training_method=training_method, num_classes=FLAGS.num_classes, depth=FLAGS.resnet_depth, width=FLAGS.resnet_width) tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) total_loss = tf.losses.get_total_loss(add_regularization_losses=True) learning_rate = 0.1 opt = tf.train.MomentumOptimizer( learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if training_method in ['threshold']: # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning() logging.info('starting mask update op') mask_update_op = pruning_obj.conditional_mask_update_op() # Create the training op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = opt.minimize(total_loss, global_step) init_op = tf.global_variables_initializer() with self.test_session() as sess: # test that we can train successfully for 1 step sess.run(init_op) for _ in range(1): sess.run(train_op) if training_method in ['threshold']: sess.run(mask_update_op) if __name__ == '__main__': tf.test.main() ================================================ FILE: rigl/cifar_resnet/resnet_model.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Model implementation of wide resnet model. Implements masking layer if pruning method is selected. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from rigl.imagenet_resnet.pruning_layers import sparse_conv2d from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected import tensorflow.compat.v1 as tf from tensorflow.contrib import layers as contrib_layers _BN_EPS = 1e-5 _BN_MOMENTUM = 0.9 class WideResNetModel(object): """Implements WideResNet model.""" def __init__(self, is_training, regularizer=None, data_format='channels_last', pruning_method='baseline', droprate=0.3, prune_first_layer=True, prune_last_layer=True): """WideResnet as described in https://arxiv.org/pdf/1605.07146.pdf. Args: is_training: Boolean, True during model training, false for evaluation/inference. regularizer: A regularization function (mapping variables to regularization losses), or None. data_format: A string that indicates whether the channels are the second or last index in the matrix. 'channels_first' or 'channels_last'. pruning_method: str, 'threshold' or 'baseline'. droprate: float, dropout rate to apply activations. prune_first_layer: bool, if True first layer is pruned. prune_last_layer: bool, if True last layer is pruned. """ self._training = is_training self._regularizer = regularizer self._data_format = data_format self._pruning_method = pruning_method self._droprate = droprate self._prune_first_layer = prune_first_layer self._prune_last_layer = prune_last_layer if data_format == 'channels_last': self._channel_axis = -1 elif data_format == 'channels_first': self._channel_axis = 1 def build(self, inputs, depth, width, num_classes, name=None): """Model architecture to train the model. The configuration of the resnet blocks requires that depth should be 6n+4 where n is the number of resnet blocks desired. Args: inputs: A 4D float tensor containing the model inputs. depth: Number of convolutional layers in the network. width: Size of the convolutional filters in the residual blocks. num_classes: Positive integer number of possible classes. name: Optional string, the name of the resulting op in the TF graph. Returns: A 2D float logits tensor of shape (batch_size, num_classes). Raises: ValueError: if depth is not the minimum amount required to build the model. """ if (depth - 4) % 6 != 0: raise ValueError('Depth of ResNet specified not sufficient.') resnet_blocks = (depth - 4) // 6 with tf.variable_scope(name, 'resnet_model'): first_layer_technique = self._pruning_method if not self._prune_first_layer: first_layer_technique = 'baseline' net = self._conv( inputs, 'conv_1', output_size=16, sparsity_technique=first_layer_technique) net = self._residual_block( net, 'conv_2', 16 * width, subsample=False, blocks=resnet_blocks) net = self._residual_block( net, 'conv_3', 32 * width, subsample=True, blocks=resnet_blocks) net = self._residual_block( net, 'conv_4', 64 * width, subsample=True, blocks=resnet_blocks) # Put the final BN, relu before the max pooling. with tf.name_scope('Pooling'): net = self._batch_norm(net) net = tf.nn.relu(net) net = tf.layers.average_pooling2d( net, pool_size=8, strides=1, data_format=self._data_format) net = contrib_layers.flatten(net) last_layer_technique = self._pruning_method if not self._prune_last_layer: last_layer_technique = 'baseline' net = self._dense( net, num_classes, 'logits', sparsity_technique=last_layer_technique) return net def _batch_norm(self, net, name=None): """Adds batchnorm to the model. Input gradients cannot be computed with fused batch norm; causes recursive loop of tf.gradient call. If regularizer is specified, fused batchnorm must be set to False (default setting). Args: net: Pre-batch norm tensor activations. name: Specified name for batch normalization layer. Returns: batch norm layer: Activations from the batch normalization layer. """ return tf.layers.batch_normalization( inputs=net, fused=False, training=self._training, axis=self._channel_axis, momentum=_BN_MOMENTUM, epsilon=_BN_EPS, name=name) def _dense(self, net, num_units, name=None, sparsity_technique='baseline'): return sparse_fully_connected( x=net, units=num_units, sparsity_technique=sparsity_technique, kernel_regularizer=self._regularizer, name=name) def _conv(self, net, name, output_size, strides=(1, 1), padding='SAME', sparsity_technique='baseline'): """returns conv layer.""" return sparse_conv2d( x=net, units=output_size, activation=None, kernel_size=[3, 3], use_bias=False, kernel_initializer=None, kernel_regularizer=self._regularizer, bias_initializer=None, biases_regularizer=None, sparsity_technique=sparsity_technique, normalizer_fn=None, strides=strides, padding=padding, data_format=self._data_format, name=name) def _residual_block(self, net, name, output_size, subsample, blocks): """Adds a residual block to the model.""" with tf.name_scope(name): for n in range(blocks): with tf.name_scope('res_%d' % n): # when subsample is true + first block a larger stride is used. if subsample and n == 0: strides = [2, 2] else: strides = [1, 1] # Create the skip connection skip = net end_point = 'skip_%s' % name net = self._batch_norm(net) net = tf.nn.relu(net) if net.get_shape()[3].value != output_size: skip = sparse_conv2d( x=net, units=output_size, activation=None, kernel_size=[1, 1], use_bias=False, kernel_initializer=None, kernel_regularizer=self._regularizer, bias_initializer=None, biases_regularizer=None, sparsity_technique=self._pruning_method, normalizer_fn=None, strides=strides, padding='VALID', data_format=self._data_format, name=end_point) # Create residual net = self._conv( net, '%s_%d_1' % (name, n), output_size, strides, sparsity_technique=self._pruning_method) net = self._batch_norm(net) net = tf.nn.relu(net) net = tf.keras.layers.Dropout(self._droprate)(net, self._training) net = self._conv( net, '%s_%d_2' % (name, n), output_size, sparsity_technique=self._pruning_method) # Combine the residual and the skip connection net += skip return net ================================================ FILE: rigl/cifar_resnet/resnet_train_eval.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""This script trains a ResNet model that implements various pruning methods. Implement pruning method during training: Specify the pruning method to use using FLAGS.training_method - To train a model with no pruning, specify FLAGS.training_method='baseline' Specify desired end sparsity using FLAGS.end_sparsity """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os from absl import flags from rigl import sparse_optimizers from rigl import sparse_utils from rigl.cifar_resnet.data_helper import input_fn from rigl.cifar_resnet.resnet_model import WideResNetModel from rigl.imagenet_resnet import utils import tensorflow.compat.v1 as tf from tensorflow.compat.v1 import estimator as tf_estimator from tensorflow.contrib import layers as contrib_layers from tensorflow.contrib import training as contrib_training from tensorflow.contrib.model_pruning.python import pruning flags.DEFINE_string('master', 'local', 'BNS name of the TensorFlow runtime to use.') flags.DEFINE_integer('ps_task', 0, 'Task id of the replica running the training.') flags.DEFINE_integer('keep_checkpoint_max', 5, 'Number of checkpoints to save, set 0 for all.') flags.DEFINE_string('pruning_hparams', '', 'Comma separated list of pruning-related hyperparameters') flags.DEFINE_string('train_dir', '/tmp/cifar10/', 'Directory where to write event logs and checkpoint.') flags.DEFINE_string( 'load_mask_dir', '', 'Directory of a trained model from which to load only the mask') flags.DEFINE_string( 'initial_value_checkpoint', '', 'Directory of a model from which to load only the parameters') flags.DEFINE_integer( 'seed', default=0, help=('Sets the random seed.')) flags.DEFINE_float('momentum', 0.9, 'The momentum value.') # 250 Epochs flags.DEFINE_integer('max_steps', 97656, 'Number of steps to run.') flags.DEFINE_float('l2', 5e-4, 'Scale factor for L2 weight decay.') flags.DEFINE_integer('resnet_depth', 16, 'Number of core convolutional layers' 'in the network.') flags.DEFINE_integer('resnet_width', 4, 'Width of the residual blocks.') flags.DEFINE_string( 'data_directory', '', 'data directory where cifar10 records are stored') flags.DEFINE_integer('num_classes', 10, 'Number of classes.') flags.DEFINE_integer('dataset_size', 50000, 'Size of training dataset.') flags.DEFINE_integer('batch_size', 128, 'Batch size.') flags.DEFINE_integer('checkpoint_steps', 5000, 'Specifies step interval for' 'saving model checkpoints.') flags.DEFINE_integer( 'summaries_steps', 300, 'Specifies interval in steps for' 'saving model summaries.') flags.DEFINE_bool('per_class_metrics', True, 'Whether to add per-class' 'performance summaries.') flags.DEFINE_enum('mode', 'train', ('train_and_eval', 'train', 'eval'), 'String that specifies either inference or training') # pruning flags flags.DEFINE_integer('sparsity_begin_step', 20000, 'Step to begin pruning at.') flags.DEFINE_integer('sparsity_end_step', 75000, 'Step to end pruning at.') flags.DEFINE_integer('pruning_frequency', 1000, 'Step interval between pruning steps.') flags.DEFINE_float('end_sparsity', 0.9, 'Target sparsity desired by end of training.') flags.DEFINE_enum( 'training_method', 'baseline', ('scratch', 'set', 'baseline', 'momentum', 'rigl', 'static', 'snip', 'prune'), 'Method used for training sparse network. `scratch` means initial mask is ' 'kept during training. `set` is for sparse evalutionary training and ' '`baseline` is for dense baseline.') flags.DEFINE_bool('prune_first_layer', False, 'Whether or not to apply sparsification to the first layer') flags.DEFINE_bool('prune_last_layer', True, 'Whether or not to apply sparsification to the last layer') flags.DEFINE_float('drop_fraction', 0.3, 'When changing mask dynamically, this fraction decides how ' 'much of the ') flags.DEFINE_string('drop_fraction_anneal', 'constant', 'If not empty the drop fraction is annealed during sparse' ' training. One of the following: `constant`, `cosine` or ' '`exponential_(\\d*\\.?\\d*)$`. For example: ' '`exponential_3`, `exponential_.3`, `exponential_0.3`. ' 'The number after `exponential` defines the exponent.') flags.DEFINE_string('grow_init', 'zeros', 'Passed to the SparseInitializer, one of: zeros, ' 'initial_value, random_normal, random_uniform.') flags.DEFINE_float('s_momentum', 0.9, 'Momentum values for exponential moving average of ' 'gradients. Used when training_method="momentum".') flags.DEFINE_float('rigl_acc_scale', 0., 'Used to scale initial accumulated gradients for new ' 'connections.') flags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin mask updates.') flags.DEFINE_integer('maskupdate_end_step', 75000, 'Step to end mask updates.') flags.DEFINE_integer('maskupdate_frequency', 100, 'Step interval between mask updates.') flags.DEFINE_string( 'mask_init_method', default='random', help='If not empty string and mask is not loaded from a checkpoint, ' 'indicates the method used for mask initialization. One of the following: ' '`random`, `erdos_renyi`.') flags.DEFINE_float('training_steps_multiplier', 1.0, 'Training schedule is shortened or extended with the ' 'multiplier, if it is not 1.') FLAGS = flags.FLAGS PARAM_SUFFIXES = ('gamma', 'beta', 'weights', 'biases') MASK_SUFFIX = 'mask' CLASSES = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] def create_eval_metrics(labels, logits): """Creates the evaluation metrics for the model.""" eval_metrics = {} label_keys = CLASSES predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) eval_metrics['eval_accuracy'] = tf.metrics.accuracy( labels=labels, predictions=predictions) if FLAGS.per_class_metrics: with tf.name_scope('class_level_summaries') as scope: for i in range(len(label_keys)): labels = tf.cast(labels, tf.int64) name = scope + '/' + label_keys[i] eval_metrics[('class_level_summaries/precision/' + label_keys[i])] = tf.metrics.precision_at_k( labels=labels, predictions=logits, class_id=i, k=1, name=name) eval_metrics[('class_level_summaries/recall/' + label_keys[i])] = tf.metrics.recall_at_k( labels=labels, predictions=logits, class_id=i, k=1, name=name) return eval_metrics def train_fn(training_method, global_step, total_loss, train_dir, accuracy, top_5_accuracy): """Training script for resnet model. Args: training_method: specifies the method used to sparsify networks. global_step: the current step of training/eval. total_loss: tensor float32 of the cross entropy + regularization losses. train_dir: string specifying where directory where summaries are saved. accuracy: tensor float32 batch classification accuracy. top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes). Returns: hooks: summary tensors to be computed at each training step. eval_metrics: set to None during training. train_op: the optimization term. """ # Rougly drops at every 30k steps. boundaries = [30000, 60000, 90000] if FLAGS.training_steps_multiplier != 1.0: multiplier = FLAGS.training_steps_multiplier boundaries = [int(x * multiplier) for x in boundaries] tf.logging.info( 'Learning Rate boundaries are updated with multiplier:%.2f', multiplier) learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)], name='lr_schedule') optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if training_method == 'set': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseSETOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif training_method == 'static': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseStaticOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif training_method == 'momentum': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseMomentumOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif training_method == 'rigl': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseRigLOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif training_method == 'snip': optimizer = sparse_optimizers.SparseSnipOptimizer( optimizer, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, use_tpu=False) elif training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) # Create the training op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(total_loss, global_step) if training_method == 'prune': # construct the necessary hparams string from the FLAGS hparams_string = ('begin_pruning_step={0},' 'sparsity_function_begin_step={0},' 'end_pruning_step={1},' 'sparsity_function_end_step={1},' 'target_sparsity={2},' 'pruning_frequency={3},' 'threshold_decay=0,' 'use_tpu={4}'.format( FLAGS.sparsity_begin_step, FLAGS.sparsity_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, False, )) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) tf.logging.info('starting mask update op') # We override the train op to also update the mask. with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() masks = pruning.get_masks() mask_metrics = utils.mask_summaries(masks) for name, tensor in mask_metrics.items(): tf.summary.scalar(name, tensor) tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('accuracy', accuracy) tf.summary.scalar('total_loss', total_loss) tf.summary.scalar('top_5_accuracy', top_5_accuracy) # Logging drop_fraction if dynamic sparse training. if training_method in ('set', 'momentum', 'rigl', 'static'): tf.summary.scalar('drop_fraction', optimizer.drop_fraction) summary_op = tf.summary.merge_all() summary_hook = tf.train.SummarySaverHook( save_secs=300, output_dir=train_dir, summary_op=summary_op) hooks = [summary_hook] eval_metrics = None return hooks, eval_metrics, train_op def build_model(mode, images, labels, training_method='baseline', num_classes=10, depth=10, width=4): """Build the wide ResNet model for training or eval. If regularizer is specified, a regularizer term is added to the loss function. The regularizer term is computed using either the pre-softmax activation or an auxiliary network logits layer based upon activations earlier in the network after the first resnet block. Args: mode: String for whether training or evaluation is taking place. images: A 4D float32 tensor containing the model input images. labels: A int32 tensor of size (batch size, number of classes) containing the model labels. training_method: The method used to sparsify the network weights. num_classes: The number of distinct labels in the dataset. depth: Number of core convolutional layers in the network. width: The width of the convolurional filters in the resnet block. Returns: total_loss: A 1D float32 tensor that is the sum of cross-entropy and all regularization losses. accuracy: A 1D float32 accuracy tensor. Raises: ValueError: if depth is not the minimum amount required to build the model. """ regularizer_term = tf.constant(FLAGS.l2, tf.float32) kernel_regularizer = contrib_layers.l2_regularizer(scale=regularizer_term) # depth should be 6n+4 where n is the desired number of resnet blocks # if n=2,depth=10 n=3,depth=22, n=5,depth=34 n=7,depth=46 if (depth - 4) % 6 != 0: raise ValueError('Depth of ResNet specified not sufficient.') if mode == 'train': is_training = True else: is_training = False # 'threshold' would create layers with mask. pruning_method = 'baseline' if training_method == 'baseline' else 'threshold' model = WideResNetModel( is_training=is_training, regularizer=kernel_regularizer, data_format='channels_last', pruning_method=pruning_method, prune_first_layer=FLAGS.prune_first_layer, prune_last_layer=FLAGS.prune_last_layer) logits = model.build( images, depth=depth, width=width, num_classes=num_classes) global_step = tf.train.get_or_create_global_step() predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32)) in_top_5 = tf.cast( tf.nn.in_top_k(predictions=logits, targets=labels, k=5), tf.float32) top_5_accuracy = tf.cast(tf.reduce_mean(in_top_5), tf.float32) return global_step, accuracy, top_5_accuracy, logits def wide_resnet_w_pruning(features, labels, mode, params): """The model_fn for ResNet wide with pruning. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: Dictionary of parameters passed to the model. Returns: A EstimatorSpec for the model Raises: ValueError: if mode is not recognized as train or eval. """ if isinstance(features, dict): features = features['feature'] train_dir = params['train_dir'] training_method = params['training_method'] global_step, accuracy, top_5_accuracy, logits = build_model( mode=mode, images=features, labels=labels, training_method=training_method, num_classes=FLAGS.num_classes, depth=FLAGS.resnet_depth, width=FLAGS.resnet_width) if mode == tf_estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf_estimator.export.PredictOutput(predictions) }) with tf.name_scope('computing_cross_entropy_loss'): entropy_loss = tf.losses.sparse_softmax_cross_entropy( labels=labels, logits=logits) tf.summary.scalar('cross_entropy_loss', entropy_loss) with tf.name_scope('computing_total_loss'): total_loss = tf.losses.get_total_loss(add_regularization_losses=True) if mode == tf_estimator.ModeKeys.TRAIN: hooks, eval_metrics, train_op = train_fn(training_method, global_step, total_loss, train_dir, accuracy, top_5_accuracy) elif mode == tf_estimator.ModeKeys.EVAL: hooks = None train_op = None with tf.name_scope('summaries'): eval_metrics = create_eval_metrics(labels, logits) else: raise ValueError('mode not recognized as training or eval.') # If given load parameter values. if FLAGS.initial_value_checkpoint: tf.logging.info('Loading inital values from: %s', FLAGS.initial_value_checkpoint) utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint, FLAGS.train_dir, PARAM_SUFFIXES) # Load or randomly initialize masks. if (FLAGS.load_mask_dir and FLAGS.training_method not in ('snip', 'baseline', 'prune')): # Init masks. tf.logging.info('Loading masks from %s', FLAGS.load_mask_dir) utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir, FLAGS.train_dir, MASK_SUFFIX) scaffold = tf.train.Scaffold() elif (FLAGS.mask_init_method and FLAGS.training_method not in ('snip', 'baseline', 'scratch', 'prune')): tf.logging.info('Initializing masks using method: %s', FLAGS.mask_init_method) all_masks = pruning.get_masks() assigner = sparse_utils.get_mask_init_fn( all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, {}) def init_fn(scaffold, session): """A callable for restoring variable from a checkpoint.""" del scaffold # Unused. session.run(assigner) scaffold = tf.train.Scaffold(init_fn=init_fn) else: assert FLAGS.training_method in ('snip', 'baseline', 'prune') scaffold = None tf.logging.info('No mask is set, starting dense.') return tf_estimator.EstimatorSpec( mode=mode, training_hooks=hooks, loss=total_loss, train_op=train_op, eval_metric_ops=eval_metrics, scaffold=scaffold) def main(argv): del argv # Unused. tf.set_random_seed(FLAGS.seed) if FLAGS.training_steps_multiplier != 1.0: multiplier = FLAGS.training_steps_multiplier FLAGS.max_steps = int(FLAGS.max_steps * multiplier) FLAGS.maskupdate_begin_step = int(FLAGS.maskupdate_begin_step * multiplier) FLAGS.maskupdate_end_step = int(FLAGS.maskupdate_end_step * multiplier) FLAGS.sparsity_begin_step = int(FLAGS.sparsity_begin_step * multiplier) FLAGS.sparsity_end_step = int(FLAGS.sparsity_end_step * multiplier) tf.logging.info( 'Training schedule is updated with multiplier: %.2f', multiplier) # configures train directories based upon hyperparameters used. if FLAGS.training_method == 'prune': folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity), str(FLAGS.sparsity_begin_step), str(FLAGS.sparsity_end_step), str(FLAGS.pruning_frequency)) elif FLAGS.training_method in ('set', 'momentum', 'rigl', 'static'): folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity), str(FLAGS.maskupdate_begin_step), str(FLAGS.maskupdate_end_step), str(FLAGS.maskupdate_frequency)) elif FLAGS.training_method in ('baseline', 'snip', 'scratch'): folder_stub = os.path.join(FLAGS.training_method, str(0.0), str(0.0), str(0.0), str(0.0)) else: raise ValueError('Training method is not known %s' % FLAGS.training_method) train_dir = os.path.join(FLAGS.train_dir, folder_stub) # we pass the updated eval and train string to the params dictionary. params = {} params['train_dir'] = train_dir params['data_split'] = FLAGS.mode params['batch_size'] = FLAGS.batch_size params['data_directory'] = FLAGS.data_directory params['mode'] = FLAGS.mode params['training_method'] = FLAGS.training_method run_config = tf_estimator.RunConfig( model_dir=train_dir, keep_checkpoint_max=FLAGS.keep_checkpoint_max, save_summary_steps=FLAGS.summaries_steps, save_checkpoints_steps=FLAGS.checkpoint_steps, log_step_count_steps=100) classifier = tf_estimator.Estimator( model_fn=wide_resnet_w_pruning, model_dir=train_dir, config=run_config, params=params) if FLAGS.mode == 'eval': eval_steps = 10000 // FLAGS.batch_size # Run evaluation when there's a new checkpoint for ckpt in contrib_training.checkpoints_iterator(train_dir): print('Starting to evaluate.') try: classifier.evaluate( input_fn=input_fn, steps=eval_steps, checkpoint_path=ckpt, name='eval') # Terminate eval job when final checkpoint is reached global_step = int(os.path.basename(ckpt).split('-')[1]) if global_step >= FLAGS.max_steps: print('Evaluation finished after training step %d' % global_step) break except tf.errors.NotFoundError: print('Checkpoint no longer exists,skipping checkpoint.') else: print('Starting training...') if FLAGS.mode == 'train': classifier.train(input_fn=input_fn, max_steps=FLAGS.max_steps) if __name__ == '__main__': tf.app.run(main) ================================================ FILE: rigl/experimental/jax/README.md ================================================ # Weight Symmetry Research Code This code is mostly written by Yani Ioannou. ## Experiment Summary There are a number of experiment drivers defined in the base directory: ### Experiment Types {#experiment-types} random_mask : Random Variable Sparsity Masks : This experiment generates random masks of a given type (see [Mask Types](#mask-types)) within the *given a sparsity range*, and trains the models, tracking mask statistics and training details. Masks are generated with a random number of connections and randomly shuffled. shuffled_mask : Random Fixed Sparsity Masks : This experiment generates random masks of a given type (see [Mask Types](#mask-types)) *of a fixed sparsity*, and trains the models, tracking mask statistics and training details. Masks are generated with a fixed number of connections and simply shuffled. fixed_param : Train models with (approximately) fixed number of parameters, but varying depth/width. : Train models with (approximately) fixed number of parameters, but varying depth/width, with shuffled mask (as in shuffled_mask driver), and only the MNIST_FC model type. prune : Simple Pruning/Training Driver : This experiment trains a dense model pruning either iteratively or one-shot, tracking mask statistics and training details. train : Simple Training Driver (Without Masking/Pruning) : This experiment simply trains a dense model, tracking mask statistics and training details. ### Mask Types {#mask-types} symmetric : Structured Mask. : The mask is a structured random : Unstructured Mask. : The mask as a whole is a random mask of a given sparsity, with some neurons having fewer/more connections than others. per-neuron : Unstructured Mask. : Each neuron has the same sparsity (# of masked connections), but is shuffled randomly. per-neuron-no-input-ablation: : Unstructured Mask. : As with per-neuron, each neuron has the same sparsity, but randomly shuffled connections. Also at least one connection is maintained to each of the input neurons (i.e. the input neurons are not effectively ablated), although these connections are also randomly shuffled amongst the neurons of a given layer. ### Model Types {#model-types} MNIST_FC : A small fully-connected model, accepting number of neurons and depth as parameters. No batch normalization, configurable drop-out rate (default: 0). MNIST_CNN : A small convolutional model designed for MNIST, accepting number of filters for each layer and depth as parameters. Uses batch normalization and configurable drop-out rate (default: 0). CIFAR10_CNN : A larger convolutional model designed for CIFAR10, accepting number of filters for each layer and depth as parameters. No batch normalization, configurable drop-out rate (default: 0). ### Dataset Types {#dataset-types} MNIST : Wrapper of the Tensorflow Datasets (TFDS) MNIST dataset. CIFAR10 : Wrapper of the Tensorflow Datasets (TFDS) CIFAR10 dataset. ## Running Experiments ### Running on a Workstation Train: ```shell python -m weight_symmetry:${EXPERIMENT_TYPE} ``` ## Result Processing/Analysis ### Plotting Results from a JSON Summary File You can convert the results to a Pandas dataframe from a JSON summary file for plotting/analysis using the example colab in `analysis/plot_summary_json.ipynb`. ================================================ FILE: rigl/experimental/jax/__init__.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This module contains code for weight symmetry experiments.""" name = 'weight_symmetry' ================================================ FILE: rigl/experimental/jax/analysis/plot_summary_json.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6iEEw5OwSlnz" }, "source": [ "# Plot Results from an Experiment Summary JSON File", "Licensed under the Apache License, Version 2.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Eg6FmoCaTCHM" }, "source": [ "## Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "ML0hUJMzYF0W" }, "outputs": [], "source": [ "from google.colab import files\n", "\n", "# Experiment summary filenames (one per experiment)\n", "SUMMARY_FILES = files.upload()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "MHubbscQSLGm" }, "outputs": [], "source": [ "# Labels to use for each of the summaries listed above (in the same order!)\n", "XID_LABELS=['structured', 'unstructured'] #@param" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "x0jDBWKdU_2A" }, "source": [ "## Loading of JSON Summary/Conversion to Pandas Dataframe" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "Lz-HwS1tU-ie" }, "outputs": [], "source": [ "import json\n", "import pandas as pd\n", "import os\n", "\n", "from colabtools.interactive_widgets import ProgressIter\n", "\n", "dfs = []\n", "for i, summary_file in enumerate(SUMMARY_FILES):\n", " with open(summary_file) as summary_file:\n", " data = json.load(summary_file)\n", " dataframe = pd.DataFrame.from_dict(data, orient='index')\n", " dataframe['experiment_label'] = XID_LABELS[i]\n", " dfs.append(dataframe)\n", "\n", "df=pd.concat(dfs)\n", "\n", "print('Loaded {} rows for experiment'.format(len(data)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DhO6oT1nVpTV" }, "source": [ "## Measurements and Labels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "XFRR3XrXVopB" }, "outputs": [], "source": [ "DATA_LABELS={\n", " 'best_train_loss/test_accuracy': 'Test Accuracy (of best train loss)',\n", " 'best_train_loss/train_accuracy': 'Train Accuracy (of best train loss)',\n", " 'best_train_loss/test_avg_loss': 'Test Loss (of best train loss)',\n", " 'best_train_loss/train_avg_loss': 'Train Loss (of best train loss)',\n", " 'best_train_loss/step': 'Training Iterations (of best train loss)',\n", " 'best_train_loss/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best train loss)',\n", " 'best_train_loss/vector_difference_norm': 'Vector Difference Norm. (of best train loss)',\n", " 'best_train_loss/cosine_distance': 'Cosine Similarity (of best train loss)',\n", " 'best_test_acc/test_accuracy': 'Test Accuracy (of best test acc.)',\n", " 'best_test_acc/train_accuracy': 'Train Accuracy (of best test acc.)',\n", " 'best_test_acc/test_avg_loss': 'Test Loss (of best test acc.)',\n", " 'best_test_acc/train_avg_loss': 'Train Loss (of best test acc.)',\n", " 'best_test_acc/step': 'Training Iterations (of best test acc.)',\n", " 'best_test_acc/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best Test Acc.)',\n", " 'best_test_acc/cosine_distance': 'Cosine Similarity (of best Test Acc.)',\n", " 'best_test_acc/vector_difference_norm': 'Vector Difference Norm. (of best Test Acc.)',\n", " 'mask/sparsity': 'Sparsity',\n", " 'mask/unique_neurons': '# Unique Neurons',\n", " 'mask/zeroed_neurons': '# Zeroed Neurons',\n", " 'mask/permutation_log10': 'log10(1 + Permutations)',\n", " 'mask/permutation_num_digits': 'Permutation # of Digits',\n", " 'mask/permutations': 'Permutation',\n", " 'mask/total_neurons': 'Total # of Neurons',\n", " 'propagated_mask/sparsity': 'Mask Sparsity',\n", " 'propagated_mask/unique_neurons': '# Unique Neurons (prop.)',\n", " 'propagated_mask/zeroed_neurons': '# Zeroed Neurons (prop.)',\n", " 'propagated_mask/permutation_log10': 'log10(1 + Permutations) (prop.)',\n", " 'propagated_mask/permutation_num_digits': 'Permutation # of Digits (prop.)',\n", " 'propagated_mask/permutations': 'Mask Permutations',\n", " 'propagated_mask/total_neurons': 'Total # of Neurons (prop.)',\n", " 'training/train_avg_loss': 'Train Loss',\n", "}" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HAVkz8ZzV0Hd" }, "source": [ "# Seaborn Plot Example" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "eoxoJH4gWHbb" }, "outputs": [], "source": [ "# Choose the X/Y/Z labels from the parameter list above.\n", "X_LABEL='propagated_mask/sparsity' #@param {type:\"string\"}\n", "Y_LABEL='best_train_loss/cumulative_gradient_norm' #@param {type:\"string\"}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "pudAXLl1VzFl" }, "outputs": [], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# Seaborn style - remove outer plot ticks, white plot background.\n", "np.set_printoptions(linewidth=128, precision=3, edgeitems=5)\n", "sns.set_style(\"whitegrid\")\n", "sns.color_palette(\"muted\")\n", "sns.set_context(\"paper\", font_scale=1, rc={\n", " \"lines.linewidth\": 1.2,\n", " \"xtick.major.size\": 0,\n", " \"xtick.minor.size\": 0,\n", " \"ytick.major.size\": 0,\n", " \"ytick.minor.size\": 0\n", "})\n", "\n", "# Higher resolution plots\n", "%config InlineBackend.figure_format = 'retina'" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "lYUK9xi_aym3" }, "source": [ "### Plot Raw Data Points" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "uWcT76L6Wbv6" }, "outputs": [], "source": [ "\n", "plt.figure(figsize=(16,8))\n", "axis = sns.scatterplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', s=50, alpha=.5)\n", "axis.set_ylabel(DATA_LABELS[Y_LABEL])\n", "axis.set_xlabel(DATA_LABELS[X_LABEL])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Kws6tjfTa7h0" }, "source": [ "### Plot Mean/StdDev" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "jR04tmMnaxjG" }, "outputs": [], "source": [ "plt.figure(figsize=(16,8))\n", "axis = sns.lineplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', alpha=.5, ci=\"sd\", markers=True)\n", "axis.set_ylabel(DATA_LABELS[Y_LABEL])\n", "axis.set_xlabel(DATA_LABELS[X_LABEL])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "jyNFtKQfajiq" }, "outputs": [], "source": [ "# Code to save output files for publication.\n", "PARAM_STR=X_LABEL.replace('/', '-')+'_'+Y_LABEL.replace('/', '-')\n", "\n", "OUT_FILE_PDF=f'/tmp/{PARAM_STR}.pdf'\n", "OUT_FILE_SVG=f'/tmp/{PARAM_STR}.svg'\n", "OUT_FILE_PNG=f'/tmp/{PARAM_STR}.png'\n", "\n", "plt.savefig(OUT_FILE_PDF, pi=600)\n", "files.download(OUT_FILE_PDF)\n", "\n", "plt.savefig(OUT_FILE_SVG)\n", "files.download(OUT_FILE_SVG)\n", "\n", "plt.savefig(OUT_FILE_PNG)\n", "files.download(OUT_FILE_PNG)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "plot_summary_json", "provenance": [ { "file_id": "1g2aTwv76XMrLfEwryfj_tGzNnvZWjIVl", "timestamp": 1600990155741 } ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: rigl/experimental/jax/datasets/__init__.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: rigl/experimental/jax/datasets/cifar10.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """CIFAR10 Dataset. Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS) with JAX/FLAX, by defining a bunch of wrappers, including preprocessing. In this case, the CIFAR10 dataset. """ from typing import MutableMapping, Sequence from rigl.experimental.jax.datasets import dataset_base import tensorflow.compat.v2 as tf class CIFAR10Dataset(dataset_base.ImageDataset): """CIFAR10 dataset. Attributes: NAME: The Tensorflow Dataset's dataset name. """ NAME: str = 'cifar10' # Computed from the training set by taking the per-channel mean/std-dev # over sample, height and width axes of all training samples. MEAN_RGB: Sequence[float] = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255] STDDEV_RGB: Sequence[float] = [0.2470 * 255, 0.2435 * 255, 0.2616 * 255] def __init__(self, batch_size, batch_size_test, shuffle_buffer_size = 1024, seed = 42): """CIFAR10 dataset. Args: batch_size: The batch size to use for the training datasets. batch_size_test: The batch size used for the test dataset. shuffle_buffer_size: The buffer size to use for dataset shuffling. seed: The random seed used to shuffle. Returns: Dataset: A dataset object. Raises: ValueError: If the test dataset is not evenly divisible by the test batch size. """ super().__init__(CIFAR10Dataset.NAME, batch_size, batch_size_test, shuffle_buffer_size, seed) if self.get_test_len() % batch_size_test != 0: raise ValueError( 'Test data not evenly divisible by batch size: {} % {} != 0.'.format( self.get_test_len(), batch_size_test)) def preprocess( self, data): """Normalizes CIFAR10 images: `uint8` -> `float32`. Args: data: Data sample. Returns: Data after being augmented/normalized/transformed. """ data = super().preprocess(data) mean_rgb = tf.constant(self.MEAN_RGB, shape=[1, 1, 3], dtype=tf.float32) std_rgb = tf.constant(self.STDDEV_RGB, shape=[1, 1, 3], dtype=tf.float32) data['image'] = (tf.cast(data['image'], tf.float32) - mean_rgb) / std_rgb return data ================================================ FILE: rigl/experimental/jax/datasets/cifar10_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.datasets.cifar10.""" from absl.testing import absltest import numpy as np from rigl.experimental.jax.datasets import cifar10 class CIFAR10DatasetTest(absltest.TestCase): """Test cases for CIFAR10 Dataset.""" def setUp(self): """Common setup routines/variables for test cases.""" super().setUp() self._batch_size = 16 self._batch_size_test = 10 self._shuffle_buffer_size = 8 self._dataset = cifar10.CIFAR10Dataset( self._batch_size, batch_size_test=self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) def test_create_dataset(self): """Tests creation of dataset.""" self.assertIsInstance(self._dataset, cifar10.CIFAR10Dataset) def test_train_image_dims_content(self): """Tests dimensions and contents of test data.""" iterator = self._dataset.get_train() sample = next(iterator) image, label = sample['image'], sample['label'] with self.subTest(name='DataShape'): self.assertTupleEqual(image.shape, (self._batch_size, 32, 32, 3)) with self.subTest(name='DataType'): self.assertTrue(np.issubdtype(image.dtype, float)) with self.subTest(name='DataValues'): # Normalized by stddev., expect nothing to fall outside 3 stddev. self.assertTrue((image >= -3.).all() and (image <= 3.).all()) with self.subTest(name='LabelShape'): self.assertLen(label, self._batch_size) with self.subTest(name='LabelType'): self.assertTrue(np.issubdtype(label.dtype, int)) with self.subTest(name='LabelValues'): self.assertTrue((label >= 0).all() and (label <= self._dataset.num_classes).all()) def test_test_image_dims_content(self): """Tests dimensions and contents of train data.""" iterator = self._dataset.get_test() sample = next(iterator) image, label = sample['image'], sample['label'] with self.subTest(name='DataShape'): self.assertTupleEqual(image.shape, (self._batch_size_test, 32, 32, 3)) with self.subTest(name='DataType'): self.assertTrue(np.issubdtype(image.dtype, float)) with self.subTest(name='DataValues'): # Normalized by stddev., expect nothing to fall outside 3 stddev. self.assertTrue((image >= -3.).all() and (image <= 3.).all()) with self.subTest(name='LabelShape'): self.assertLen(label, self._batch_size_test) with self.subTest(name='LabelType'): self.assertTrue(np.issubdtype(label.dtype, int)) with self.subTest(name='LabelValues'): self.assertTrue((label >= 0).all() and (label <= self._dataset.num_classes).all()) def test_train_data_length(self): """Tests length of training dataset.""" total_count = 0 for batch in self._dataset.get_train(): total_count += len(batch['label']) self.assertEqual(total_count, self._dataset.get_train_len()) def test_test_data_length(self): """Tests length of test dataset.""" total_count = 0 for batch in self._dataset.get_test(): total_count += len(batch['label']) self.assertEqual(total_count, self._dataset.get_test_len()) def test_dataset_nonevenly_divisible_batch_size(self): """Tests non-evenly divisible test batch size.""" with self.assertRaisesRegex( ValueError, 'Test data not evenly divisible by batch size: .*'): self._dataset = cifar10.CIFAR10Dataset( self._batch_size, batch_size_test=101) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/datasets/dataset_base.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dataset Classes. Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS) with JAX/FLAX, by defining a bunch of wrappers, including preprocessing. """ import abc from typing import MutableMapping, Optional import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds class Dataset(metaclass=abc.ABCMeta): """Base class for datasets. Attributes: DATAKEY: The key used for the data component of a Tensorflow Dataset (TFDS) sample, e.g. 'image' for image datasets. LABELKEY: The key used fot the label component of a Tensorflow Dataset sample, i.e. 'label'. name: The TFDS name of the dataset. batch_size: The batch size to use for the training dataset. batch_size_test: The batch size to use for the test dataset. num_classes: the number of supervised classes in the dataset. shape: the shape of an input data array. """ DATAKEY: Optional[str] = None LABELKEY: str = 'label' def __init__(self, name, batch_size, batch_size_test, shuffle_buffer_size, prefetch_size = 1, seed = None): # pytype: disable=annotation-type-mismatch """Base class for datasets. Args: name: The TFDS name of the dataset. batch_size: The batch size to use for the training dataset. batch_size_test: The batch size to use for the test dataset. shuffle_buffer_size: The buffer size to use for dataset shuffling. prefetch_size: The number of mini-batches to prefetch. seed: The random seed used to shuffle. Returns: A Dataset object. """ super().__init__() self.name = name self.batch_size = batch_size self.batch_size_test = batch_size_test self._shuffle_buffer_size = shuffle_buffer_size self._prefetch_size = prefetch_size self._train_ds, self._train_info = tfds.load( self.name, split=tfds.Split.TRAIN, data_dir=self._dataset_dir(), with_info=True) self._train_ds = self._train_ds.shuffle( self._shuffle_buffer_size, seed).map(self.preprocess).cache().map(self.augment).batch( self.batch_size, drop_remainder=True).prefetch(self._prefetch_size) self._test_ds, self._test_info = tfds.load( self.name, split=tfds.Split.TEST, data_dir=self._dataset_dir(), with_info=True) self._test_ds = self._test_ds.map(self.preprocess).cache().batch( self.batch_size_test).prefetch(self._prefetch_size) self.num_classes = self._train_info.features['label'].num_classes self.shape = self._train_info.features['image'].shape def _dataset_dir(self): """Returns the dataset path for the TFDS data.""" return None def get_train(self): """Returns the training dataset.""" return iter(tfds.as_numpy(self._train_ds)) def get_train_len(self): """Returns the length of the training dataset.""" return self._train_info.splits['train'].num_examples def get_test(self): """Returns the test dataset.""" return iter(tfds.as_numpy(self._test_ds)) def get_test_len(self): """Returns the length of the test dataset.""" return self._test_info.splits['test'].num_examples def preprocess( self, data): """Preprocessing fn used by TFDS map for normalization. This function is for transformations that can be cached, e.g. normalization/whitening. Args: data: Data sample. Returns: Data after being normalized/transformed. """ return data def augment( self, data): """Preprocessing fn used by TFDS map for augmentation at training time. This function is for transformations that should not be cached, e.g. random augmentation that should change for every sample, and are only applied at training time. Args: data: Data sample. Returns: Data after being augmented/transformed. """ return data class ImageDataset(Dataset): """Base class for image datasets.""" DATAKEY = 'image' def preprocess( self, data): """Preprocessing function used by TFDS map for normalization. This function is for transformations that can be cached, e.g. normalization/whitening. Args: data: Data sample. Returns: Data after being normalized/transformed. """ data = super().preprocess(data) # Ensure we only provide the image and label, stripping out other keys. return dict((key, val) for key, val in data.items() if key in [self.LABELKEY, self.DATAKEY]) ================================================ FILE: rigl/experimental/jax/datasets/dataset_base_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.datasets.dataset_base.""" from absl.testing import absltest from rigl.experimental.jax.datasets import dataset_base class DummyDataset(dataset_base.ImageDataset): """A dummy implementation of the abstract dataset class. Attributes: NAME: The Tensorflow Dataset's dataset name. """ NAME: str = 'mnist' def __init__(self, batch_size, batch_size_test, shuffle_buffer_size = 1024, seed = 42): """Dummy MNIST dataset. Args: batch_size: The batch size to use for the training datasets. batch_size_test: The batch size to used for the test dataset. shuffle_buffer_size: The buffer size to use for dataset shuffling. seed: The random seed used to shuffle. Returns: Dataset: A dataset object. """ super().__init__(DummyDataset.NAME, batch_size, batch_size_test, shuffle_buffer_size, seed) class DummyDatasetTest(absltest.TestCase): """Test cases for dummy dataset.""" def setUp(self): """Common setup routines/variables for test cases.""" super().setUp() self._batch_size = 16 self._batch_size_test = 10 self._shuffle_buffer_size = 8 self._dataset = DummyDataset( self._batch_size, batch_size_test=self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) def test_create_dataset(self): """Tests creation of dataset.""" self.assertIsInstance(self._dataset, DummyDataset) def test_train_image_dims_content(self): """Tests dimensions and contents of test data.""" iterator = iter(self._dataset.get_train()) sample = next(iterator) image, label = sample['image'], sample['label'] with self.subTest(name='data_shape'): self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1)) with self.subTest(name='data_values'): self.assertBetween(image.all(), 0, 256) with self.subTest(name='label_shape'): self.assertLen(label, self._batch_size) with self.subTest(name='label_values'): self.assertBetween(label.all(), 0, self._dataset.num_classes) def test_test_image_dims_content(self): """Tests dimensions and contents of train data.""" iterator = iter(self._dataset.get_test()) sample = next(iterator) image, label = sample['image'], sample['label'] with self.subTest(name='data_shape'): self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1)) with self.subTest(name='data_values'): self.assertBetween(image.all(), 0, 256) with self.subTest(name='label_shape'): self.assertLen(label, self._batch_size_test) with self.subTest(name='label_values'): self.assertBetween(label.all(), 0, self._dataset.num_classes) def test_train_data_length(self): """Tests length of training dataset.""" total_count = 0 for batch in self._dataset.get_train(): total_count += len(batch['label']) self.assertEqual(total_count, self._dataset.get_train_len()) def test_test_data_length(self): """Tests length of test dataset.""" total_count = 0 for batch in self._dataset.get_test(): total_count += len(batch['label']) # Check image size/content. self.assertEqual(total_count, self._dataset.get_test_len()) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/datasets/dataset_factory.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dataset Factory. Dataset factory to allow us to easily use tensorflow datasets (TFDS) with JAX/FLAX, by defining a bunch of wrappers, including preprocessing. Attributes: DATASETS: A list of the datasets that can be created. """ from typing import Any, Mapping, Type from rigl.experimental.jax.datasets import cifar10 from rigl.experimental.jax.datasets import dataset_base from rigl.experimental.jax.datasets import mnist import tensorflow.compat.v2 as tf DATASETS: Mapping[str, Type[dataset_base.Dataset]] = { 'MNIST': mnist.MNISTDataset, 'CIFAR10': cifar10.CIFAR10Dataset, } def create_dataset(name, *args, **kwargs): """Creates a Tensorflow datasets (TFDS) dataset. Args: name: The TFDS name of the dataset. *args: Dataset arguments. **kwargs: Dataset keyword arguments. Returns: Dataset: An abstracted dataset object. Raises: ValueError if a dataset with the given name does not exist. """ if name not in DATASETS: raise ValueError(f'No such dataset: {name}') return DATASETS[name](*args, **kwargs) ================================================ FILE: rigl/experimental/jax/datasets/dataset_factory_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.datasets.dataset_common.""" from absl.testing import absltest from absl.testing import parameterized import numpy as np from rigl.experimental.jax.datasets import dataset_base from rigl.experimental.jax.datasets import dataset_factory class DatasetCommonTest(parameterized.TestCase): def setUp(self): super().setUp() self._batch_size = 32 self._batch_size_test = 10 self._shuffle_buffer_size = 128 def _create_dataset(self, dataset_name): """Helper function for creating a dataset.""" return dataset_factory.create_dataset( dataset_name, self._batch_size, self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) def test_dataset_supported(self): """Tests supported datasets.""" for dataset_name in dataset_factory.DATASETS: dataset = self._create_dataset(dataset_name) self.assertIsInstance(dataset, dataset_base.Dataset) @parameterized.parameters(*dataset_factory.DATASETS.keys()) def test_dataset_train_iterators(self, dataset_name): """Tests dataset's train iterator.""" dataset = self._create_dataset(dataset_name) sample = next(dataset.get_train()) with self.subTest(name='{}_sample'.format(dataset_name)): self.assertNotEmpty(sample) with self.subTest(name='{}_label_type'.format(dataset_name)): self.assertIsInstance(sample['label'], np.ndarray) with self.subTest(name='{}_label_batch_size'.format(dataset_name)): self.assertLen(sample['label'], self._batch_size) with self.subTest(name='{}_image_type'.format(dataset_name)): self.assertIsInstance(sample['image'], np.ndarray) with self.subTest(name='{}_image_shape'.format(dataset_name)): self.assertLen(sample['image'].shape, 4) with self.subTest(name='{}_image_batch_size'.format(dataset_name)): self.assertEqual(sample['image'].shape[0], self._batch_size) with self.subTest( name='{}_non_zero_image_dimensions'.format(dataset_name)): self.assertGreater(sample['image'].shape[1], 1) @parameterized.parameters(*dataset_factory.DATASETS.keys()) def test_dataset_test_iterators(self, dataset_name): """Tests dataset's test iterator.""" dataset = self._create_dataset(dataset_name) sample = next(dataset.get_test()) with self.subTest(name='{}_sample'.format(dataset_name)): self.assertNotEmpty(sample) with self.subTest(name='{}_label_type'.format(dataset_name)): self.assertIsInstance(sample['label'], np.ndarray) with self.subTest(name='{}_label_batch_size'.format(dataset_name)): self.assertLen(sample['label'], self._batch_size_test) with self.subTest(name='{}_image_type'.format(dataset_name)): self.assertIsInstance(sample['image'], np.ndarray) with self.subTest(name='{}_image_shape'.format(dataset_name)): self.assertLen(sample['image'].shape, 4) with self.subTest(name='{}_image_batch_size'.format(dataset_name)): self.assertEqual(sample['image'].shape[0], self._batch_size_test) with self.subTest( name='{}_non_zero_image_dimensions'.format(dataset_name)): self.assertGreater(sample['image'].shape[1], 1) def test_dataset_unsupported(self): """Tests unsupported datasets.""" with self.assertRaisesRegex(ValueError, 'No such dataset: unsupported'): self._create_dataset('unsupported') if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/datasets/mnist.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNIST Dataset. Dataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS) with JAX/FLAX, by defining a bunch of wrappers, including preprocessing. In this case, the MNIST dataset. """ from typing import MutableMapping from rigl.experimental.jax.datasets import dataset_base import tensorflow.compat.v2 as tf class MNISTDataset(dataset_base.ImageDataset): """MNIST dataset. Attributes: NAME: The Tensorflow Dataset's dataset name. """ NAME: str = 'mnist' def __init__(self, batch_size, batch_size_test, shuffle_buffer_size = 1024, seed = 42): """MNIST dataset. Args: batch_size: The batch size to use for the training datasets. batch_size_test: The batch size to used for the test dataset. shuffle_buffer_size: The buffer size to use for dataset shuffling. seed: The random seed used to shuffle. Returns: Dataset: A dataset object. """ super().__init__(MNISTDataset.NAME, batch_size, batch_size_test, shuffle_buffer_size, seed) def preprocess( self, data): """Normalizes MNIST images: `uint8` -> `float32`. Args: data: Data sample. Returns: Data after being augmented/normalized/transformed. """ data = super().preprocess(data) data['image'] = (tf.cast(data['image'], tf.float32) / 255.) - 0.5 return data ================================================ FILE: rigl/experimental/jax/datasets/mnist_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.datasets.mnist.""" from absl.testing import absltest import numpy as np from rigl.experimental.jax.datasets import mnist class MNISTDatasetTest(absltest.TestCase): """Test cases for MNIST Dataset.""" def setUp(self): """Common setup routines/variables for test cases.""" super().setUp() self._batch_size = 16 self._batch_size_test = 10 self._shuffle_buffer_size = 8 self._dataset = mnist.MNISTDataset( self._batch_size, batch_size_test=self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) def test_create_dataset(self): """Tests creation of dataset.""" self.assertIsInstance(self._dataset, mnist.MNISTDataset) def test_train_image_dims_content(self): """Tests dimensions and contents of test data.""" iterator = self._dataset.get_train() sample = next(iterator) image, label = sample['image'], sample['label'] with self.subTest(name='data_shape'): self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1)) with self.subTest(name='data_values'): self.assertTrue((image >= -1.).all() and (image <= 1.).all()) with self.subTest(name='data_type'): self.assertTrue(np.issubdtype(image.dtype, float)) with self.subTest(name='label_shape'): self.assertLen(label, self._batch_size) with self.subTest(name='label_type'): self.assertTrue(np.issubdtype(label.dtype, int)) with self.subTest(name='label_values'): self.assertTrue((label >= 0).all() and (label <= self._dataset.num_classes).all()) def test_test_image_dims_content(self): """Tests dimensions and contents of train data.""" iterator = self._dataset.get_test() sample = next(iterator) image, label = sample['image'], sample['label'] with self.subTest(name='data_shape'): self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1)) with self.subTest(name='data_type'): self.assertTrue(np.issubdtype(image.dtype, float)) # TODO: Find a better approach to testing with JAX arrays. with self.subTest(name='data_values'): self.assertTrue((image >= -1.).all() and (image <= 1.).all()) with self.subTest(name='label_shape'): self.assertLen(label, self._batch_size_test) with self.subTest(name='label_type'): self.assertTrue(np.issubdtype(label.dtype, int)) with self.subTest(name='label_values'): self.assertTrue((label >= 0).all() and (label <= self._dataset.num_classes).all()) def test_train_data_length(self): """Tests length of training dataset.""" total_count = 0 for batch in self._dataset.get_train(): total_count += len(batch['label']) self.assertEqual(total_count, self._dataset.get_train_len()) def test_test_data_length(self): """Tests length of test dataset.""" total_count = 0 for batch in self._dataset.get_test(): total_count += len(batch['label']) # Check image size/content. self.assertEqual(total_count, self._dataset.get_test_len()) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/fixed_param.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Weight Symmetry: Train models with fixed param, but diff. depth and width.""" import ast import functools import operator from os import path from typing import List, Sequence import uuid from absl import app from absl import flags from absl import logging import flax from flax.metrics import tensorboard from flax.training import lr_schedule import jax import jax.numpy as jnp from rigl.experimental.jax.datasets import dataset_factory from rigl.experimental.jax.models import mnist_fc from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.pruning import masked from rigl.experimental.jax.pruning import symmetry from rigl.experimental.jax.training import training from rigl.experimental.jax.utils import utils experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', MODEL, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape input_len = functools.reduce(operator.mul, dataset.shape) features = mnist_fc.feature_dim_for_param( input_len, FLAGS.param_count, FLAGS.depth) logging.info('Model Configuration: %s', str(features)) base_model, _ = model_factory.create_model( MODEL, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, features=features) model_param_count = utils.count_param(base_model, ('kernel',)) logging.info( 'Model Config: param.: %d, depth: %d. max width: %d, min width: %d', model_param_count, len(features), max(features), min(features)) logging.info('Generating random mask based on model') # Re-initialize the RNG to maintain same training pattern (as in prune code). mask_rng = jax.random.PRNGKey(FLAGS.random_seed) mask = masked.shuffled_mask( base_model, rng=mask_rng, sparsity=FLAGS.mask_sparsity) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json( mask_stats, path.join(experiment_dir, 'mask_stats.json')) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json(mask_stats, path.join(experiment_dir, 'mask_stats.json')) model_stats = { 'depth': len(features), 'max_width': max(features), 'min_width': min(features), } model_stats.update( {'feature_{}'.format(i): value for i, value in enumerate(features)}) if FLAGS.dump_json: utils.dump_dict_json(model_stats, path.join(experiment_dir, 'model_stats.json')) model, initial_state = model_factory.create_model( 'MNIST_FC', rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, features=features, masks=mask) if FLAGS.opt == 'Adam': optimizer = flax.optim.Adam( learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.opt == 'Momentum': optimizer = flax.optim.Momentum( learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) else: raise ValueError('Unknown Optimizer: {}'.format(FLAGS.opt)) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule)) if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train( FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch, ) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: if FLAGS.dump_json: utils.dump_dict_json(best_metrics, path.join(experiment_dir, 'best_metrics.json')) for label, value in best_metrics.items(): summary_writer.scalar('best/{}'.format(label), value, FLAGS.epochs * steps_per_epoch) summary_writer.close() def main(argv: List[str]): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') run_training() if __name__ == '__main__': app.run(main) ================================================ FILE: rigl/experimental/jax/fixed_param_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.fixed_param.""" import glob from os import path import tempfile from absl.testing import absltest from absl.testing import flagsaver from rigl.experimental.jax import fixed_param class FixedParamTest(absltest.TestCase): def test_run(self): """Tests if the driver for shuffled training runs correctly.""" experiment_dir = tempfile.mkdtemp() eval_flags = dict( epochs=1, experiment_dir=experiment_dir, ) with flagsaver.flagsaver(**eval_flags): fixed_param.main([]) with self.subTest(name='tf_summary_file_exists'): outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/models/__init__.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: rigl/experimental/jax/models/cifar10_cnn.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """CIFAR10 CNN. A small CNN for the CIFAR10 dataset, consists of a number of convolutional layers (determined by length of filters parameter), followed by a fully-connected layer. """ from typing import Callable, Mapping, Optional, Sequence from absl import logging import flax import jax.numpy as jnp from rigl.experimental.jax.pruning import init from rigl.experimental.jax.pruning import masked class CIFAR10CNN(flax.deprecated.nn.Module): """Small CIFAR10 CNN.""" def apply(self, inputs, num_classes, filter_shape = (3, 3), filters = (32, 32, 64, 64, 128, 128), init_fn=flax.deprecated.nn.initializers.kaiming_normal, train=True, activation_fn = flax.deprecated.nn.relu, masks = None, masked_layer_indices = None): """Applies a convolution to the inputs. Args: inputs: Input data with dimensions (batch, spatial_dims..., features). num_classes: Number of classes in the dataset. filter_shape: Shape of the convolutional filters. filters: Number of filters in each convolutional layer, and number of conv layers (given by length of sequence). init_fn: Initialization function used for convolutional layers. train: If model is being evaluated in training mode or not. activation_fn: Activation function to be used for convolutional layers. masks: Masks of the layers in this model, in the same form as module params, or None. masked_layer_indices: The layer indices of layers in model to be masked. Returns: A tensor of shape (batch, num_classes), containing the logit output. Raises: ValueError if the number of pooling layers is too many for the given input size, or if the provided mask is not of the correct depth for the model. """ # Note: First dim is batch, last dim is channels, other dims are "spatial". if not all([(dim >= 2**(len(filters)//2)) for dim in inputs.shape[1:-2]]): raise ValueError( 'Input spatial size, {}, does not allow {} pooling layers.'.format( str(inputs.shape[1:-2]), len(filters)) ) depth = 1 + len(filters) masks = masked.generate_model_masks(depth, masks, masked_layer_indices) batch_norm = flax.deprecated.nn.BatchNorm.partial( use_running_average=not train, momentum=0.99, epsilon=1e-5) for i, filter_num in enumerate(filters): if f'MaskedModule_{i}' in masks: logging.info('Layer %d is masked in model', i) mask = masks[f'MaskedModule_{i}'] inputs = masked.masked(flax.deprecated.nn.Conv, mask)( inputs, features=filter_num, kernel_size=filter_shape, kernel_init=init.sparse_init( init_fn(), mask['kernel'] if mask is not None else None)) else: inputs = flax.deprecated.nn.Conv( inputs, features=filter_num, kernel_size=filter_shape, kernel_init=init_fn()) inputs = batch_norm(inputs, name='bn_conv_{}'.format(i)) inputs = activation_fn(inputs) if i % 2 == 1: inputs = flax.deprecated.nn.max_pool( inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID') # Global average pooling if we have spatial dimensions left. inputs = flax.deprecated.nn.avg_pool( inputs, window_shape=(inputs.shape[1:-1]), padding='VALID') inputs = inputs.reshape((inputs.shape[0], -1)) # This is effectively a Dense layer, but we cast it as a convolution layer # to allow us to easily propagate masks, avoiding b/156135283. inputs = flax.deprecated.nn.Conv( inputs, features=num_classes, kernel_size=inputs.shape[1:-1], kernel_init=flax.deprecated.nn.initializers.xavier_normal()) inputs = batch_norm(inputs, name='bn_dense_1') inputs = jnp.squeeze(inputs) return flax.deprecated.nn.log_softmax(inputs) ================================================ FILE: rigl/experimental/jax/models/cifar10_cnn_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.models.cifar10_cnn.""" from absl.testing import absltest import flax import jax import jax.numpy as jnp from rigl.experimental.jax.models import cifar10_cnn class CIFAR10CNNTest(absltest.TestCase): """Tests the CIFAR10CNN model.""" def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._num_classes = 10 self._batch_size = 2 self._input_shape = ((self._batch_size, 32, 32, 3), jnp.float32) self._input = jnp.zeros(*self._input_shape) def test_output_shapes(self): """Tests the output shapes of the model.""" with flax.deprecated.nn.stateful() as initial_state: _, initial_params = cifar10_cnn.CIFAR10CNN.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes) model = flax.deprecated.nn.Model(cifar10_cnn.CIFAR10CNN, initial_params) with flax.deprecated.nn.stateful(initial_state, mutable=False): logits = model(self._input, num_classes=self._num_classes, train=False) self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes)) def test_invalid_spatial_dimensions(self): """Tests model with an invalid spatial dimension parameters.""" with self.assertRaisesRegex(ValueError, 'Input spatial size, '): cifar10_cnn.CIFAR10CNN.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes, filters=20 * (32,)) def test_invalid_masks_depth(self): """Tests model mask with the incorrect depth for the given model.""" invalid_masks = { 'MaskedModule_0': { 'kernel': jnp.zeros((self._batch_size, 3, 3, 32)) } } with self.assertRaisesRegex( ValueError, 'Mask is invalid for model.'): cifar10_cnn.CIFAR10CNN.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes, masks=invalid_masks) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/models/mnist_cnn.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNIST CNN. A small CNN for the MNIST dataset, consists of a number of convolutional layers (determined by length of filters parameter), followed by a fully-connected layer. """ from typing import Callable, Mapping, Optional, Sequence from absl import logging import flax import jax.numpy as jnp from rigl.experimental.jax.pruning import init from rigl.experimental.jax.pruning import masked class MNISTCNN(flax.deprecated.nn.Module): """Small MNIST CNN.""" def apply(self, inputs, num_classes, filter_shape = (5, 5), filters = (16, 32), dense_size = 64, train=True, init_fn = flax.deprecated.nn.initializers.kaiming_normal, activation_fn = flax.deprecated.nn.relu, masks = None, masked_layer_indices = None): """Applies a convolution to the inputs. Args: inputs: Input data with dimensions (batch, spatial_dims..., features). num_classes: Number of classes in the dataset. filter_shape: Shape of the convolutional filters. filters: Number of filters in each convolutional layer, and number of conv layers (given by length of sequence). dense_size: Number of filters in each convolutional layer, and number of conv layers (given by length of sequence). train: If model is being evaluated in training mode or not. init_fn: Initialization function used for convolutional layers. activation_fn: Activation function to be used for convolutional layers. masks: Masks of the layers in this model, in the same form as module params, or None. masked_layer_indices: The layer indices of layers in model to be masked. Returns: A tensor of shape (batch, num_classes), containing the logit output. Raises: ValueError if the number of pooling layers is too many for the given input size. """ # Note: First dim is batch, last dim is channels, other dims are "spatial". if not all([(dim >= 2**len(filters)) for dim in inputs.shape[1:-2]]): raise ValueError( 'Input spatial size, {}, does not allow {} pooling layers.'.format( str(inputs.shape[1:-2]), len(filters)) ) depth = 2 + len(filters) masks = masked.generate_model_masks(depth, masks, masked_layer_indices) batch_norm = flax.deprecated.nn.BatchNorm.partial( use_running_average=not train, momentum=0.99, epsilon=1e-5) for i, filter_num in enumerate(filters): if f'MaskedModule_{i}' in masks: logging.info('Layer %d is masked in model', i) mask = masks[f'MaskedModule_{i}'] inputs = masked.masked(flax.deprecated.nn.Conv, mask)( inputs, features=filter_num, kernel_size=filter_shape, kernel_init=init.sparse_init( init_fn(), mask['kernel'] if mask is not None else None)) else: inputs = flax.deprecated.nn.Conv( inputs, features=filter_num, kernel_size=filter_shape, kernel_init=init_fn()) inputs = batch_norm(inputs, name='bn_conv_{}'.format(i)) inputs = activation_fn(inputs) if i < len(filters) - 1: inputs = flax.deprecated.nn.max_pool( inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID') # Global average pool at end of convolutional layers. inputs = flax.deprecated.nn.avg_pool( inputs, window_shape=inputs.shape[1:-1], padding='VALID') # This is effectively a Dense layer, but we cast it as a convolution layer # to allow us to easily propagate masks, avoiding b/156135283. if f'MaskedModule_{depth - 2}' in masks: mask_dense_1 = masks[f'MaskedModule_{depth - 2}'] inputs = masked.masked(flax.deprecated.nn.Conv, mask_dense_1)( inputs, features=dense_size, kernel_size=inputs.shape[1:-1], kernel_init=init.sparse_init( init_fn(), mask_dense_1['kernel'] if mask_dense_1 is not None else None)) else: inputs = flax.deprecated.nn.Conv( inputs, features=dense_size, kernel_size=inputs.shape[1:-1], kernel_init=init_fn()) inputs = batch_norm(inputs, name='bn_dense_1') inputs = activation_fn(inputs) inputs = flax.deprecated.nn.Dense( inputs, features=num_classes, kernel_init=flax.deprecated.nn.initializers.xavier_normal()) inputs = batch_norm(inputs, name='bn_dense_2') inputs = jnp.squeeze(inputs) return flax.deprecated.nn.log_softmax(inputs) ================================================ FILE: rigl/experimental/jax/models/mnist_cnn_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.models.mnist_cnn.""" from absl.testing import absltest import flax import jax import jax.numpy as jnp from rigl.experimental.jax.models import mnist_cnn class MNISTCNNTest(absltest.TestCase): """Tests the MNISTCNN model.""" def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._num_classes = 10 self._batch_size = 2 self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._input = jnp.zeros(*self._input_shape) def test_output_shapes(self): """Tests the output shapes of the model.""" with flax.deprecated.nn.stateful() as initial_state: _, initial_params = mnist_cnn.MNISTCNN.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes) model = flax.deprecated.nn.Model(mnist_cnn.MNISTCNN, initial_params) with flax.deprecated.nn.stateful(initial_state, mutable=False): logits = model(self._input, num_classes=self._num_classes, train=False) self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes)) def test_invalid_depth(self): """Tests model mask with the incorrect depth for the given model.""" with self.assertRaisesRegex(ValueError, 'Input spatial size, '): mnist_cnn.MNISTCNN.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes, filters=10 * (32,)) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/models/mnist_fc.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MNIST Fully-Connected Neural Network. A fully-connected model for the MNIST dataset, consists of a number of dense layers (determined by length of features parameter). """ import math from typing import Callable, Mapping, Optional, Sequence, Tuple from absl import logging import flax import jax.numpy as jnp from rigl.experimental.jax.pruning import init from rigl.experimental.jax.pruning import masked def feature_dim_for_param(input_len, param_count, depth, depth_mult = 2.): """Calculates feature dimensions for a fixed parameter count and depth. This is calculated for the specific case of a fully-connected neural network, where each layer consists of l * a**i neurons, where a is a multiplier for each layer. Assume, x is the input size, a is the depth multiplier, l is the initial layer width, d is the depth. The total number of parameters, n, is then given by, $$n = x*l + l^2 * sum_{i=2}^d a^{2i-3})$$. Args: input_len: Input size. param_count: Number of parameters model should maintain. depth: Depth of the model. depth_mult: The layer width multiplier w.r.t. depth. Returns: The feature specification for a fully-connected model, as a tuple of layer widths. Raises: ValueError: If the given number of parameters is too low for the given depth and input size. """ # Calculate the initial width for the first layer. if depth == 1: initial_width = param_count / input_len else: # l = ((x^2 + 4cn)^{1/2} - x)/(2c) where c = sum_{i=2}^d a^{2i-3}. depth_sum = sum(depth_mult**(2 * i - 3) for i in range(2, depth + 1)) initial_width = (math.sqrt(input_len**2 + 4 * depth_sum * param_count) - input_len) / (2 * depth_sum) if initial_width < 1: raise ValueError( 'Expected parameter count too low for given depth and input size.') return tuple(int(int(initial_width) * depth_mult**i) for i in range(depth)) class MNISTFC(flax.deprecated.nn.Module): """MNIST Fully-Connected Neural Network.""" def apply(self, inputs, num_classes, features = (32, 32), train=True, init_fn = flax.deprecated.nn.initializers.kaiming_normal, activation_fn = flax.deprecated.nn.relu, masks = None, masked_layer_indices = None, dropout_rate = 0.): """Applies fully-connected neural network to the inputs. Args: inputs: Input data with dimensions (batch, features), if features has more than one dimension, it is flattened. num_classes: Number of classes in the dataset. features: Number of neurons in each layer, and number of layers (given by length of sequence) + one layer for softmax. train: If model is being evaluated in training mode or not. init_fn: Initialization function used for dense layers. activation_fn: Activation function to be used for dense layers. masks: Masks of the layers in this model, in the same form as module params, or None. masked_layer_indices: The layer indices of layers in model to be masked. dropout_rate: Dropout rate, if 0 then dropout is not used (default). Returns: A tensor of shape (batch, num_classes), containing the logit output. """ batch_norm = flax.deprecated.nn.BatchNorm.partial( use_running_average=not train, momentum=0.99, epsilon=1e-5) depth = 1 + len(features) masks = masked.generate_model_masks(depth, masks, masked_layer_indices) # If inputs are in image dimensions, flatten image. inputs = inputs.reshape(inputs.shape[0], -1) for i, feature_num in enumerate(features): if f'MaskedModule_{i}' in masks: logging.info('Layer %d is masked in model', i) mask = masks[f'MaskedModule_{i}'] inputs = masked.masked(flax.deprecated.nn.Dense, mask)( inputs, features=feature_num, kernel_init=init.sparse_init( init_fn(), mask['kernel'] if mask is not None else None)) else: inputs = flax.deprecated.nn.Dense( inputs, features=feature_num, kernel_init=init_fn()) inputs = batch_norm(inputs, name=f'bn_conv_{i}') inputs = activation_fn(inputs) if dropout_rate > 0.0: inputs = flax.deprecated.nn.dropout( inputs, dropout_rate, deterministic=not train) inputs = flax.deprecated.nn.Dense( inputs, features=num_classes, kernel_init=flax.deprecated.nn.initializers.xavier_normal()) return flax.deprecated.nn.log_softmax(inputs) ================================================ FILE: rigl/experimental/jax/models/mnist_fc_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.models.mnist_fc.""" from typing import Sequence from absl.testing import absltest from absl.testing import parameterized import flax import jax import jax.numpy as jnp from rigl.experimental.jax.models import mnist_fc from rigl.experimental.jax.utils import utils PARAM_COUNT_PARAM: Sequence[str] = ('kernel',) class MNISTFCTest(parameterized.TestCase): """Tests the MNISTFC model.""" def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._num_classes = 10 self._batch_size = 2 self._input_len = 28*28*1 self._input_shape = ((self._batch_size, self._input_len), jnp.float32) self._input = jnp.zeros((self._batch_size, self._input_len), jnp.float32) self._param_count = 1e7 def test_output_shapes(self): """Tests the output shape from the model.""" with flax.deprecated.nn.stateful() as initial_state: _, initial_params = mnist_fc.MNISTFC.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes) model = flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params) with flax.deprecated.nn.stateful(initial_state, mutable=False): logits = model(self._input, num_classes=self._num_classes, train=False) self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes)) def test_invalid_masks_depth(self): """Tests a model with an invalid mask.""" invalid_masks = { 'MaskedModule_0': { 'kernel': jnp.zeros((self._batch_size, 5 * 5 * 16)) } } with self.assertRaisesRegex( ValueError, 'Mask is invalid for model.'): mnist_fc.MNISTFC.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes, masks=invalid_masks) def _create_model(self, features): """Convenience fn to create a FLAX model .""" _, initial_params = mnist_fc.MNISTFC.init_by_shape( self._rng, (self._input_shape,), num_classes=self._num_classes, features=features) return flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params) @parameterized.parameters(*range(1, 6)) def test_feature_dim_for_param_depth(self, depth): """Tests feature_dim_for_param with multiple depths.""" features = mnist_fc.feature_dim_for_param(self._input_len, self._param_count, depth) model = self._create_model(features) total_size = utils.count_param(model, PARAM_COUNT_PARAM) with self.subTest(name='FeatureDimLen'): self.assertLen(features, depth) with self.subTest(name='FeatureDimParamCount'): self.assertBetween(total_size, self._param_count * 0.95, self._param_count * 1.05) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/models/model_factory.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Factory for neural network models. Attributes: MODELS: A list of the models that can be created. """ from typing import Any, Callable, Mapping, Sequence, Tuple, Type import flax import jax.numpy as jnp from rigl.experimental.jax.models import cifar10_cnn from rigl.experimental.jax.models import mnist_cnn from rigl.experimental.jax.models import mnist_fc MODELS: Mapping[str, Type[flax.deprecated.nn.Model]] = { 'MNIST_CNN': mnist_cnn.MNISTCNN, 'MNIST_FC': mnist_fc.MNISTFC, 'CIFAR10_CNN': cifar10_cnn.CIFAR10CNN, } def create_model( name, rng, input_specs, **kwargs ): """Creates a Model. Args: name: the name of the model to instantiate. rng : the random number generator to use for init. input_specs: an iterable of (shape, dtype) pairs specifying the inputs. **kwargs: list of model specific keyword arguments. Returns: A tuple of FLAX model (flax.deprecated.nn.Model), and initial model state. Raises: ValueError if a model with the given name does not exist. """ if name not in MODELS: raise ValueError('No such model: {}'.format(name)) with flax.deprecated.nn.stateful() as init_state: with flax.deprecated.nn.stochastic(rng): model_class = MODELS[name].partial(**kwargs) _, params = model_class.init_by_shape(rng, input_specs) return flax.deprecated.nn.Model(model_class, params), init_state def update_model(model, **kwargs): """Updates a model to use different model arguments, but same parameters. Args: model: The model to update. **kwargs: List of model specific keyword arguments. Returns: A FLAX model. """ return flax.deprecated.nn.Model(model.module.partial(**kwargs), model.params) ================================================ FILE: rigl/experimental/jax/models/model_factory_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.models.model_factory.""" from absl.testing import absltest from absl.testing import parameterized import flax import jax import jax.numpy as jnp from rigl.experimental.jax.models import model_factory class ModelCommonTest(parameterized.TestCase): """Tests the model factory.""" def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._input_shape = ((1, 28, 28, 1), jnp.float32) self._num_classes = 10 def _create_model(self, model_name): return model_factory.create_model( model_name, self._rng, (self._input_shape,), num_classes=self._num_classes) @parameterized.parameters(*model_factory.MODELS.keys()) def test_model_supported(self, model_name): """Tests supported models.""" model, state = self._create_model(model_name) with self.subTest(name='test_model_supported_model_instance'): self.assertIsInstance(model, flax.deprecated.nn.Model) with self.subTest(name='test_model_supported_collection_instance'): self.assertIsInstance(state, flax.deprecated.nn.Collection) def test_model_unsupported(self): """Tests unsupported models.""" with self.assertRaisesRegex(ValueError, 'No such model: unsupported'): self._create_model('unsupported') if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/prune.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Weight Symmetry: Iteratively Prune Model during Training. Command for training and pruning an MNIST fully-connected model for 10 epochs with a fixed pruning rate of 0.95: prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10 --pruning_rate=0.95 Command for training and pruning an MNIST fully-connected model for 10 epochs, with pruning rates 0.3, 0.6 and 0.95 at epochs 2, 5, and 8 respectively for all layers: prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10 --pruning_schedule='[(2, 0.3), (5, 0.6), (8, 0.95)]' Command for doing the same, but performing pruning only on the second layer: prune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10 --pruning_schedule="{'1': [(2, 0.3), (5, 0.6), (8, 0.95)]}" """ import ast from collections import abc import functools from os import path from typing import List import uuid from absl import app from absl import flags from absl import logging import flax from flax.metrics import tensorboard from flax.training import lr_schedule import jax import jax.numpy as jnp from rigl.experimental.jax.datasets import dataset_factory from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.training import training from rigl.experimental.jax.utils import utils experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes) initial_model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, masked_layer_indices=FLAGS.masked_layer_indices) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam( learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum( learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == LR_SCHEDULE_CONSTANT: lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == LR_SCHEDULE_STEPPED: lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == LR_SCHEDULE_COSINE: lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}') # Reuses the FLAX learning rate schedule framework for pruning rate schedule. pruning_fn_p = functools.partial( lr_schedule.create_stepped_learning_rate_schedule, FLAGS.pruning_rate, steps_per_epoch) if FLAGS.pruning_schedule: pruning_schedule = ast.literal_eval(FLAGS.pruning_schedule) if isinstance(pruning_schedule, abc.Mapping): pruning_rate_fn = { f'MaskedModule_{layer_num}': pruning_fn_p(schedule) for layer_num, schedule in pruning_schedule.items() } else: pruning_rate_fn = pruning_fn_p(pruning_schedule) else: pruning_rate_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.pruning_rate, steps_per_epoch) if jax.host_id() == 0: trainer = training.Trainer( optimizer, initial_model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer( optimizer, initial_model, initial_state, dataset, rng) _, best_metrics = trainer.train( FLAGS.epochs, lr_fn=lr_fn, pruning_rate_fn=pruning_rate_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch, ) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: if FLAGS.dump_json: utils.dump_dict_json(best_metrics, path.join(experiment_dir, 'best_metrics.json')) for label, value in best_metrics.items(): summary_writer.scalar(f'best/{label}', value, FLAGS.epochs * steps_per_epoch) summary_writer.close() def main(argv: List[str]): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') run_training() if __name__ == '__main__': app.run(main) ================================================ FILE: rigl/experimental/jax/prune_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.prune.""" import glob from os import path from absl.testing import absltest from absl.testing import flagsaver from rigl.experimental.jax import prune class PruneTest(absltest.TestCase): def test_prune_fixed_schedule(self): """Tests training/pruning driver with a fixed global sparsity.""" experiment_dir = self.create_tempdir().full_path eval_flags = dict( epochs=1, pruning_rate=0.95, experiment_dir=experiment_dir, ) with flagsaver.flagsaver(**eval_flags): prune.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_prune_global_pruning_schedule(self): """Tests training/pruning driver with a global sparsity schedule.""" experiment_dir = self.create_tempdir().full_path eval_flags = dict( epochs=10, pruning_schedule='[(5, 0.33), (7, 0.66), (9, 0.95)]', experiment_dir=experiment_dir, ) with flagsaver.flagsaver(**eval_flags): prune.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_prune_local_pruning_schedule(self): """Tests training/pruning driver with a single layer sparsity schedule.""" experiment_dir = self.create_tempdir().full_path eval_flags = dict( epochs=10, pruning_schedule='{1:[(5, 0.33), (7, 0.66), (9, 0.95)]}', experiment_dir=experiment_dir, ) with flagsaver.flagsaver(**eval_flags): prune.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/pruning/__init__.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: rigl/experimental/jax/pruning/init.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for initialization of masked models.""" import functools from typing import Callable, Sequence, Optional import flax import jax import jax.numpy as jnp def sparse_init( base_init, mask, dtype=jnp.float32): """Weight initializer with correct fan in/fan out for a masked model. The weight initializer uses any dense initializer to correctly initialize a masked weight matrix by calling the given initialization method with the correct fan in/fan out for every neuron in the layer. If the mask is None, it reverts to the original initialization method. Args: base_init: The base (dense) initialization method to use. mask: The layer's mask, or None. dtype: The weight array jnp.dtype. Returns: An initialization method that is mask aware for the given layer and mask. """ def init(rng, shape, dtype=dtype): if mask is None: return base_init(rng, shape, dtype) # Find the ablated neurons in the mask, to determine correct fan_out. neuron_weight_count = jnp.sum( jnp.reshape(mask, (-1, mask.shape[-1])), axis=0) non_zero_neurons = jnp.sum(neuron_weight_count != 0) # Special case of completely ablated weight matrix/layer. if jnp.sum(non_zero_neurons) == 0: print('Empty weight mask!') return jnp.zeros(shape, dtype) # Neurons have different fan_in w/mask, build up initialization per-unit. init_cols = [] rng, *split_rngs = jax.random.split(rng, mask.shape[-1] + 1) for i in range(mask.shape[-1]): # Special case of ablated neuron. if neuron_weight_count[i] == 0: init_cols.append(jnp.zeros(shape[:-1] + (1,), dtype)) continue # Fake shape of weight matrix with correct fan_in, and fan_out. sparse_shape = (int(neuron_weight_count[i]), int(non_zero_neurons)) # Use only the first column of init from initializer, since faked fan_out. init = base_init(split_rngs[i], sparse_shape, dtype)[Ellipsis, 0] # Expand out to full sparse array. expanded_init = jnp.zeros( mask[Ellipsis, i].shape, dtype).flatten().at[jnp.where(mask[Ellipsis, i].flatten() == 1)].set(init) expanded_init = jnp.reshape(expanded_init, mask[Ellipsis, i].shape) init_cols.append(expanded_init[Ellipsis, jnp.newaxis]) return jnp.concatenate(init_cols, axis=-1) return init xavier_sparse_normal = glorot_sparse_normal = functools.partial( sparse_init, flax.deprecated.nn.initializers.xavier_normal()) kaiming_sparse_normal = he_sparse_normal = functools.partial( sparse_init, flax.deprecated.nn.initializers.kaiming_normal()) ================================================ FILE: rigl/experimental/jax/pruning/init_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.pruning.init.""" from typing import Any, Mapping, Optional from absl.testing import absltest import flax import jax import jax.numpy as jnp from rigl.experimental.jax.pruning import init from rigl.experimental.jax.pruning import masked class MaskedDense(flax.deprecated.nn.Module): """Single-layer Dense Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) layer_mask = mask['MaskedModule_0'] if mask else None return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Dense, mask=layer_mask, kernel_init=flax.deprecated.nn.initializers.kaiming_normal()) class MaskedDenseSparseInit(flax.deprecated.nn.Module): """Single-layer Dense Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, *args, mask = None, **kwargs): inputs = inputs.reshape(inputs.shape[0], -1) layer_mask = mask['MaskedModule_0'] if mask else None return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Dense, mask=layer_mask, kernel_init=init.kaiming_sparse_normal( layer_mask['kernel'] if layer_mask is not None else None), **kwargs) class MaskedCNN(flax.deprecated.nn.Module): """Single-layer CNN Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, mask = None): layer_mask = mask['MaskedModule_0'] if mask else None return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Conv, kernel_size=(3, 3), mask=layer_mask, kernel_init=flax.deprecated.nn.initializers.kaiming_normal()) class MaskedCNNSparseInit(flax.deprecated.nn.Module): """Single-layer CNN Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, *args, mask = None, **kwargs): layer_mask = mask['MaskedModule_0'] if mask else None return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Conv, kernel_size=(3, 3), mask=layer_mask, kernel_init=init.kaiming_sparse_normal( layer_mask['kernel'] if layer_mask is not None else None), **kwargs) class InitTest(absltest.TestCase): def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._batch_size = 2 self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._input = jnp.ones(*self._input_shape) def test_init_kaiming_sparse_normal_output(self): """Tests the output shape/type of kaiming normal sparse initialization.""" input_array = jnp.ones((64, 16), jnp.float32) mask = jax.random.bernoulli(self._rng, shape=(64, 16)) base_init = flax.deprecated.nn.initializers.kaiming_normal()( self._rng, input_array.shape, input_array.dtype) sparse_init = init.kaiming_sparse_normal(mask)(self._rng, input_array.shape, input_array.dtype) with self.subTest(name='test_sparse_init_output_shape'): self.assertSequenceEqual(sparse_init.shape, base_init.shape) with self.subTest(name='test_sparse_init_output_dtype'): self.assertEqual(sparse_init.dtype, base_init.dtype) with self.subTest(name='test_sparse_init_output_notallzero'): self.assertTrue((sparse_init != 0).any()) def test_dense_no_mask(self): """Checks that in the special case of no mask, init is same as base_init.""" _, initial_params = MaskedDense.init_by_shape(self._rng, (self._input_shape,)) self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params) _, initial_params = MaskedDenseSparseInit.init_by_shape( jax.random.PRNGKey(42), (self._input_shape,), mask=None) self._masked_model_sparse_init = flax.deprecated.nn.Model( MaskedDenseSparseInit, initial_params) self.assertTrue( jnp.isclose( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel'], self._unmasked_model.params['MaskedModule_0'] ['unmasked']['kernel']).all()) def test_dense_sparse_init_kaiming(self): """Checks kaiming normal sparse initialization for dense layer.""" _, initial_params = MaskedDense.init_by_shape(self._rng, (self._input_shape,)) self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params) mask = masked.simple_mask(self._unmasked_model, jnp.ones, masked.WEIGHT_PARAM_NAMES) _, initial_params = MaskedDenseSparseInit.init_by_shape( jax.random.PRNGKey(42), (self._input_shape,), mask=mask) self._masked_model_sparse_init = flax.deprecated.nn.Model( MaskedDenseSparseInit, initial_params) mean_init = jnp.mean( self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel']) stddev_init = jnp.std( self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel']) mean_sparse_init = jnp.mean( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) stddev_sparse_init = jnp.std( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) with self.subTest(name='test_cnn_sparse_init_mean'): self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init, mean_init + 2 * stddev_init) with self.subTest(name='test_cnn_sparse_init_stddev'): self.assertBetween(stddev_sparse_init, 0.5 * stddev_init, 1.5 * stddev_init) def test_cnn_sparse_init_kaiming(self): """Checks kaiming normal sparse initialization for convolutional layer.""" _, initial_params = MaskedCNN.init_by_shape(self._rng, (self._input_shape,)) self._unmasked_model = flax.deprecated.nn.Model(MaskedCNN, initial_params) mask = masked.simple_mask(self._unmasked_model, jnp.ones, masked.WEIGHT_PARAM_NAMES) _, initial_params = MaskedCNNSparseInit.init_by_shape( jax.random.PRNGKey(42), (self._input_shape,), mask=mask) self._masked_model_sparse_init = flax.deprecated.nn.Model( MaskedCNNSparseInit, initial_params) mean_init = jnp.mean( self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel']) stddev_init = jnp.std( self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel']) mean_sparse_init = jnp.mean( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) stddev_sparse_init = jnp.std( self._masked_model_sparse_init.params['MaskedModule_0']['unmasked'] ['kernel']) with self.subTest(name='test_cnn_sparse_init_mean'): self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init, mean_init + 2 * stddev_init) with self.subTest(name='test_cnn_sparse_init_stddev'): self.assertBetween(stddev_sparse_init, 0.5 * stddev_init, 1.5 * stddev_init) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/pruning/mask_factory.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pruning mask factory. Attributes: MaskFnType: A type alias for functions to create sparse masks. MASK_TYPES: Masks types that can be created. """ from typing import Any, Callable, Mapping import flax import jax.numpy as jnp from rigl.experimental.jax.pruning import masked # A function to create a mask, takes as arguments: a flax model, JAX PRNG Key, # sparsity level as a float in [0, 1]. MaskFnType = Callable[ [flax.deprecated.nn.Model, Callable[[int], jnp.array], float], masked.MaskType] MASK_TYPES: Mapping[str, MaskFnType] = { 'random': masked.shuffled_mask, 'per_neuron': masked.shuffled_neuron_mask, 'per_neuron_no_input_ablation': masked.shuffled_neuron_no_input_ablation_mask, 'symmetric': masked.symmetric_mask, } def create_mask(mask_type, base_model, rng, sparsity, **kwargs): """Creates a Mask of the given type. Args: mask_type: the name of the type of mask to instantiate. base_model: the model to create a mask for. rng : the random number generator to use for init. sparsity: the mask sparsity. **kwargs: list of model specific keyword arguments. Returns: A mask for a FLAX model. Raises: ValueError if a model with the given name does not exist. """ if mask_type not in MASK_TYPES: raise ValueError(f'Unknown mask type: {mask_type}') return MASK_TYPES[mask_type](base_model, rng, sparsity, **kwargs) ================================================ FILE: rigl/experimental/jax/pruning/mask_factory_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.models.model_factory.""" from typing import Mapping, Optional from absl.testing import absltest from absl.testing import parameterized import flax import jax import jax.numpy as jnp from rigl.experimental.jax.pruning import mask_factory from rigl.experimental.jax.pruning import masked class MaskedDense(flax.deprecated.nn.Module): """Single-layer Dense Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_0'] if mask else None) class MaskFactoryTest(parameterized.TestCase): def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._input_shape = ((1, 28, 28, 1), jnp.float32) self._num_classes = 10 self._sparsity = 0.9 _, initial_params = MaskedDense.init_by_shape(self._rng, (self._input_shape,)) # Use the same initialization for both masked/unmasked models. self._model = flax.deprecated.nn.Model(MaskedDense, initial_params) def _create_mask(self, mask_type): return mask_factory.create_mask( mask_type, self._model, self._rng, self._sparsity) @parameterized.parameters(*mask_factory.MASK_TYPES.keys()) def test_mask_supported(self, mask_type): """Tests supported mask types.""" mask = self._create_mask(mask_type) with self.subTest(name='test_mask_type'): self.assertIsInstance(mask, dict) def test_mask_unsupported(self): """Tests unsupported mask types.""" with self.assertRaisesRegex(ValueError, 'Unknown mask type: unsupported'): self._create_mask('unsupported') if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/pruning/masked.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Masked wrapped for FLAX modules. Attributes: WEIGHT_PARAM_NAMES: The name of the weight parameters to use. MaskType: Model mask type for static type checking. MaskLayerType: Mask layer type for static type checking. MutableMaskType: Mutable model mask type for static type checking. MutableMaskLayerType: Mutable mask layer type for static type checking. """ import functools import operator from typing import Any, Callable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple, Type from absl import logging import flax import jax import jax.numpy as jnp import jax.ops # Model weight param names, e.g. 'kernel', (as opposed batch norm param, etc). WEIGHT_PARAM_NAMES = ('kernel',) # Note: Bias is not typically masked. # Mask layer type for static type checking. MaskLayerType = Mapping[str, Optional[jnp.array]] # Model mask type for static type checking. MaskType = Mapping[str, Optional[MaskLayerType]] # Mask layer type for static type checking. MutableMaskLayerType = MutableMapping[str, Optional[jnp.array]] # Model mask type for static type checking. MutableMaskType = MutableMapping[str, MutableMaskLayerType] class MaskedModule(flax.deprecated.nn.Module): """Generic FLAX Masking Module. Masks a FLAX module, given a mask for params of each layer. Attributes: UNMASKED: The key to use for the unmasked parameter dictionary. """ UNMASKED = 'unmasked' def apply(self, *args, wrapped_module, mask = None, **kwargs): """Apply the wrapped module, while applying the given masks to its params. Args: *args: The positional arguments for the wrapped module. wrapped_module: The module class to be wrapped. mask: The mask nested dictionary containing masks for the wrapped module's params, in the same format/with the same keys as the module param dict (or None if not to mask). **kwargs: The keyword arguments for the wrapped module. Returns: The intermediate outputs specified by truncate_path. Raises: ValueError: If the given mask is not valid for the wrapped module, i.e. the pytrees do not match. """ # Explicitly create the parameters of the wrapped module. def init_fn(rng, input_shape): del input_shape # Unused. # Call init to get the params of the wrapped module. _, params = wrapped_module.init(rng, *args, **kwargs) return params unmasked_params = self.param(self.UNMASKED, None, init_fn) if mask is not None: try: masked_params = jax.tree_util.tree_map( lambda x, *xs: x if xs[0] is None else x * xs[0], unmasked_params, mask) except ValueError as err: raise ValueError('Mask is invalid for model.') from err # Call the wrapped module with the masked params. return wrapped_module.call(masked_params, *args, **kwargs) else: logging.warning('Using masked module without mask!') # Call the wrapped module with the unmasked params. return wrapped_module.call(unmasked_params, *args, **kwargs) def masked(module, mask): """Convenience function for masking a FLAX module with MaskedModule.""" return MaskedModule.partial(wrapped_module=module, mask=mask) def generate_model_masks( depth, mask = None, masked_layer_indices = None): """Creates empty masks for this model, or initializes with existing mask. Args: depth: Number of layers in the model. mask: Existing model mask for layers in this model, if not given, all module masks are initialized to None. masked_layer_indices: The layer indices of layers in model to be masked, or all if None. Returns: A model mask, with None where no mask is given for a model layer, or that specific layer is indicated as not to be masked by the masked_layer_indices parameter. """ if depth <= 0: raise ValueError(f'Invalid model depth: {depth}') if mask is None: mask = {f'MaskedModule_{i}': None for i in range(depth)} # Have to explicitly check for None to differentiate from empty array. if masked_layer_indices is not None: # Check none of the indices are outside of model's layer bounds. if any(i < 0 or i >= depth for i in masked_layer_indices): raise ValueError( f'Invalid indices for given depth ({depth}): {masked_layer_indices}') mask = { f'MaskedModule_{i}': mask[f'MaskedModule_{i}'] for i in masked_layer_indices } return mask def _filter_param(param_names, invert = False): """Convenience function for filtering maskable parameters from paths. Args: param_names: Names of parameters we are looking for. invert: Inverts filter to exclude, rather than include, given parameters. Returns: A function to use with flax.deprecated.nn.optim.ModelParamTraversal for filtering. """ def filter_fn(path, value): del value # Unused. parameter_found = any([ '{}/{}'.format(MaskedModule.UNMASKED, param_name) in path for param_name in param_names ]) return not parameter_found if invert else parameter_found return filter_fn def mask_map(model, fn): """Convenience function to create a mask for a model. Args: model: The Flax model, with at least one MaskedModule layer. fn: The function to call on each masked parameter, to create the mask for that parameter, takes the parameter name, and parameter value as arguments and returns the new parameter value. Returns: A model parameter dictionary, with all masked parameters set by the given function, and all other parameters set to None. Raises: ValueError: If the given model does not support masking, i.e. none of the layers are wrapped by a MaskedModule. """ maskable = False for layer_key, layer in model.params.items(): if MaskedModule.UNMASKED not in layer: logging.warning( 'Layer \'%s\' does not support masking, i.e. it is not ' 'wrapped by a MaskedModule', layer_key) else: maskable = True if not maskable: raise ValueError('Model does not support masking, i.e. no layers are ' 'wrapped by a MaskedModule.') # First set all non-masked params to None in copy of model pytree. filter_non_masked = _filter_param(WEIGHT_PARAM_NAMES, invert=True) nonmasked_traversal = flax.optim.ModelParamTraversal(filter_non_masked) # pytype: disable=module-attr mask_model = nonmasked_traversal.update(lambda _: None, model) # Then find params to mask, and set to array. for param_name in WEIGHT_PARAM_NAMES: filter_masked = _filter_param(WEIGHT_PARAM_NAMES) mask_traversal = flax.optim.ModelParamTraversal(filter_masked) # pytype: disable=module-attr mask_model = mask_traversal.update( functools.partial(fn, param_name), mask_model) mask = mask_model.params # Remove unneeded unmasked param for mask. for layer_key, layer in mask.items(): if MaskedModule.UNMASKED in layer: mask[layer_key] = layer[MaskedModule.UNMASKED] return mask def iterate_mask( mask, param_names = None ): """Iterate over the parameters in as mask. Args: mask: The model mask. param_names: The parameter names to iterate over in each layer, if None iterates over all parameters of all layers. Yields: An iterator of tuples containing the parameter path and parameter value in sorted order of layer parameters matching the names in param_names (or all parameters if None). """ flat_mask = flax.traverse_util.flatten_dict(mask) for key, value in flat_mask.items(): if param_names is None or key in param_names: path = '/' + '/'.join(key) yield path, value def shuffled_mask(model, rng, sparsity): """Returns a randomly shuffled mask with a given sparsity for all layers. Returns a random weight mask for a model param array, by randomly shuffling a mask with a fixed number of non-zero/zero entries, given by the sparsity. Args: model: Flax model that contains masked modules. rng: Random number generator, i.e. jax.random.PRNGKey. sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will mask all weights, while 0 will mask none. Returns: A randomly shuffled weight mask, in the same form as flax.Module.params. Raises: ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are maskable, i.e. is wrapped by MaskedModule. """ if sparsity > 1 or sparsity < 0: raise ValueError( 'Given sparsity, {}, is not in range [0, 1]'.format(sparsity)) def create_shuffled_mask(param_name, param): del param_name # Unused. mask = jnp.arange(param.size) mask = jnp.where(mask >= sparsity * param.size, jnp.ones_like(mask), jnp.zeros_like(mask)) mask = jax.random.permutation(rng, mask) return mask.reshape(param.shape) return mask_map(model, create_shuffled_mask) def random_mask(model, rng, mean_sparsity = 0.5): """Returns a random weight mask for a masked model. Args: model: Flax model that contains masked modules. rng: Random number generator, i.e. jax.random.PRNGKey. mean_sparsity: The mean number of 0's in the mask, i.e. mean = (1 - mean_sparsity) for the Bernoulli distribution to sample from. Returns: A random weight mask, in the same form as flax.Module.params Raises: ValueError: If the sparsity is beyond the bounds [0, 1], or if a layer to mask is not maskable, i.e. is not wrapped by MaskedModule. """ if mean_sparsity > 1 or mean_sparsity < 0: raise ValueError( 'Given sparsity, {}, is not in range [0, 1]'.format(mean_sparsity)) # Invert mean_sparsity to get mean for Bernoulli distribution. mean = 1. - mean_sparsity def create_random_mask(param_name, param): del param_name # Unused. return jax.random.bernoulli( rng, p=mean, shape=param.shape).astype(jnp.int32) # TPU doesn't support uint8. return mask_map(model, create_random_mask) def simple_mask(model, init_fn, masked_param): """Creates a mask given a model and numpy initialization function. Args: model: The model to create a mask for. init_fn: The numpy initialization function, e.g. numpy.ones. masked_param: The list of parameters to mask. Returns: A mask for the model. """ def create_init_fn_mask(param_name, param): if param_name in masked_param: return init_fn(param.shape) return None return mask_map(model, create_init_fn_mask) def symmetric_mask(model, rng, sparsity = 0.5): """Generates a random weight mask that's symmetric, i.e. structurally pruned. Args: model: Flax model that contains masked modules. rng: Random number generator, i.e. jax.random.PRNGKey. sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), in the range [0, 1]: 1.0 will mask all weights, while 0 will mask none. Returns: A symmetric random weight mask, in the same form as flax.Module.params. """ if sparsity > 1 or sparsity < 0: raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]') def create_neuron_symmetric_mask(param_name, param): del param_name # Unused. neuron_length = functools.reduce(operator.mul, param.shape[:-1]) neuron_mask = jnp.arange(neuron_length) neuron_mask = jnp.where(neuron_mask >= sparsity * neuron_mask.size, jnp.ones_like(neuron_mask), jnp.zeros_like(neuron_mask)) neuron_mask = jax.random.shuffle(rng, neuron_mask) mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1) return mask.reshape(param.shape) return mask_map(model, create_neuron_symmetric_mask) class _PerNeuronShuffle: """This class is needed to get around the fact that JAX RNG is stateless.""" def __init__(self, init_rng, sparsity): """Creates the per-neuron shuffle class, with initial RNG state. Args: init_rng: The initial random number generator state to use. sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will mask all weights, while 0 will mask none. """ self._rng = init_rng self._sparsity = sparsity def __call__(self, param_name, param): """Shuffles the weight matrix/mask for a given parameter, per-neuron. This is to be used with mask_map, and accepts the standard mask_map function parameters. Args: param_name: The parameter's name. param: The parameter's weight or mask matrix. Returns: A shuffled weight/mask matrix, with each neuron shuffled independently. """ del param_name # Unused. neuron_length = functools.reduce(operator.mul, param.shape[:-1]) neuron_mask = jnp.arange(neuron_length) neuron_mask = jnp.where(neuron_mask >= self._sparsity * neuron_mask.size, jnp.ones_like(neuron_mask), jnp.zeros_like(neuron_mask)) mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1) self._rng, rng_input = jax.random.split(self._rng) mask = jax.random.shuffle(rng_input, mask, axis=0) return mask.reshape(param.shape) def shuffled_neuron_mask(model, rng, sparsity): """Returns a shuffled mask with a given fixed sparsity for all neurons/layers. Returns a randomly shuffled weight mask for a model param array, by setting a fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector in the model, and then randomly shuffling each neuron's weight mask with a fixed number of non-zero/zero entries, given by the sparsity. This ensures no neuron is ablated for a non-zero sparsity. Note: This is much more complicated for convolutional layers due to the receptive field being different for every pixel! We only take into account channel-wise masks and not spatial ablations in propagation in that case. Args: model: Flax model that contains masked modules. rng: Random number generator, i.e. jax.random.PRNGKey. sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will mask all weights, while 0 will mask none. Returns: A randomly shuffled weight mask, in the same form as flax.Module.params. Raises: ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are maskable, i.e. is wrapped by MaskedModule. """ if sparsity > 1 or sparsity < 0: raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]') return mask_map(model, _PerNeuronShuffle(rng, sparsity)) def _fill_diagonal_wrap(shape, value, dtype = jnp.uint8): """Fills the diagonal of a 2D array, while also wrapping tall arrays. For a matrix of dimensions (N x M),: if N <= M, i.e. the array is wide rectangular, the array's diagonal is filled, for example: _fill_diagonal_wrap(jnp.zeroes((2, 3), dtype=uint8), 1) > [[1, 0, 0], [0, 1, 0]] if N > M, i.e. the array is tall rectangular, the array's diagonal, and offset diagonals are filled. This differs from numpy.fill_diagonal(..., wrap=True), in that it does not include a single row gap between the diagonals, and it is not in-place but returns a copy of the given array. For example, _fill_diagonal_wrap(jnp.zeroes((3, 2), dtype=uint8), 1) > [[1, 0], [0, 1], [1, 0]] Args: shape: The shape of the 2D array to return with the diagonal filled. value: The value to fill in for the diagonal, and offset diagonals. dtype: The datatype of the jax numpy array to return. Returns: A copy of the given array with the main diagonal filled, and offset diagonals filled if the given array is tall. """ if len(shape) != 2: raise ValueError( f'Expected an 2D array, however array has dimensions: {shape}') array = jnp.zeros(shape, dtype=dtype) rows, cols = shape def diagonal_indices(offset): # Returns jax.ops._Indexable. """Returns slice of the nth diagonal of an array, where n is offset.""" # This is an a numpy-style advanced slice of the form [start:end:step], that # gives you the offset (vertically) diagonal of an array. If it was the main # diagonal of a (flattened) square matrix of n X n it would be 0:n**2:n+1, # i.e. start at 0, and look at each n+1 elements, end when you get to end # of array. We need to look at vertically-offset diagonals as well, which is # handled by offset. return jnp.index_exp[cols * offset:cols * (offset + cols):cols + 1] # Fills (square) matrix diagonals with the given value, tiling over tall # rectangular arrays by offsetting the filled diagonals by multiples of the # height of the square arrays. diagonals = [ array.ravel().at[diagonal_indices(offset)].set(value).reshape(array.shape) for offset in range(0, rows, cols) ] return functools.reduce(jnp.add, diagonals) def _random_neuron_mask(neuron_length, unmasked_count, rng, dtype = jnp.uint32): """Generates a random mask for a neuron. Args: neuron_length: The length of the neuron's weight vector. unmasked_count: The number of elements that should be unmasked. rng: A jax.random.PRNGKey random seed. dtype: Type of array to create. Returns: A random neuron weight vector mask. """ if unmasked_count > neuron_length: raise ValueError('unmasked_count cannot be greater that neuron_length: ' f'{unmasked_count} > {neuron_length}') neuron_mask = jnp.concatenate( (jnp.ones(unmasked_count), jnp.zeros(neuron_length - unmasked_count)), axis=0) neuron_mask = jax.random.shuffle(rng, neuron_mask) return neuron_mask.astype(dtype) class _PerNeuronNoInputAblationShuffle: """This class is needed to get around the fact that JAX RNG is stateless.""" def __init__(self, init_rng, sparsity): """Creates the per-neuron shuffle class, with initial RNG state. Args: init_rng: The initial random number generator state to use. sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will mask all weights, while 0 will mask none. """ self._rng = init_rng self._sparsity = sparsity def _get_rng(self): """Creates a new JAX RNG, while updating RNG state.""" self._rng, rng_input = jax.random.split(self._rng) return rng_input def __call__(self, param_name, param): """Shuffles the weight matrix/mask for a given parameter, per-neuron. This is to be used with mask_map, and accepts the standard mask_map function parameters. Args: param_name: The parameter's name. param: The parameter's weight or mask matrix. Returns: A shuffled weight/mask matrix, with each neuron shuffled independently. """ del param_name # Unused. incoming_connections = jnp.prod(jnp.array(param.shape[:-1])) num_neurons = param.shape[-1] # Ensure each input neuron has at least one connection unmasked. mask = _fill_diagonal_wrap((incoming_connections, num_neurons), 1, dtype=jnp.uint8) # Randomly shuffle which of the neurons have these connections. mask = jax.random.shuffle(self._get_rng(), mask, axis=0) # Add extra required random connections to mask to satisfy sparsity. mask_cols = [] for col in range(mask.shape[-1]): neuron_mask = mask[:, col] off_diagonal_count = max( round((1 - self._sparsity) * incoming_connections) - jnp.count_nonzero(neuron_mask), 0) zero_indices = jnp.flatnonzero(neuron_mask == 0) random_entries = _random_neuron_mask( len(zero_indices), off_diagonal_count, self._get_rng()) neuron_mask = neuron_mask.at[zero_indices].set(random_entries) mask_cols.append(neuron_mask) return jnp.column_stack(mask_cols).reshape(param.shape) def shuffled_neuron_no_input_ablation_mask(model, rng, sparsity): """Returns a shuffled mask with a given fixed sparsity for all neurons/layers. Returns a randomly shuffled weight mask for a model param array, by setting a fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector in the model, and then randomly shuffling each neuron's weight mask with a fixed number of non-zero/zero entries, given by the sparsity. This ensures no neuron is ablated for a non-zero sparsity. This function also ensures that no neurons in the previous layer are effectively ablated, by ensuring that each neuron has at least one connection. Note: This is much more complicated for convolutional layers due to the receptive field being different for every pixel! We only take into account channel-wise masks and not spatial ablations in propagation in that case. Args: model: Flax model that contains masked modules. rng: Random number generator, i.e. jax.random.PRNGKey. sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will mask all weights, except for the minimum number required to maintain, connectivity with the input layer, while 0 will mask none. Returns: A randomly shuffled weight mask, in the same form as flax.Module.params. Raises: ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are maskable, i.e. is wrapped by MaskedModule. """ if sparsity > 1.0 or sparsity < 0.0: raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]') # First, generate a random permutation matrix, and ensure our mask has at # least N connections, where there are N neurons in the previous layer. return mask_map(model, _PerNeuronNoInputAblationShuffle(rng, sparsity)) def propagate_masks( mask, param_names = WEIGHT_PARAM_NAMES ): """Accounts for implicitly pruned neurons in a model's weight masks. When neurons are randomly ablated in one layer, they can effectively ablate neurons in the next layer if in effect all incoming weights of a neuron are zero. This method accounts for this by propagating forward mask information through the entire model. Args: mask: Model masks to check, in same pytree structure as Model.params. param_names: List of param keys in mask to count. Returns: A refined model mask with weights that are effectively ablated in the original mask set to zero. """ flat_mask = flax.traverse_util.flatten_dict(mask) mask_layer_list = list(flat_mask.values()) mask_layer_keys = list(flat_mask.keys()) mask_layer_param_names = [layer_param[-1] for layer_param in mask_layer_keys] for param_name in param_names: # Find which of the param arrays correspond to leaf nodes with this name. param_indices = [ i for i, names in enumerate(mask_layer_param_names) if param_name in names ] for i in range(1, len(param_indices)): last_weight_mask = mask_layer_list[param_indices[i - 1]] weight_mask = mask_layer_list[param_indices[i]] if last_weight_mask is None or weight_mask is None: continue last_weight_mask_reshaped = jnp.reshape(last_weight_mask, (-1, last_weight_mask.shape[-1])) # Neurons with any outgoing weights from previous layer. alive_incoming = jnp.sum(last_weight_mask_reshaped, axis=0) != 0 # Combine effective mask of previous layer with neuron's current mask. if len(weight_mask.shape) > 2: # Convolutional layer, only consider channel-wise masks, if any spatial # weight is non-zero that channel is considered non-masked. spatial_dim = len(weight_mask.shape) - 2 new_weight_mask = alive_incoming[:, jnp.newaxis] * jnp.amax( weight_mask, axis=tuple(range(spatial_dim))) new_weight_mask = jnp.tile(new_weight_mask, weight_mask.shape[:-2] + (1, 1)) else: # Check for case of dense following convolution, i.e. spatial input into # dense, to prevent b/156135283. Must use convolution for these layers. if len(last_weight_mask.shape) > 2: raise ValueError( 'propagate_masks requires knowledge of the spatial ' 'dimensions of the previous layer. Use a functionally equivalent ' 'conv. layer in place of a dense layer in a model with a mixed ' 'conv/dense setting.') new_weight_mask = alive_incoming[:, jnp.newaxis] * weight_mask mask_layer_list[param_indices[i]] = jnp.reshape( new_weight_mask, mask_layer_list[param_indices[i]].shape) return flax.traverse_util.unflatten_dict( dict(zip(mask_layer_keys, mask_layer_list))) def mask_layer_sparsity(mask_layer): """Calculates the sparsity of a single layer's mask. Args: mask_layer: mask layer to calculate the sparsity of. Returns: The sparsity of the mask. """ parameter_count = 0 masked_count = 0 for key in mask_layer: if mask_layer[key] is not None and key in WEIGHT_PARAM_NAMES: parameter_count += mask_layer[key].size masked_count += jnp.sum(mask_layer[key]) if parameter_count == 0: return 0. return 1. - masked_count/parameter_count def mask_sparsity( mask, param_names = None): """Calculates the sparsity of the given parameters over a model mask. Args: mask: Model mask to calculate sparsity over. param_names: List of param keys in mask to count. Returns: The overall sparsity of the mask. """ if param_names is None: param_names = WEIGHT_PARAM_NAMES parameter_count = 0 masked_count = 0 for path, value in iterate_mask(mask): if value is not None and any( param_name in path for param_name in param_names): parameter_count += value.size masked_count += jnp.sum(value.flatten()) if parameter_count == 0: return 0. return 1.0 - float(masked_count / parameter_count) ================================================ FILE: rigl/experimental/jax/pruning/masked_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.pruning.masked.""" from typing import Mapping, Optional, Sequence from absl.testing import absltest from absl.testing import parameterized import flax import jax import jax.numpy as jnp import numpy as np from rigl.experimental.jax.pruning import masked class Dense(flax.deprecated.nn.Module): """Single-layer Dense Non-Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs): inputs = inputs.reshape(inputs.shape[0], -1) return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES) class MaskedDense(flax.deprecated.nn.Module): """Single-layer Dense Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_0'] if mask else None) class DenseTwoLayer(flax.deprecated.nn.Module): """Two-layer Dense Non-Masked Network.""" NUM_FEATURES: Sequence[int] = (32, 64) def apply(self, inputs): inputs = inputs.reshape(inputs.shape[0], -1) inputs = flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[0]) return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[1]) class MaskedTwoLayerDense(flax.deprecated.nn.Module): """Two-layer Dense Masked Network.""" NUM_FEATURES: Sequence[int] = (32, 64) def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[0], wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_0'] if mask else None) return masked.MaskedModule( inputs, features=self.NUM_FEATURES[1], wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_1'] if mask else None) class MaskedConv(flax.deprecated.nn.Module): """Single-layer Conv Masked Network.""" NUM_FEATURES: int = 16 def apply(self, inputs, mask = None): return masked.MaskedModule( inputs, features=self.NUM_FEATURES, kernel_size=(3, 3), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_0'] if mask is not None else None) class MaskedTwoLayerConv(flax.deprecated.nn.Module): """Two-layer Conv Masked Network.""" NUM_FEATURES: Sequence[int] = (16, 32) def apply(self, inputs, mask = None): inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[0], kernel_size=(5, 5), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_0'] if mask is not None else None) return masked.MaskedModule( inputs, features=self.NUM_FEATURES[1], kernel_size=(3, 3), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_1'] if mask is not None else None) class MaskedThreeLayerConvDense(flax.deprecated.nn.Module): """Three-layer Conv Masked Network with Dense layer.""" NUM_FEATURES: Sequence[int] = (16, 32, 64) def apply(self, inputs, mask = None): inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[0], kernel_size=(5, 5), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_0'] if mask is not None else None) inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[1], kernel_size=(3, 3), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_1'] if mask is not None else None) return masked.MaskedModule( inputs, features=self.NUM_FEATURES[2], kernel_size=inputs.shape[1:-1], wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_2'] if mask is not None else None) class MaskedTwoLayerMixedConvDense(flax.deprecated.nn.Module): """Two-layer Mixed Conv/Dense Masked Network.""" NUM_FEATURES: Sequence[int] = (16, 32) def apply(self, inputs, mask = None): inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[0], kernel_size=(5, 5), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_0'] if mask is not None else None) return masked.MaskedModule( inputs, features=self.NUM_FEATURES[1], wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_1'] if mask is not None else None) class MaskedTest(parameterized.TestCase): """Tests the flax layer mask.""" def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._batch_size = 2 self._input_dimensions = (28, 28, 1) self._input_shape = ((self._batch_size,) + self._input_dimensions, jnp.float32) self._input = jnp.ones(*self._input_shape) _, initial_params = Dense.init_by_shape(self._rng, (self._input_shape,)) self._unmasked_model = flax.deprecated.nn.Model(Dense, initial_params) self._unmasked_output = self._unmasked_model(self._input) # Use the same initialization for both masked/unmasked models. masked_initial_params = { 'MaskedModule_0': { 'unmasked': initial_params['Dense_0'] } } self._masked_model = flax.deprecated.nn.Model(MaskedDense, masked_initial_params) _, initial_params = DenseTwoLayer.init_by_shape(self._rng, (self._input_shape,)) self._unmasked_model_twolayer = flax.deprecated.nn.Model( DenseTwoLayer, initial_params) self._unmasked_output_twolayer = self._unmasked_model_twolayer(self._input) # Use the same initialization for both masked/unmasked models. masked_initial_params = { 'MaskedModule_0': { 'unmasked': initial_params['Dense_0'] }, 'MaskedModule_1': { 'unmasked': initial_params['Dense_1'] }, } _, initial_params = MaskedTwoLayerDense.init_by_shape( self._rng, (self._input_shape,)) self._masked_model_twolayer = flax.deprecated.nn.Model( MaskedTwoLayerDense, masked_initial_params) _, initial_params = MaskedConv.init_by_shape(self._rng, (self._input_shape,)) self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv, initial_params) _, initial_params = MaskedTwoLayerConv.init_by_shape( self._rng, (self._input_shape,)) self._masked_conv_model_twolayer = flax.deprecated.nn.Model( MaskedTwoLayerConv, initial_params) _, initial_params = MaskedTwoLayerMixedConvDense.init_by_shape( self._rng, (self._input_shape,)) self._masked_mixed_model_twolayer = flax.deprecated.nn.Model( MaskedTwoLayerMixedConvDense, initial_params) _, initial_params = MaskedThreeLayerConvDense.init_by_shape( self._rng, (self._input_shape,)) self._masked_conv_fc_model_threelayer = flax.deprecated.nn.Model( MaskedThreeLayerConvDense, initial_params) def test_fully_masked_layer(self): """Tests masked module with full-sparsity mask.""" full_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) masked_output = self._masked_model(self._input, mask=full_mask) with self.subTest(name='fully_masked_dense_values'): self.assertTrue((masked_output == 0).all()) with self.subTest(name='fully_masked_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_no_mask_masked_layer(self): """Tests masked module with no mask.""" masked_output = self._masked_model(self._input, mask=None) with self.subTest(name='no_mask_masked_dense_values'): self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='no_mask_masked_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_empty_mask_masked_layer(self): """Tests masked module with an empty (not sparse) mask.""" empty_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) masked_output = self._masked_model(self._input, mask=empty_mask) with self.subTest(name='empty_mask_masked_dense_values'): self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='empty_mask_masked_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_invalid_mask(self): """Tests using an invalid mask.""" invalid_mask = { 'MaskedModule_0': { 'not_kernel': jnp.ones(self._unmasked_model.params['Dense_0']['kernel'].shape) } } with self.assertRaisesRegex(ValueError, 'Mask is invalid for model.'): self._masked_model(self._input, mask=invalid_mask) def test_shuffled_mask_invalid_model(self): """Tests shuffled mask with model containing no masked layers.""" with self.assertRaisesRegex( ValueError, 'Model does not support masking, i.e. no layers are ' 'wrapped by a MaskedModule.'): masked.shuffled_mask(self._unmasked_model, self._rng, 0.5) def test_shuffled_mask_invalid_sparsity(self): """Tests shuffled mask with invalid sparsity.""" with self.subTest(name='sparsity_too_small'): with self.assertRaisesRegex( ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'): masked.shuffled_mask(self._masked_model, self._rng, -0.5) with self.subTest(name='sparsity_too_large'): with self.assertRaisesRegex( ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'): masked.shuffled_mask(self._masked_model, self._rng, 1.5) def test_shuffled_mask_sparsity_full(self): """Tests shuffled mask generation, for 100% sparsity.""" mask = masked.shuffled_mask(self._masked_model, self._rng, 1.0) with self.subTest(name='shuffled_full_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_values'): self.assertTrue((masked_output == 0).all()) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_shuffled_mask_sparsity_empty(self): """Tests shuffled mask generation, for 0% sparsity.""" mask = masked.shuffled_mask(self._masked_model, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_shuffled_mask_sparsity_half_full(self): """Tests shuffled mask generation, for a half-full mask.""" mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5) param_len = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].size with self.subTest(name='shuffled_mask_values'): self.assertEqual( jnp.sum(mask['MaskedModule_0']['kernel']), param_len // 2) def test_shuffled_mask_sparsity_full_twolayer(self): """Tests shuffled mask generation for two layers, and 100% sparsity.""" mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 1.0) with self.subTest(name='shuffled_full_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values_layer1'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'): self.assertIsNone(mask['MaskedModule_0']['bias']) with self.subTest(name='shuffled_full_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_full_mask_values_layer2'): self.assertTrue((mask['MaskedModule_1']['kernel'] == 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'): self.assertIsNone(mask['MaskedModule_1']['bias']) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_values'): self.assertTrue((masked_output == 0).all()) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape) def test_shuffled_mask_sparsity_empty_twolayer(self): """Tests shuffled mask generation for two layers, for 0% sparsity.""" mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values_layer1'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_empty_mask_values_layer2'): self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all()) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output_twolayer).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape) def test_random_invalid_model(self): """Tests random mask with model containing no masked layers.""" with self.assertRaisesRegex( ValueError, 'Model does not support masking, i.e. no layers are ' 'wrapped by a MaskedModule.'): masked.random_mask(self._unmasked_model, self._rng, 0.5) def test_random_invalid_sparsity(self): """Tests random mask with invalid sparsity.""" with self.subTest(name='random_sparsity_too_small'): with self.assertRaisesRegex( ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'): masked.random_mask(self._masked_model, self._rng, -0.5) with self.subTest(name='random_sparsity_too_large'): with self.assertRaisesRegex( ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'): masked.random_mask(self._masked_model, self._rng, 1.5) def test_random_mask_sparsity_full(self): """Tests random mask generation, for 100% sparsity.""" mask = masked.random_mask(self._masked_model, self._rng, 1.) with self.subTest(name='random_full_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all()) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='random_full_mask_dense_values'): self.assertTrue((masked_output.all() == 0).all()) with self.subTest(name='random_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_random_mask_sparsity_empty(self): """Tests random mask generation, for 0% sparsity.""" mask = masked.random_mask(self._masked_model, self._rng, 0.) with self.subTest(name='random_empty_mask_values'): self.assertEqual( jnp.sum(mask['MaskedModule_0']['kernel']), mask['MaskedModule_0']['kernel'].size) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='random_empty_dense_values'): self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='random_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_random_mask_sparsity_half_full(self): """Tests random mask generation, for a half-full mask.""" mask = masked.random_mask(self._masked_model, self._rng, 0.5) param_len = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].size half_full = param_len / 2 with self.subTest(name='random_mask_values'): self.assertBetween( jnp.sum(mask['MaskedModule_0']['kernel']), 0.66 * half_full, 1.33 * half_full) def test_simple_mask_one_layer(self): """Tests generation of a simple mask.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros(self._masked_model.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, } } gen_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) result, _ = jax.tree_flatten( jax.tree_util.tree_map(lambda x, *xs: (x == xs[0]).all(), mask, gen_mask)) self.assertTrue(all(result)) def test_simple_mask_two_layer(self): """Tests generation of a simple mask.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.zeros(self._masked_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } gen_mask = masked.simple_mask(self._masked_model_twolayer, jnp.zeros, ['kernel']) result, _ = jax.tree_flatten( jax.tree_util.tree_map(lambda x, *xs: (x == xs[0]).all(), mask, gen_mask)) self.assertTrue(all(result)) def test_shuffled_mask_neuron_mask_sparsity_empty(self): """Tests shuffled neuron mask generation, for 0% sparsity.""" mask = masked.shuffled_neuron_mask(self._masked_model, self._rng, 0.0) with self.subTest(name='shuffled_neuron_empty_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_neuron_empty_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_neuron_empty_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_neuron_empty_dense_values'): self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='shuffled_neuron_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_shuffled_mask_neuron_mask_sparsity_half_full(self): """Tests shuffled mask generation, for a half-full mask.""" mask = masked.shuffled_neuron_mask(self._masked_model, self._rng, 0.5) param_len = len( self._masked_model.params['MaskedModule_0']['unmasked']['kernel'][:, 0]) mask_sum = jnp.sum(mask['MaskedModule_0']['kernel'][:, 0]) with self.subTest(name='shuffled_mask_values'): # Check that single neuron has the correct sparsity. self.assertEqual(mask_sum, param_len // 2) with self.subTest(name='shuffled_mask_rows_different'): # Check that two rows are different. self.assertFalse( jnp.isclose(mask['MaskedModule_0']['kernel'][:, 0], mask['MaskedModule_0']['kernel'][:, 1]).all()) def test_symmetric_mask_sparsity_empty(self): """Tests symmetric mask generation, for 0% sparsity.""" mask = masked.symmetric_mask(self._masked_model, self._rng, 0.0) with self.subTest(name='shuffled_neuron_empty_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='symmetric_empty_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='symmetric_empty_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='symmetric_empty_dense_values'): self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='symmetric_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_symmetric_mask_sparsity_half_full(self): """Tests shuffled mask generation, for a half-full mask.""" mask = masked.symmetric_mask(self._masked_model, self._rng, 0.5) param_len = len( self._masked_model.params['MaskedModule_0']['unmasked']['kernel'][:, 0]) mask_sum = jnp.sum(mask['MaskedModule_0']['kernel'][:, 0]) with self.subTest(name='symmetric_mask_values'): # Check that single neuron has the correct sparsity. self.assertEqual(mask_sum, param_len // 2) with self.subTest(name='symmetric_mask_rows_different'): # Check that two rows are same. self.assertTrue( jnp.isclose(mask['MaskedModule_0']['kernel'][:, 0], mask['MaskedModule_0']['kernel'][:, 1]).all()) def test_propagate_masks_ablated_neurons_one_layer(self): """Tests mask propagation on a single layer model.""" mask = { 'MaskedModule_0': { 'kernel': jax.random.normal( self._rng, self._masked_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape, dtype=jnp.float32), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) # Since this is a single layer, should not affect mask at all. self.assertTrue((mask['MaskedModule_0']['kernel'] == refined_mask['MaskedModule_0']['kernel']).all()) def test_propagate_masks_ablated_neurons_two_layers(self): """Tests mask propagation on a two-layer model.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones(self._masked_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) with self.subTest(name='layer_1'): self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all()) # Since layer 1 is all zero, layer 2 is also effectively zero. with self.subTest(name='layer_2'): self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all()) def test_propagate_masks_ablated_neurons_two_layers_nonmasked(self): """Tests mask propagation where previous layer is not masked.""" mask = { 'Dense_0': { 'kernel': None, 'bias': None, }, 'MaskedModule_1': { 'kernel': jax.random.normal( self._rng, self._masked_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape, dtype=jnp.float32), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) with self.subTest(name='layer_1'): self.assertIsNone(refined_mask['Dense_0']['kernel']) # Since layer 1 is all zero, layer 2 is also effectively zero. with self.subTest(name='layer_2'): # Since this is a single masked layer, should not affect mask at all. self.assertTrue((mask['MaskedModule_1']['kernel'] == refined_mask['MaskedModule_1']['kernel']).all()) def test_propagate_masks_ablated_neurons_one_conv_layer(self): """Tests mask propagation on a single layer model.""" mask = { 'MaskedModule_0': { 'kernel': jax.random.normal( self._rng, self._masked_conv_model.params['MaskedModule_0']['unmasked'] ['kernel'].shape, dtype=jnp.float32), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) # Since this is a single layer, should not affect mask at all. self.assertTrue((mask['MaskedModule_0']['kernel'] == refined_mask['MaskedModule_0']['kernel']).all()) def test_propagate_masks_ablated_neurons_two_conv_layers(self): """Tests mask propagation on a two-layer convolutional model.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros( self._masked_conv_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones( self._masked_conv_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) with self.subTest(name='layer_1'): self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all()) # Since layer 1 is all zero, layer 2 is also effectively zero. with self.subTest(name='layer_2'): self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all()) def test_propagate_masks_ablated_neurons_three_conv_fc_layers(self): """Tests mask propagation on a two-layer convolutional model with dense.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros(self._masked_conv_fc_model_threelayer .params['MaskedModule_0']['unmasked']['kernel'].shape ), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones(self._masked_conv_fc_model_threelayer .params['MaskedModule_1']['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_2': { 'kernel': jnp.ones(self._masked_conv_fc_model_threelayer .params['MaskedModule_2']['unmasked']['kernel'].shape), 'bias': None, }, } refined_mask = masked.propagate_masks(mask) with self.subTest(name='layer_1'): self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all()) # Since layer 1 is all zero, layer 2 is also effectively zero. with self.subTest(name='layer_2'): self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all()) # Since layer 2 is all zero, layer 3 is also effectively zero. with self.subTest(name='layer_3'): self.assertTrue((refined_mask['MaskedModule_2']['kernel'] == 0).all()) def test_propagate_masks_ablated_neurons_mixed_conv_dense_layers(self): """Tests mask propagation on a two-layer convolutional/dense model.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros( self._masked_mixed_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones( self._masked_mixed_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } with self.assertRaisesRegex( ValueError, 'propagate_masks requires knowledge of the spatial ' 'dimensions of the previous layer. Use a functionally equivalent ' 'conv. layer in place of a dense layer in a model with a mixed ' 'conv/dense setting.'): masked.propagate_masks(mask) def test_mask_layer_sparsity_zero_mask(self): """Tests mask calculation with a zeroed mask.""" zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) self.assertEqual( masked.mask_layer_sparsity(zero_mask['MaskedModule_0']), 0.) def test_mask_layer_sparsity_half_mask(self): """Tests mask calculation with a half-filled mask.""" half_mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5) self.assertAlmostEqual( masked.mask_layer_sparsity(half_mask['MaskedModule_0']), 0.5) def test_mask_layer_sparsity_ones_mask(self): """Tests mask calculation with a mask full of ones.""" one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) self.assertEqual( masked.mask_layer_sparsity(one_mask['MaskedModule_0']), 1.) def test_mask_sparsity_zero_mask(self): """Tests mask calculation with a zeroed mask.""" zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) self.assertEqual(masked.mask_sparsity(zero_mask), 0.) def test_mask_sparsity_ones_mask(self): """Tests mask calculation with a mask full of ones.""" one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) self.assertEqual(masked.mask_sparsity(one_mask), 1.) def test_mask_sparsity_mixed_mask(self): """Tests mask calculation with a mask different sparsity masked layers.""" mask = { 'MaskedModule_0': { 'kernel': jnp.zeros( self._masked_conv_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].shape), 'bias': None, }, 'MaskedModule_1': { 'kernel': jnp.ones( self._masked_conv_model_twolayer.params['MaskedModule_1'] ['unmasked']['kernel'].shape), 'bias': None, }, } mask_sparsity = masked.mask_sparsity(mask) true_sparsity = self._masked_conv_model_twolayer.params['MaskedModule_1'][ 'unmasked']['kernel'].size / ( self._masked_conv_model_twolayer.params['MaskedModule_0'] ['unmasked']['kernel'].size + self._masked_conv_model_twolayer .params['MaskedModule_1']['unmasked']['kernel'].size) self.assertAlmostEqual(mask_sparsity, 1.0 - true_sparsity) @parameterized.parameters( # Simple masked 1-layer model. (1,), # Simple masked 2-layer model. (2,), # Simple masked 10-layer model. (10,), ) def test_generate_model_masks_depth_only(self, depth): mask = masked.generate_model_masks(depth) with self.subTest(name='test_model_mask_length'): self.assertLen(mask, depth) for i in range(depth): with self.subTest(name=f'test_model_mask_value_layer_{i}'): self.assertIsNone(mask[f'MaskedModule_{i}']) @parameterized.parameters( # Simple masked 1-layer model, no masked indices. (1, []), # Simple masked 2-layer model, second layer masked. (2, (1,)), # Simple masked 10-layer model, 4 layers masked. (10, (1, 2, 3, 9)), ) def test_generate_model_masks_indices(self, depth, indices): mask = masked.generate_model_masks(depth, None, indices) with self.subTest(name='test_model_mask_length'): self.assertLen(mask, len(indices)) for i in indices: with self.subTest(name=f'test_model_mask_value_layer_{i}'): self.assertIsNone(mask[f'MaskedModule_{i}']) @parameterized.parameters( # Existing 1-layer mask. (1, {'MaskedModule_0': np.ones(1)}, None), (2, {'MaskedModule_0': np.ones(1), 'MaskedModule_1': np.ones(1)}, None), # Existing 2-layer mask, only using one due to mask indices. (2, {'MaskedModule_0': np.ones(1), 'MaskedModule_1': np.ones(1),}, (1,)), ) def test_generate_model_masks_existing_mask(self, depth, existing_mask, indices): mask = masked.generate_model_masks(depth, existing_mask, indices) # Need to differentiate from empty sequence by explicitly checking is None. if indices is None: indices = range(depth) with self.subTest(name='test_model_mask_length'): self.assertLen(mask, len(indices)) for i in indices: with self.subTest(name=f'test_model_mask_value_layer_{i}'): self.assertIsNotNone(mask[f'MaskedModule_{i}']) # Ensure existing mask layers that aren't in indices aren't in output. for i in range(depth): if i not in indices: with self.subTest( name=f'test_model_mask_only_allowed_indices_layer_{i}'): self.assertNotIn(f'MaskedModule_{i}', mask) def test_generate_model_masks_invalid_depth_zero(self): with self.assertRaisesWithLiteralMatch(ValueError, 'Invalid model depth: 0'): masked.generate_model_masks(0) def test_generate_model_masks_invalid_index_toohigh(self): with self.assertRaisesWithLiteralMatch( ValueError, 'Invalid indices for given depth (2): (1, 2)'): masked.generate_model_masks(2, None, (1, 2)) def test_generate_model_masks_invalid_index_negative(self): with self.assertRaisesWithLiteralMatch( ValueError, 'Invalid indices for given depth (2): (-1, 2)'): masked.generate_model_masks(2, None, (-1, 2)) def test_shuffled_neuron_no_input_ablation_mask_invalid_model(self): """Tests shuffled mask with model containing no masked layers.""" with self.assertRaisesRegex( ValueError, 'Model does not support masking, i.e. no layers are ' 'wrapped by a MaskedModule.'): masked.shuffled_neuron_no_input_ablation_mask(self._unmasked_model, self._rng, 0.5) def test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity(self): """Tests shuffled mask with invalid sparsity.""" with self.subTest(name='sparsity_too_small'): with self.assertRaisesRegex( ValueError, r'Given sparsity, -0.5, is not in range \[0, 1\]'): masked.shuffled_neuron_no_input_ablation_mask(self._masked_model, self._rng, -0.5) with self.subTest(name='sparsity_too_large'): with self.assertRaisesRegex( ValueError, r'Given sparsity, 1.5, is not in range \[0, 1\]'): masked.shuffled_neuron_no_input_ablation_mask(self._masked_model, self._rng, 1.5) def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self): """Tests shuffled mask generation, for 100% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model, self._rng, 1.0) with self.subTest(name='shuffled_full_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values'): self.assertEqual(jnp.count_nonzero(mask['MaskedModule_0']['kernel']), jnp.prod(jnp.array(self._input_dimensions))) with self.subTest(name='shuffled_full_no_input_ablation'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all()) with self.subTest(name='shuffled_full_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty(self): """Tests shuffled mask generation, for 0% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_not_masked_values'): self.assertIsNone(mask['MaskedModule_0']['bias']) masked_output = self._masked_model(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape) def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self): """Tests shuffled mask generation, for a half-full mask.""" mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model, self._rng, 0.5) param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].shape with self.subTest(name='shuffled_mask_values'): self.assertEqual( jnp.sum(mask['MaskedModule_0']['kernel']), param_shape[0]//2 * param_shape[1]) with self.subTest(name='shuffled_half_no_input_ablation'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all()) def test_shuffled_neuron_no_input_ablation_mask_sparsity_quarter_full(self): """Tests shuffled mask generation, for a half-full mask.""" mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model, self._rng, 0.25) param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].shape with self.subTest(name='shuffled_mask_values'): self.assertEqual( jnp.sum(mask['MaskedModule_0']['kernel']), 0.75 * param_shape[0] * param_shape[1]) with self.subTest(name='shuffled_half_no_input_ablation'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all()) def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer(self): """Tests shuffled mask generation for two layers, and 100% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model_twolayer, self._rng, 1.0) with self.subTest(name='shuffled_full_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_full_mask_values_layer1'): self.assertEqual(jnp.count_nonzero(mask['MaskedModule_0']['kernel']), jnp.prod(jnp.array(self._input_dimensions))) with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'): self.assertIsNone(mask['MaskedModule_0']['bias']) with self.subTest(name='shuffled_full_no_input_ablation_layer1'): # Check no row (neurons are columns) is completely ablated. self.assertTrue((jnp.count_nonzero( mask['MaskedModule_0']['kernel'], axis=0) != 0).all()) with self.subTest(name='shuffled_full_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_full_mask_values_layer2'): self.assertEqual(jnp.count_nonzero(mask['MaskedModule_1']['kernel']), jnp.prod(MaskedTwoLayerDense.NUM_FEATURES[0])) with self.subTest(name='shuffled_full_mask_not_masked_values_layer2'): self.assertIsNone(mask['MaskedModule_1']['bias']) with self.subTest(name='shuffled_full_no_input_ablation_layer2'): # Note: check no *inputs* are ablated, and inputs < num_neurons. self.assertEqual( jnp.sum(jnp.count_nonzero(mask['MaskedModule_1']['kernel'], axis=0)), MaskedTwoLayerDense.NUM_FEATURES[0]) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_full_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape) def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolayer(self): """Tests shuffled mask generation for two layers, for 0% sparsity.""" mask = masked.shuffled_neuron_no_input_ablation_mask( self._masked_model_twolayer, self._rng, 0.0) with self.subTest(name='shuffled_empty_mask_layer1'): self.assertIn('MaskedModule_0', mask) with self.subTest(name='shuffled_empty_mask_values_layer1'): self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all()) with self.subTest(name='shuffled_empty_mask_layer2'): self.assertIn('MaskedModule_1', mask) with self.subTest(name='shuffled_empty_mask_values_layer2'): self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all()) masked_output = self._masked_model_twolayer(self._input, mask=mask) with self.subTest(name='shuffled_empty_dense_values'): self.assertTrue( jnp.isclose(masked_output, self._unmasked_output_twolayer).all()) with self.subTest(name='shuffled_empty_mask_dense_shape'): self.assertSequenceEqual(masked_output.shape, self._unmasked_output_twolayer.shape) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/pruning/pruning.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Functions for pruning FLAX masked models.""" from collections import abc from typing import Any, Callable, Mapping, Optional, Union import flax import jax.numpy as jnp from rigl.experimental.jax.pruning import masked def weight_magnitude(weights): """Creates weight magnitude-based saliencies, given a weight matrix.""" return jnp.absolute(weights) def prune( model, pruning_rate, saliency_fn = weight_magnitude, mask = None, compare_fn = jnp.greater): """Returns a mask for a model where the params in each layer are pruned using a saliency function. Args: model: The model to create a pruning mask for. pruning_rate: The fraction of lowest magnitude saliency weights that are pruned. If a float, the same rate is used for all layers, otherwise if it is a mapping, it must contain a rate for all masked layers in the model. saliency_fn: A function that returns a float number used to rank the importance of individual weights in the layer. mask: If the model has an existing mask, the mask will be applied before pruning the model. compare_fn: A pairwise operator to compare saliency with threshold, and return True if the saliency indicates the value should not be masked. Returns: A pruned mask for the given model. """ if not mask: mask = masked.simple_mask(model, jnp.ones, masked.WEIGHT_PARAM_NAMES) if not isinstance(pruning_rate, abc.Mapping): pruning_rate_dict = {} for param_name, _ in masked.iterate_mask(mask): # Get the layer name from the parameter's full name/path. layer_name = param_name.split('/')[-2] pruning_rate_dict[layer_name] = pruning_rate pruning_rate = pruning_rate_dict for param_path, param_mask in masked.iterate_mask(mask): split_param_path = param_path.split('/') layer_name = split_param_path[-2] param_name = split_param_path[-1] # If we don't have a pruning rate for the given layer, don't mask it. if layer_name in pruning_rate and mask[layer_name][param_name] is not None: param_value = model.params[layer_name][ masked.MaskedModule.UNMASKED][param_name] # Here any existing mask is first applied to weight matrix. # Note: need to check explicitly is not None for np array. if param_mask is not None: saliencies = saliency_fn(param_mask * param_value) else: saliencies = saliency_fn(param_value) # TODO: Use partition here (partial sort) instead of sort, # since it's O(N), not O(N log N), however JAX doesn't support it. sorted_param = jnp.sort(jnp.abs(saliencies.flatten())) # Figure out the weight magnitude threshold. threshold_index = jnp.round(pruning_rate[layer_name] * sorted_param.size).astype(jnp.int32) threshold = sorted_param[threshold_index] mask[layer_name][param_name] = jnp.array( compare_fn(saliencies, threshold), dtype=jnp.int32) return mask ================================================ FILE: rigl/experimental/jax/pruning/pruning_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.pruning.pruning.""" from typing import Mapping, Optional, Sequence from absl.testing import absltest import flax import jax import jax.numpy as jnp from rigl.experimental.jax.pruning import masked from rigl.experimental.jax.pruning import pruning class MaskedDense(flax.deprecated.nn.Module): """Single-layer Dense Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_0'] if mask else None) class MaskedTwoLayerDense(flax.deprecated.nn.Module): """Two-layer Dense Masked Network.""" NUM_FEATURES: Sequence[int] = (32, 64) def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[0], wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_0'] if mask else None) return masked.MaskedModule( inputs, features=self.NUM_FEATURES[1], wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_1'] if mask else None) class MaskedConv(flax.deprecated.nn.Module): """Single-layer Conv Masked Network.""" NUM_FEATURES: int = 32 def apply(self, inputs, mask = None): return masked.MaskedModule( inputs, features=self.NUM_FEATURES, kernel_size=(3, 3), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_0'] if mask is not None else None) class MaskedTwoLayerConv(flax.deprecated.nn.Module): """Two-layer Conv Masked Network.""" NUM_FEATURES: Sequence[int] = (16, 32) def apply(self, inputs, mask = None): inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[0], kernel_size=(5, 5), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_0'] if mask is not None else None) return masked.MaskedModule( inputs, features=self.NUM_FEATURES[1], kernel_size=(3, 3), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_1'] if mask is not None else None) class PruningTest(absltest.TestCase): """Tests the flax layer pruning module.""" def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._batch_size = 2 self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._input = jnp.ones(*self._input_shape) _, initial_params = MaskedDense.init_by_shape(self._rng, (self._input_shape,)) self._masked_model = flax.deprecated.nn.Model(MaskedDense, initial_params) _, initial_params = MaskedTwoLayerDense.init_by_shape( self._rng, (self._input_shape,)) self._masked_model_twolayer = flax.deprecated.nn.Model( MaskedTwoLayerDense, initial_params) _, initial_params = MaskedConv.init_by_shape(self._rng, (self._input_shape,)) self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv, initial_params) _, initial_params = MaskedTwoLayerConv.init_by_shape( self._rng, (self._input_shape,)) self._masked_conv_model_twolayer = flax.deprecated.nn.Model( MaskedTwoLayerConv, initial_params) def test_prune_single_layer_dense_no_mask(self): """Tests pruning of single dense layer without an existing mask.""" pruned_mask = pruning.prune(self._masked_model, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3) def test_prune_single_layer_local_pruning(self): """Test pruning of model with a single layer, and local pruning schedule.""" pruned_mask = pruning.prune(self._masked_model, { 'MaskedModule_0': 0.5, }) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3) def test_prune_single_layer_dense_with_mask(self): """Tests pruning of single dense layer with an existing mask.""" pruned_mask = pruning.prune( self._masked_model, 0.5, mask=masked.shuffled_mask(self._masked_model, self._rng, 0.95)) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.95, places=3) def test_prune_two_layers_dense_no_mask(self): """Tests pruning of model with two dense layers without an existing mask.""" pruned_mask = pruning.prune(self._masked_model_twolayer, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_layer1_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_layer2_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3) def test_prune_two_layer_local_pruning_rate(self): """Test pruning of model with two layers, and a local pruning schedule.""" pruned_mask = pruning.prune(self._masked_model_twolayer, { 'MaskedModule_1': 0.5, }) mask_layer_0_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_0']) mask_layer_1_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_1']) with self.subTest(name='test_mask_layer1_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_layer2_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel']) with self.subTest(name='test_mask_layer_0_sparsity'): self.assertEqual(mask_layer_0_sparsity, 0.) with self.subTest(name='test_mask_layer_1_sparsity'): self.assertAlmostEqual(mask_layer_1_sparsity, 0.5, places=3) def test_prune_one_layer_conv_no_mask(self): """Tests pruning of model with one conv. layer without an existing mask.""" pruned_mask = pruning.prune(self._masked_conv_model, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=1) def test_prune_one_layer_conv_with_mask(self): """Tests pruning of model with one conv. layer with an existing mask.""" pruned_mask = pruning.prune( self._masked_conv_model, 0.5, mask=masked.shuffled_mask(self._masked_model, self._rng, 0.95)) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.95, places=3) def test_prune_two_layer_conv_no_mask(self): """Tests pruning of model with two conv. layer without an existing mask.""" pruned_mask = pruning.prune(self._masked_conv_model_twolayer, 0.5) mask_sparsity = masked.mask_sparsity(pruned_mask) with self.subTest(name='test_mask_layer1_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel']) with self.subTest(name='test_mask_layer2_param_not_none'): self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel']) with self.subTest(name='test_mask_sparsity'): self.assertAlmostEqual(mask_sparsity, 0.5, places=3) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/pruning/symmetry.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Code for analyzing symmetries in NN.""" import functools import math import operator from typing import Dict, Optional, Union import jax.numpy as jnp import numpy as np from rigl.experimental.jax.pruning import masked from rigl.experimental.jax.utils import utils def count_permutations_mask_layer( mask_layer, next_mask_layer = None, parameter_key = 'kernel'): """Calculates the number of permutations for a layer, given binary masks. Args: mask_layer: The binary weight mask of a dense/conv layer, where last dimension is number of neurons/filters. next_mask_layer: The binary weight mask of the following a dense/conv layer, or None if this is the last layer. parameter_key: The name of the parameter to count the permutations of in each layer. Returns: A dictionary with stats on the permutation structure of a mask, including the number of symmetric permutations of the mask, number of unique mask columns, count of the zeroed out (structurally pruned) neurons, and total number of neurons/filters. """ # Have to check 'is None' since mask_layer[parameter_key] is jnp.array. if not mask_layer or parameter_key not in mask_layer or mask_layer[ parameter_key] is None: return { 'permutations': 1, 'zeroed_neurons': 0, 'total_neurons': 0, 'unique_neurons': 0, } mask = mask_layer[parameter_key] num_neurons = mask.shape[-1] # Initialize with stats for an empty mask. mask_stats = { 'permutations': 0, 'zeroed_neurons': num_neurons, 'total_neurons': num_neurons, 'unique_neurons': 0, } # Re-shape masks as 1D, in case they are 2D (e.g. convolutional). connection_mask = mask.reshape(-1, num_neurons) # Only consider non-zero columns (in JAX neurons/filters are last index). non_zero_neurons = ~jnp.all(connection_mask == 0, axis=0) # Count only zeroed neurons in the current layer. zeroed_count = num_neurons - jnp.count_nonzero(non_zero_neurons) # Special case where all neurons in current layer are ablated. if zeroed_count == num_neurons: return mask_stats # Have to check is None since next_mask_layer[parameter_key] is jnp.array. if next_mask_layer and parameter_key in next_mask_layer and next_mask_layer[ parameter_key] is not None: next_mask = next_mask_layer[parameter_key] # Re-shape masks as 1D, in case they are 2D (e.g. convolutional). next_connection_mask = next_mask.T.reshape(-1, num_neurons) # Update with neurons that are non-zero in outgoing connections too. non_zero_neurons &= ~jnp.all(next_connection_mask == 0, axis=0) # Remove rows corresponding to neurons that are ablated. next_connection_mask = next_connection_mask[:, non_zero_neurons] connection_mask = connection_mask[:, non_zero_neurons] # Combine the outgoing and incoming masks in one vector per-neuron. connection_mask = jnp.concatenate( (connection_mask, next_connection_mask), axis=0) else: connection_mask = connection_mask[:, non_zero_neurons] # Effectively no connections between these two layers. if not connection_mask.size: return mask_stats # Note: np.unique not implemented in JAX numpy yet. _, unique_counts = np.unique(connection_mask, axis=-1, return_counts=True) # Convert from device array. mask_stats['zeroed_neurons'] = int(zeroed_count) mask_stats['permutations'] = functools.reduce( operator.mul, (np.math.factorial(t) for t in unique_counts)) mask_stats['unique_neurons'] = len(unique_counts) return mask_stats def count_permutations_mask(mask): """Calculates the number of permutations for a given model mask. Args: mask: Model masks to check, similar to Model.params. Returns: A dictionary with stats on the permutation structure of a mask, including the number of symmetric permutations of the mask, number of unique mask columns, count of the zeroed out (structurally pruned) neurons, and total number of neurons/filters. """ sum_keys = ('total_neurons', 'unique_neurons', 'zeroed_neurons') product_keys = ('permutations',) # Count permutation stats for each pairwise set of layers. # Note: I tried doing this with more_itertools.pairwise/itertools.chain, but # there is a type conflict in passing iterators of different types to # itertools.chain. counts = [ count_permutations_mask_layer(layer, next_layer) for layer, next_layer in utils.pairwise_longest(mask.values()) ] sum_stats = {} for key in sum_keys: sum_stats[key] = functools.reduce(operator.add, (z[key] for z in counts)) product_stats = {} for key in product_keys: product_stats[key] = functools.reduce(operator.mul, (z[key] for z in counts)) return {**sum_stats, **product_stats} def get_mask_stats(mask): """Calculates an array of mask statistics. Args: mask: A model mask to calculate the statistics of. Returns: A dictionary, containing a set of mask statistics. """ mask_stats = count_permutations_mask(mask) mask_stats.update({ 'sparsity': masked.mask_sparsity(mask), 'permutation_num_digits': len(str(mask_stats['permutations'])), 'permutation_log10': math.log10(mask_stats['permutations'] + 1), }) return mask_stats ================================================ FILE: rigl/experimental/jax/pruning/symmetry_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.pruning.symmetry.""" import functools import math import operator from typing import Mapping, Optional, Sequence from absl.testing import absltest from absl.testing import parameterized import flax import jax import jax.numpy as jnp import numpy as np from rigl.experimental.jax.pruning import masked from rigl.experimental.jax.pruning import symmetry class MaskedDense(flax.deprecated.nn.Module): """Single-layer Dense Masked Network. Attributes: NUM_FEATURES: The number of neurons in the single dense layer. """ NUM_FEATURES: int = 16 def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) return masked.MaskedModule( inputs, features=self.NUM_FEATURES, wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_0'] if mask is not None else None) class MaskedConv(flax.deprecated.nn.Module): """Single-layer Conv Masked Network. Attributes: NUM_FEATURES: The number of filters in the single conv layer. """ NUM_FEATURES: int = 16 def apply(self, inputs, mask = None): return masked.MaskedModule( inputs, features=self.NUM_FEATURES, kernel_size=(3, 3), wrapped_module=flax.deprecated.nn.Conv, mask=mask['MaskedModule_0'] if mask is not None else None) class MaskedTwoLayerDense(flax.deprecated.nn.Module): """Two-layer Dense Masked Network. Attributes: NUM_FEATURES: The number of neurons in the dense layers. """ NUM_FEATURES: Sequence[int] = (16, 32) def apply(self, inputs, mask = None): inputs = inputs.reshape(inputs.shape[0], -1) inputs = masked.MaskedModule( inputs, features=self.NUM_FEATURES[0], wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_0'] if mask is not None else None) inputs = flax.deprecated.nn.relu(inputs) return masked.MaskedModule( inputs, features=self.NUM_FEATURES[1], wrapped_module=flax.deprecated.nn.Dense, mask=mask['MaskedModule_1'] if mask is not None else None) class SymmetryTest(parameterized.TestCase): """Tests symmetry analysis methods.""" def setUp(self): super().setUp() self._rng = jax.random.PRNGKey(42) self._batch_size = 2 self._input_shape = ((self._batch_size, 2, 2, 1), jnp.float32) self._flat_input_shape = ((self._batch_size, 2 * 2 * 1), jnp.float32) _, initial_params = MaskedDense.init_by_shape(self._rng, (self._flat_input_shape,)) self._masked_model = flax.deprecated.nn.Model(MaskedDense, initial_params) _, initial_params = MaskedConv.init_by_shape(self._rng, (self._input_shape,)) self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv, initial_params) _, initial_params = MaskedTwoLayerDense.init_by_shape( self._rng, (self._flat_input_shape,)) self._masked_two_layer_model = flax.deprecated.nn.Model( MaskedTwoLayerDense, initial_params) def test_count_permutations_layer_mask_full(self): """Tests count of weight permutations in a full mask.""" mask_layer = { 'kernel': jnp.ones(self._masked_model.params['MaskedModule_0']['unmasked'] ['kernel'].shape), } stats = symmetry.count_permutations_mask_layer(mask_layer) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(MaskedDense.NUM_FEATURES)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedDense.NUM_FEATURES) def test_count_permutations_layer_mask_empty(self): """Tests count of weight permutations in an empty mask.""" mask_layer = { 'kernel': jnp.zeros(self._masked_model.params['MaskedModule_0']['unmasked'] ['kernel'].shape), } stats = symmetry.count_permutations_mask_layer(mask_layer) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedDense.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedDense.NUM_FEATURES) def test_count_permutations_conv_layer_mask_full(self): """Tests count of weight permutations in a full mask for a conv. layer.""" mask_layer = { 'kernel': jnp.ones(self._masked_conv_model.params['MaskedModule_0'] ['unmasked']['kernel'].shape), } stats = symmetry.count_permutations_mask_layer(mask_layer) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(MaskedConv.NUM_FEATURES)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES) def test_count_permutations_conv_layer_mask_empty(self): """Tests count of weight permutations in an empty mask for a conv. layer.""" mask_layer = { 'kernel': jnp.zeros(self._masked_conv_model.params['MaskedModule_0'] ['unmasked']['kernel'].shape), } stats = symmetry.count_permutations_mask_layer(mask_layer) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES) def test_count_permutations_layer_mask_known_perm(self): """Tests count of weight permutations in a mask with known # permutations.""" param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].shape # Create two unique random mask rows. row_type_one = jax.random.bernoulli( self._rng, p=0.3, shape=(param_shape[0],)).astype(jnp.int32) row_type_two = jax.random.bernoulli( self._rng, p=0.9, shape=(param_shape[0],)).astype(jnp.int32) # Create mask by repeating the two unique rows. repeat_one = param_shape[-1] // 3 repeat_two = param_shape[-1] - repeat_one mask_layer = {'kernel': jnp.concatenate( (jnp.repeat(row_type_one[:, jnp.newaxis], repeat_one, axis=-1), jnp.repeat(row_type_two[:, jnp.newaxis], repeat_two, axis=-1)), axis=-1)} stats = symmetry.count_permutations_mask_layer(mask_layer) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 2) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(repeat_one) * math.factorial(repeat_two)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], param_shape[-1]) def test_count_permutations_layer_mask_known_perm_zeros(self): """Tests count of weight permutations in a mask with zeroed neurons.""" param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][ 'kernel'].shape # Create two unique random mask rows. row_type_one = jax.random.bernoulli( self._rng, p=0.3, shape=(param_shape[0],)).astype(jnp.int32) row_type_two = jnp.zeros(shape=(param_shape[0],), dtype=jnp.int32) # Create mask by repeating the two unique rows. repeat_one = param_shape[-1] // 3 repeat_two = param_shape[-1] - repeat_one mask_layer = {'kernel': jnp.concatenate( (jnp.repeat(row_type_one[:, jnp.newaxis], repeat_one, axis=-1), jnp.repeat(row_type_two[:, jnp.newaxis], repeat_two, axis=-1)), axis=-1)} stats = symmetry.count_permutations_mask_layer(mask_layer) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(repeat_one)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], repeat_two) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], param_shape[-1]) def test_count_permutations_shuffled_full_mask(self): """Tests count of weight permutations on a generated full mask.""" mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=1) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES) def test_count_permutations_shuffled_empty_mask(self): """Tests count of weight permutations on a generated empty mask.""" mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=0) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(MaskedConv.NUM_FEATURES)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES) def test_count_permutations_mask_layer_twolayer_known_symmetric(self): """Tests count of permutations in a known mask with 2 permutations.""" mask = { 'MaskedModule_0': { 'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T, }, 'MaskedModule_1': { 'kernel': jnp.array(((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T, }, } stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'], mask['MaskedModule_1']) with self.subTest(name='count_permutations_unique'): self.assertEqual(stats['unique_neurons'], 2) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 2) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1]) # Note: Can't pass jnp.array here since global, InitGoogle() not called yet. @parameterized.parameters( # Tests mask with 1 permutation only if both layers are considered. ({ 'MaskedModule_0': { 'kernel': np.array(((1, 0), (1, 0), (0, 1))).T, }, 'MaskedModule_1': { 'kernel': np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T, }, }, 3, 1, 0, 3), # Tests mask zero count with an ablated neuron in first layer. ({ 'MaskedModule_0': { 'kernel': np.array(((1, 0), (1, 0), (0, 0))).T, }, 'MaskedModule_1': { 'kernel': np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T, }, }, 2, 1, 1, 3), # Tests mask zero count with first layer completely ablated. ({ 'MaskedModule_0': { 'kernel': np.array(((0, 0), (0, 0), (0, 0))).T, }, 'MaskedModule_1': { 'kernel': np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T, }, }, 0, 0, 3, 3), # Tests mask zero count with second layer completely ablated. ({ 'MaskedModule_0': { 'kernel': np.array(((1, 0), (1, 0), (0, 1))).T, }, 'MaskedModule_1': { 'kernel': np.array(((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))).T, }, }, 0, 0, 3, 3), # """Tests layer 1 permutation matrix mask, having only 1 permutation.""" ({ 'MaskedModule_0': { 'kernel': np.array(((1, 0, 0), (0, 1, 0), (0, 0, 1))).T, }, 'MaskedModule_1': { 'kernel': np.array(((1, 1, 1), (0, 1, 1), (1, 1, 1), (1, 1, 1))).T, }, }, 3, 1, 0, 3), ) def test_count_permutations_mask_layer_twolayer(self, mask, unique, permutations, zeroed, total): """Test mask permutations if both layers are considered.""" stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'], mask['MaskedModule_1']) with self.subTest(name='count_permutations_unique'): self.assertEqual(stats['unique_neurons'], unique) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], permutations) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], zeroed) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], total) def test_count_permutations_mask_full(self): """Tests count of weight permutations in a full mask.""" mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(MaskedDense.NUM_FEATURES)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES) def test_count_permutations_mask_bn_layer_full(self): """Tests count of permutations on a mask for model with non-masked layers.""" mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 1) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], math.factorial(MaskedDense.NUM_FEATURES)) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES) def test_count_permutations_mask_empty(self): """Tests count of weight permutations in an empty mask.""" mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES) def test_count_permutations_mask_twolayer_full(self): """Tests count of weight permutations in a full mask for 2 layers.""" mask = masked.simple_mask(self._masked_two_layer_model, jnp.ones, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 2) with self.subTest(name='count_permutations_permutations'): self.assertEqual( stats['permutations'], functools.reduce( operator.mul, [math.factorial(x) for x in MaskedTwoLayerDense.NUM_FEATURES])) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 0) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES)) def test_count_permutations_mask_twolayers_empty(self): """Tests count of weight permutations in an empty mask for 2 layers.""" mask = masked.simple_mask(self._masked_two_layer_model, jnp.zeros, ['kernel']) stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_mask_unique'): self.assertEqual(stats['unique_neurons'], 0) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 0) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES)) with self.subTest(name='count_permutations_total'): self.assertEqual(stats['total_neurons'], sum(MaskedTwoLayerDense.NUM_FEATURES)) def test_count_permutations_mask_twolayer_known_symmetric(self): """Tests count of permutations in a known mask with 4 permutations.""" mask = { 'MaskedModule_0': { 'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T }, 'MaskedModule_1': { 'kernel': jnp.array(((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T } } stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_full_mask_unique'): self.assertEqual(stats['unique_neurons'], 4) with self.subTest(name='count_permutations_full_mask_permutations'): self.assertEqual(stats['permutations'], 4) with self.subTest(name='count_permutations_full_mask_zeroed'): self.assertEqual(stats['zeroed_neurons'], 1) with self.subTest(name='Count_permutations_full_mask_total'): self.assertEqual( stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] + mask['MaskedModule_1']['kernel'].shape[-1]) def test_count_permutations_mask_twolayer_known_non_symmetric(self): """Tests mask with 1 permutation only if both layers are considered.""" mask = { 'MaskedModule_0': { 'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T }, 'MaskedModule_1': { 'kernel': jnp.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T } } stats = symmetry.count_permutations_mask(mask) with self.subTest(name='count_permutations_unique'): self.assertEqual(stats['unique_neurons'], 6) with self.subTest(name='count_permutations_permutations'): self.assertEqual(stats['permutations'], 1) with self.subTest(name='count_permutations_zeroed'): self.assertEqual(stats['zeroed_neurons'], 1) with self.subTest(name='count_permutations_total'): self.assertEqual( stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] + mask['MaskedModule_1']['kernel'].shape[-1]) def test_get_mask_stats_keys_values(self): """Tests the returned dict has the required keys, and value types/ranges.""" mask = { 'MaskedModule_0': { 'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T }, 'MaskedModule_1': { 'kernel': jnp.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T } } mask_stats = symmetry.get_mask_stats(mask) with self.subTest(name='sparsity_exists'): self.assertIn('sparsity', mask_stats) with self.subTest(name='sparsity_value'): self.assertBetween(mask_stats['sparsity'], 0.0, 1.0) with self.subTest(name='permutation_num_digits_exists'): self.assertIn('permutation_num_digits', mask_stats) with self.subTest(name='permutation_num_digits_value'): self.assertGreaterEqual(mask_stats['permutation_num_digits'], 0.0) with self.subTest(name='permutation_log10_exists'): self.assertIn('permutation_log10', mask_stats) with self.subTest(name='permutation_log10_value'): self.assertGreaterEqual(mask_stats['permutation_log10'], 0.0) with self.subTest(name='unique_neurons_exists'): self.assertIn('unique_neurons', mask_stats) with self.subTest(name='unique_neurons_value'): self.assertEqual(mask_stats['unique_neurons'], 6) with self.subTest(name='permutations_exists'): self.assertIn('permutations', mask_stats) with self.subTest(name='permutations_value'): self.assertEqual(mask_stats['permutations'], 1) with self.subTest(name='zeroed_neurons_exists'): self.assertIn('zeroed_neurons', mask_stats) with self.subTest(name='zeroed_neurons_value'): self.assertEqual(mask_stats['zeroed_neurons'], 1) with self.subTest(name='total_neurons_exists'): self.assertIn('total_neurons', mask_stats) with self.subTest(name='total_neurons_value'): self.assertEqual(mask_stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] + mask['MaskedModule_1']['kernel'].shape[-1]) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/random_mask.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Weight Symmetry: Train model with randomly sampled sparse mask.""" import ast from os import path from typing import List, Sequence import uuid from absl import app from absl import flags from absl import logging import flax from flax.metrics import tensorboard from flax.training import lr_schedule import jax import jax.numpy as jnp from rigl.experimental.jax.datasets import dataset_factory from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.pruning import mask_factory from rigl.experimental.jax.pruning import masked from rigl.experimental.jax.pruning import symmetry from rigl.experimental.jax.training import training from rigl.experimental.jax.utils import utils experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, masked_layer_indices=FLAGS.masked_layer_indices) logging.info('Generating random mask based on model') # Re-initialize the RNG to maintain same training pattern (as in prune code). mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed) mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng, FLAGS.mask_sparsity) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json( mask_stats, path.join(experiment_dir, 'mask_stats.json')) mask = masked.propagate_masks(mask) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Propagated mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'propagated_mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'propagated_mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json( mask_stats, path.join(experiment_dir, 'propagated_mask_stats.json')) model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, masks=mask) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam( learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum( learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}') if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train( FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch, ) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: if FLAGS.dump_json: utils.dump_dict_json(best_metrics, path.join(experiment_dir, 'best_metrics.json')) for label, value in best_metrics.items(): summary_writer.scalar(f'best/{label}', value, FLAGS.epochs * steps_per_epoch) summary_writer.close() def main(argv: List[str]): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') run_training() if __name__ == '__main__': app.run(main) ================================================ FILE: rigl/experimental/jax/random_mask_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.random_mask.""" import glob from os import path import tempfile from absl.testing import absltest from absl.testing import flagsaver from rigl.experimental.jax import random_mask class RandomMaskTest(absltest.TestCase): def test_run_fc(self): """Test random mask driver with fully-connected model.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, model='MNIST_FC', ) with flagsaver.flagsaver(**self._eval_flags): random_mask.main([]) with self.subTest(name='tf_summary_file_exists'): outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_conv(self): """Test random mask driver with CNN model.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, model='MNIST_CNN', ) with flagsaver.flagsaver(**self._eval_flags): random_mask.main([]) with self.subTest(name='tf_summary_file_exists'): outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_random(self): """Test random mask driver with per-neuron sparsity.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, mask_type='random', ) with flagsaver.flagsaver(**self._eval_flags): random_mask.main([]) with self.subTest(name='tf_summary_file_exists'): outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_per_neuron(self): """Test random mask driver with per-neuron sparsity.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, mask_type='per_neuron', ) with flagsaver.flagsaver(**self._eval_flags): random_mask.main([]) with self.subTest(name='tf_summary_file_exists'): outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_symmetric(self): """Test random mask driver with per-neuron sparsity.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, mask_type='symmetric', ) with flagsaver.flagsaver(**self._eval_flags): random_mask.main([]) with self.subTest(name='tf_summary_file_exists'): outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/requirements.txt ================================================ absl-py>=0.10.0 flax>=0.2.2 jax>=0.2.0 jaxlib>=0.1.55 tensorboard>=2.3.0 tensorflow>=2.3.1 tensorflow_datasets>=3.2.1 ================================================ FILE: rigl/experimental/jax/run.sh ================================================ # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #!/bin/bash set -e set -x virtualenv -p python3 . source ./bin/activate pip install -r weight_symmetry/requirements.txt TEST_NAMES='training.training_test train_test fixed_param_test shuffled_mask_test models.model_factory_test models.cifar10_cnn_test models.mnist_cnn_test models.mnist_fc_test utils.utils_test prune_test random_mask_test pruning.mask_factory_test pruning.init_test pruning.symmetry_test pruning.pruning_test pruning.masked_test datasets.dataset_factory_test datasets.dataset_base_test datasets.cifar10_test datasets.mnist_test' IFS=$'\n' readarray -t tests <<<$TEST_NAMES for test in ${tests[@]}; do python3 -m "weight_symmetry.${test}" done ================================================ FILE: rigl/experimental/jax/shuffled_mask.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Weight Symmetry: Train model with randomly shuffled sparse mask.""" # TODO: Refactor drivers to separate logic from flags/IO. import ast from os import path from typing import List, Sequence import uuid from absl import app from absl import flags from absl import logging import flax from flax.metrics import tensorboard from flax.training import lr_schedule import jax import jax.numpy as jnp from rigl.experimental.jax.datasets import dataset_factory from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.pruning import mask_factory from rigl.experimental.jax.pruning import masked from rigl.experimental.jax.pruning import symmetry from rigl.experimental.jax.training import training from rigl.experimental.jax.utils import utils experiment_dir = '{}/{}/'.format(FLAGS.experiment_dir, work_unit_id) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape base_model, _ = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes) logging.info('Generating random mask based on model') # Re-initialize the RNG to maintain same training pattern (as in prune code). mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed) mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng, FLAGS.mask_sparsity) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json( mask_stats, path.join(experiment_dir, 'mask_stats.json')) mask = masked.propagate_masks(mask) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Propagated mask stats: %s', str(mask_stats)) for label, value in mask_stats.items(): try: summary_writer.scalar(f'propagated_mask/{label}', value, 0) # This is needed because permutations (long int) can't be cast to float32. except (OverflowError, ValueError): summary_writer.text(f'propagated_mask/{label}', str(value), 0) logging.error('Could not write mask/%s to tensorflow summary as float32' ', writing as string instead.', label) if FLAGS.dump_json: mask_stats['permutations'] = str(mask_stats['permutations']) utils.dump_dict_json( mask_stats, path.join(experiment_dir, 'propagated_mask_stats.json')) model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, jnp.float32),), num_classes=dataset.num_classes, masks=mask) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam( learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum( learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule)) if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train( FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch, ) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: if FLAGS.dump_json: utils.dump_dict_json(best_metrics, path.join(experiment_dir, 'best_metrics.json')) for label, value in best_metrics.items(): summary_writer.scalar('best/{}'.format(label), value, FLAGS.epochs * steps_per_epoch) summary_writer.close() def main(argv: List[str]): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') run_training() if __name__ == '__main__': app.run(main) ================================================ FILE: rigl/experimental/jax/shuffled_mask_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.shuffled_mask.""" import glob from os import path import tempfile from absl.testing import absltest from absl.testing import flagsaver from rigl.experimental.jax import shuffled_mask class ShuffledMaskTest(absltest.TestCase): def test_run_fc(self): """Tests if the driver for shuffled training runs correctly with FC NN.""" experiment_dir = tempfile.mkdtemp() eval_flags = dict( epochs=1, experiment_dir=experiment_dir, model='MNIST_FC', ) with flagsaver.flagsaver(**eval_flags): shuffled_mask.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_conv(self): """Tests if the driver for shuffled training runs correctly with CNN.""" experiment_dir = tempfile.mkdtemp() eval_flags = dict( epochs=1, experiment_dir=experiment_dir, model='MNIST_CNN', ) with flagsaver.flagsaver(**eval_flags): shuffled_mask.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_random(self): """Test random mask driver with per-neuron sparsity.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, mask_type='random', ) with flagsaver.flagsaver(**self._eval_flags): shuffled_mask.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_per_neuron(self): """Test random mask driver with per-neuron sparsity.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, mask_type='per_neuron', ) with flagsaver.flagsaver(**self._eval_flags): shuffled_mask.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) def test_run_symmetric(self): """Test random mask driver with per-neuron sparsity.""" experiment_dir = tempfile.mkdtemp() self._eval_flags = dict( epochs=1, experiment_dir=experiment_dir, mask_type='symmetric', ) with flagsaver.flagsaver(**self._eval_flags): shuffled_mask.main([]) outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/train.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Weight Symmetry: Train Model. Trains a model from scratch, saving the relevant early weight snapshots. """ import ast from os import path from typing import List, Sequence import uuid from absl import app from absl import flags from absl import logging import flax from flax.metrics import tensorboard from flax.training import lr_schedule import jax import jax.numpy as np from rigl.experimental.jax.datasets import dataset_factory from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.training import training FLAGS = flags.FLAGS MODEL_LIST: Sequence[str] = tuple(model_factory.MODELS.keys()) DATASET_LIST: Sequence[str] = tuple(dataset_factory.DATASETS.keys()) flags.DEFINE_enum('model', MODEL_LIST[0], MODEL_LIST, 'Model to train.') flags.DEFINE_enum('dataset', DATASET_LIST[0], DATASET_LIST, 'Dataset to train on.') flags.DEFINE_enum('optimizer', 'Adam', ['Momentum', 'Adam'], 'Optimizer to use.') flags.DEFINE_float('learning_rate', 0.01, 'Learning rate.', short_name='lr') flags.DEFINE_float('weight_decay', 1e-5, 'Weight decay penalty.', short_name='wd') flags.DEFINE_float('momentum', 0.9, 'Momentum weighting.') flags.DEFINE_string( 'lr_schedule', default='stepped', help=('Learning rate schedule type; constant, stepped or cosine.')) flags.DEFINE_string( 'lr_schedule_steps', default='[[50, 0.01], [70, 0.001], [90, 0.0001]]', help=('Learning rate schedule steps as a Python list; ' '[[step1_epoch, step1_lr_scale], ' '[step2_epoch, step2_lr_scale], ...]')) flags.DEFINE_integer( 'batch_size', 128, 'Training minibatch size.', lower_bound=1) flags.DEFINE_integer( 'batch_size_test', 128, 'Test minibatch size.', lower_bound=1) flags.DEFINE_integer( 'epochs', 100, 'Number of epochs to train over.', lower_bound=1) flags.DEFINE_integer('random_seed', 42, 'Random seed.') flags.DEFINE_integer('shuffle_buffer_size', 1024, 'Dataset shuffle buffer size.') flags.DEFINE_string( 'experiment_dir', '/tmp/experiments', 'Path to store experiment output in, i.e. models, snapshots.') flags.DEFINE_integer( 'update_iterations', 1000, 'Epoch interval after which to evaluate test error.', lower_bound=1) flags.DEFINE_integer( 'update_epoch', 10, 'Epoch interval after which to evaluate test error.', lower_bound=1) def run_training(): """Trains a model.""" print('Logging to {}'.format(FLAGS.log_dir)) work_unit_id = uuid.uuid4() experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id)) logging.info('Saving experimental results to %s', experiment_dir) host_count = jax.host_count() local_device_count = jax.local_device_count() logging.info('Device count: %d, host count: %d, local device count: %d', jax.device_count(), host_count, local_device_count) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter(experiment_dir) dataset = dataset_factory.create_dataset( FLAGS.dataset, FLAGS.batch_size, FLAGS.batch_size_test, shuffle_buffer_size=FLAGS.shuffle_buffer_size) logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset) rng = jax.random.PRNGKey(FLAGS.random_seed) input_shape = (1,) + dataset.shape model, initial_state = model_factory.create_model( FLAGS.model, rng, ((input_shape, np.float32),), num_classes=dataset.num_classes) if FLAGS.optimizer == 'Adam': optimizer = flax.optim.Adam( learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay) elif FLAGS.optimizer == 'Momentum': optimizer = flax.optim.Momentum( learning_rate=FLAGS.lr, beta=FLAGS.momentum, weight_decay=FLAGS.weight_decay, nesterov=False) steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size if FLAGS.lr_schedule == 'constant': lr_fn = lr_schedule.create_constant_learning_rate_schedule( FLAGS.lr, steps_per_epoch) elif FLAGS.lr_schedule == 'stepped': lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps) lr_fn = lr_schedule.create_stepped_learning_rate_schedule( FLAGS.lr, steps_per_epoch, lr_schedule_steps) elif FLAGS.lr_schedule == 'cosine': lr_fn = lr_schedule.create_cosine_learning_rate_schedule( FLAGS.lr, steps_per_epoch, FLAGS.epochs) else: raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule)) if jax.host_id() == 0: trainer = training.Trainer( optimizer, model, initial_state, dataset, rng, summary_writer=summary_writer, ) else: trainer = training.Trainer(optimizer, model, initial_state, dataset, rng) _, best_metrics = trainer.train( FLAGS.epochs, lr_fn=lr_fn, update_iter=FLAGS.update_iterations, update_epoch=FLAGS.update_epoch) logging.info('Best metrics: %s', str(best_metrics)) if jax.host_id() == 0: for label, value in best_metrics.items(): summary_writer.scalar('best/{}'.format(label), value, FLAGS.epochs * steps_per_epoch) summary_writer.close() def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') run_training() if __name__ == '__main__': app.run(main) ================================================ FILE: rigl/experimental/jax/train_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.train.""" import glob from os import path import tempfile from absl.testing import absltest from absl.testing import flagsaver from rigl.experimental.jax import train class TrainTest(absltest.TestCase): def test_train_driver_run(self): """Tests that the training driver runs, and outputs a TF summary.""" experiment_dir = tempfile.mkdtemp() eval_flags = dict( epochs=1, experiment_dir=experiment_dir, ) with flagsaver.flagsaver(**eval_flags): train.main([]) with self.subTest(name='tf_summary_file_exists'): outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*') files = glob.glob(outfile) self.assertTrue(len(files) == 1 and path.exists(files[0])) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/training/__init__.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: rigl/experimental/jax/training/training.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common training code. This module contains utility functions for training NN. Attributes: LABELKEY: The key used to retrieve a label from the batch dictionary. DATAKEY: The key used to retrieve an input image from the batch dictionary. PruningRateFnType: Typing alias for a valid pruning rate function. """ from collections import abc import functools import time from typing import Callable, Dict, Mapping, Optional, Tuple, Union from absl import logging import flax from flax import jax_utils from flax.training import common_utils import jax import jax.numpy as jnp from rigl.experimental.jax.datasets import dataset_base from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.pruning import masked from rigl.experimental.jax.pruning import pruning from rigl.experimental.jax.pruning import symmetry from rigl.experimental.jax.utils import utils import tensorflow.compat.v2 as tf LABELKEY = dataset_base.ImageDataset.LABELKEY DATAKEY = dataset_base.ImageDataset.DATAKEY PruningRateFnType = Union[Mapping[str, Callable[[int], float]], Callable[[int], float]] def _shard_batch(xs): """Shards a batch for a pmap, based on the number of devices.""" local_device_count = jax.local_device_count() def _prepare(x): return x.reshape((local_device_count, -1) + x.shape[1:]) return jax.tree_map(_prepare, xs) def train_step( optimizer: flax.optim.Optimizer, batch: Mapping[str, jnp.array], # pytype: disable=module-attr rng: Callable[[int], jnp.array], state: flax.deprecated.nn.Collection, learning_rate_fn: Callable[[int], float] ) -> Tuple[flax.optim.Optimizer, flax.deprecated.nn.Collection, float, float]: # pytype: disable=module-attr """Performs training for one minibatch. Args: optimizer: Optimizer to use. batch: Minibatch to train with. rng: Random number generator, i.e. jax.random.PRNGKey, to use for model training, e.g. dropout. state: Model state. learning_rate_fn: A function that returns the learning rate given the step. Returns: A tuple consisting of the new optimizer, new state, mini-batch loss, and gradient norm. """ def loss_fn( model: flax.deprecated.nn.Model ) -> Tuple[float, Tuple[flax.deprecated.nn.Collection, jnp.array]]: """Evaluates the loss function. Args: model: The model with which to evaluate the loss. Returns: Tuple of the loss for the mini-batch, and model state. """ with flax.deprecated.nn.stateful(state) as new_state: with flax.deprecated.nn.stochastic(rng): logits = model(batch[DATAKEY]) loss = utils.cross_entropy_loss(logits, batch[LABELKEY]) return loss, new_state lr = learning_rate_fn(optimizer.state.step) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, new_state), grad = grad_fn(optimizer.target) grad = jax.lax.pmean(grad, 'batch') new_opt = optimizer.apply_gradient(grad, learning_rate=lr) grad_norm = jnp.linalg.norm(utils.param_as_array(grad)) return new_opt, new_state, loss, grad_norm class Trainer: """Training class with the state and methods for training a neural network. Attributes: optimizer: Optimizer used for training, None if training hasn't begun. state: Model state used for training. """ def __init__( self, optimizer_def: flax.optim.OptimizerDef, # pytype: disable=module-attr initial_model: flax.deprecated.nn.Model, initial_state: flax.deprecated.nn.Collection, dataset: jnp.array, rng: Callable[[int], jnp.array] = None, summary_writer: Optional[tf.summary.SummaryWriter] = None, ): """Creates a Trainer object. Args: optimizer_def: The flax optimizer def (i.e. not instantiated with a model using .create) to use for training. initial_model: The initial model to train. initial_state: The initial state of the model. dataset: The training dataset. rng: Random number generator, i.e. jax.random.PRNGKey, to use for model training, e.g. dropout. summary_writer: An optional tensorboard summary writer for logging self._rng = rng if self._rng is None: self._rng = jax.random.PRNGKey(42) def _update_optimizer(self, model: flax.deprecated.nn.Model): """Updates the optimizer based on the given model.""" self.optimizer = jax_utils.replicate( self._optimizer_def.create(model)) def train( self, num_epochs: int, lr_fn: Optional[Callable[[int], float]] = None, pruning_rate_fn: Optional[PruningRateFnType] = None, update_iter: int = 100, update_epoch: int = 10 ) -> Tuple[flax.deprecated.nn.Model, Mapping[str, Union[int, float, Mapping[ str, float]]]]: """Trains the model over the given number of epochs. Args: num_epochs: The total number of epochs to train over. lr_fn: The learning rate function, takes the current iteration/step as an argument and returns the current learning rate, uses constant learning rate if no function is provided. pruning_rate_fn: The pruning rate function, takes the current epoch as an argument and returns the current pruning rate, no further pruning is performed during training if not set. Can be a dictionary, containing the pruning rate schedule functions for each layer, or a single function for all layers. update_iter: Period of iterations in which to log/update per-batch metrics. update_epoch: Period of epochs in which to log/update full training/test metrics. Returns: Tuple consisting of the best model found during training, and metrics. Raises: ValueError: If the batch size of the data set is not evenly divisible by number of devices, or the model batch size is not the training data batch size/number of jax devices. """ best_test_acc = 0 best_train_loss = None best_iter = None if lr_fn is None: lr_fn = lambda _: self.optimizer.optimizer_def.hyper_params.learning_rate host_count = jax.host_count() device_count = jax.device_count() local_device_count = jax.local_device_count() logging.info('JAX hosts %d, devices: %d, local devices: %d', host_count, device_count, local_device_count) # TODO Implement multi-host training. if host_count > 1: raise NotImplementedError('Multi-host training is not supported yet, ' 'see b/155550457.') if self._dataset.batch_size % device_count > 0: raise ValueError( 'Train batch size ({}) must be divisible by number of local devices ' '({})'.format(self._dataset.batch_size, local_device_count)) if self._dataset.batch_size_test % device_count > 0: raise ValueError( 'Test batch size ({}) must be divisible by number of local devices ' '({})'.format(self._dataset.batch_size_test, local_device_count)) # Required to use state and optimizer with jax.pmap. state = jax_utils.replicate(self.state) self._update_optimizer(self._initial_model) p_train_step = jax.pmap( functools.partial(train_step, learning_rate_fn=lr_fn), axis_name='batch') # Function to sync the batch statistics across replicas. p_synchronized_batch_stats = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x') p_cosine_similarity = functools.partial(utils.cosine_similarity_model, self._initial_model) p_vector_difference_norm = functools.partial( utils.vector_difference_norm_model, self._initial_model) pruning_rate = None mask = None cumulative_grad_norm = 0 start_time = time.time() # Main training loop. for epoch in range(num_epochs): if epoch % update_epoch == 0 or epoch == num_epochs - 1: epoch_start_time = time.time() # If we get different schedules for different layers. if isinstance(pruning_rate_fn, abc.Mapping): next_pruning_rate = { layer: layer_fn(epoch) for layer, layer_fn in pruning_rate_fn.items() } elif pruning_rate_fn: next_pruning_rate = pruning_rate_fn(epoch) # If pruning rate has changed/is first epoch, we need to update mask. # Note: pruning_rate could be zero, so must explicitly check it's None. if pruning_rate_fn and (pruning_rate is None or pruning_rate != next_pruning_rate): pruning_rate = next_pruning_rate logging.info('[%d] Pruning Rate: %s', epoch, str(pruning_rate)) # Unreplicate optimizer/current model, and mask. self.optimizer = jax_utils.unreplicate(self.optimizer) mask = jax_utils.unreplicate(mask) if mask else None # Performs pruning to get updated mask. mask = pruning.prune(self.optimizer.target, pruning_rate, mask=mask) logging.info('[%d] Mask Sparsity: %0.3f', epoch, masked.mask_sparsity(mask)) for layer, layer_mask in sorted(mask.items()): if layer_mask: logging.info('[%d] Layer: %s, Mask Sparsity: %0.3f', epoch, layer, masked.mask_layer_sparsity(layer_mask)) if jax.host_id() == 0: mask_stats = symmetry.get_mask_stats(mask) logging.info('Mask stats: %s', str(mask_stats)) if self._summary_writer: for label, value in mask_stats.items(): try: self._summary_writer.scalar(f'mask_{epoch}/{label}', value, 0) # Needed when permutations (long int) can't be cast to float32. except (OverflowError, ValueError): self._summary_writer.text(f'mask_{epoch}/{label}', str(value), 0) logging.error( 'Could not write mask_%d/%s to tensorflow summary as float32' ', writing as string instead.', epoch, label) # Creates a new optimizer, based on a new model with new mask. self._update_optimizer( model_factory.update_model(self.optimizer.target, masks=mask)) # Begins epoch. for batch in self._dataset.get_train(): # Note: Because of replicate, step has # device identical vals. step = jax_utils.unreplicate(self.optimizer.state.step) if step % update_iter == 0: batch_start_time = time.time() # These are required for pmap call. self._rng, step_key = jax.random.split(self._rng) batch = _shard_batch(batch) sharded_keys = common_utils.shard_prng_key(step_key) (self.optimizer, state, opt_loss, grad_norm) = p_train_step(self.optimizer, batch, sharded_keys, state) if state.state: state = p_synchronized_batch_stats(state) grad_norm = jax_utils.unreplicate(grad_norm) cumulative_grad_norm += grad_norm # Per-iteration status/metrics update. if jax.host_id() == 0 and step % update_iter == 0: batch_time = time.time() - batch_start_time if self._summary_writer is not None: self._summary_writer.scalar('training/train_batch_loss', jnp.mean(opt_loss), step) self._summary_writer.scalar('training/gradient_norm', grad_norm, step) logging.info('[epoch %d] %d, loss %0.5f, lr %0.3f, %0.3f sec', epoch, step, jnp.mean(opt_loss), lr_fn(step), batch_time) # Per-epoch status/metrics update. if (jax.host_id() == 0 and (epoch % update_epoch == 0 or epoch == num_epochs - 1)): epoch_time = time.time() - epoch_start_time cosine_distance = p_cosine_similarity( jax_utils.unreplicate(self.optimizer.target)) vector_difference_norm = p_vector_difference_norm( jax_utils.unreplicate(self.optimizer.target)) train_metrics = eval_model(self.optimizer.target, state, self._dataset.get_train()) test_metrics = eval_model(self.optimizer.target, state, self._dataset.get_test()) train_loss = train_metrics['loss'] train_acc = train_metrics['accuracy'] test_loss = test_metrics['loss'] test_acc = test_metrics['accuracy'] if jax.host_id() == 0: metrics = { 'wallclock_time': float(epoch_time), 'train_accuracy': float(train_acc), 'train_avg_loss': float(train_loss), 'test_accuracy': float(test_acc), 'test_avg_loss': float(test_loss), 'lr': float(lr_fn(step)), 'cosine_distance': float(cosine_distance), 'cumulative_gradient_norm': float(cumulative_grad_norm), 'vector_difference_norm': float(vector_difference_norm), } if self._summary_writer is not None: for label, value in metrics.items(): self._summary_writer.scalar('training/{}'.format(label), value, step) if test_acc >= best_test_acc: best_model = self.optimizer.target best_test_acc = test_acc best_test_metrics = { 'train_avg_loss': float(train_loss), 'train_accuracy': float(train_acc), 'test_avg_loss': float(test_loss), 'test_accuracy': float(test_acc), 'step': int(step), 'cosine_distance': float(cosine_distance), 'cumulative_gradient_norm': float(cumulative_grad_norm), 'vector_difference_norm': float(vector_difference_norm), } best_iter = step if best_train_loss is None or train_loss <= best_train_loss: best_train_loss = train_loss best_train_metrics = { 'train_avg_loss': float(train_loss), 'train_accuracy': float(train_acc), 'test_avg_loss': float(test_loss), 'test_accuracy': float(test_acc), 'step': int(step), 'cosine_distance': float(cosine_distance), 'cumulative_gradient_norm': float(cumulative_grad_norm), 'vector_difference_norm': float(vector_difference_norm), } log_format_str = ( '[epoch %d] train avg. loss %0.4f, train acc. %0.4f, test avg. ' 'loss %0.4f, test acc. %0.4f, %0.4f sec, cosine sim.: %0.3f, cum. ' 'grad. norm: %0.3f, vector diff: %0.3f') log_vars = [ epoch, train_loss, train_acc, test_loss, test_acc, epoch_time, float(cosine_distance), float(cumulative_grad_norm), float(vector_difference_norm) ] logging.info(log_format_str, *log_vars) # End epoch. training_time = time.time() - start_time logging.info('Training finished, Total wallclock time: %0.2f sec', training_time) if jax.host_id() == 0 and self._summary_writer is not None: for label, value in best_test_metrics.items(): self._summary_writer.scalar('best_test_acc/{}'.format(label), value, best_iter) logging.info('Best Test Accuracy: iteration %d, test acc. %0.5f', best_test_metrics['step'], best_test_acc) if jax.host_id() == 0 and self._summary_writer is not None: for label, value in best_test_metrics.items(): self._summary_writer.scalar( 'best_train_loss/{}'.format(label), value, step=best_train_metrics['step']) logging.info('Best Train Loss: iteration %d, test loss. %0.5f', best_train_metrics['step'], best_train_loss) return (best_model, best_test_metrics) def _eval_step(model: flax.deprecated.nn.Model, state: flax.deprecated.nn.Collection, batch: Mapping[str, jnp.array]) -> Dict[str, jnp.array]: """Evaluates a mini-batch of data. Args: model: The model to use to evaluate. state: Model state containing state for stateful flax.deprecated.nn functions, such as batch normalization. batch: Mini-batch of data to evaluate on. Returns: Dictionary consisting of the mini-batch the loss and accuracy. """ state = jax.lax.pmean(state, 'batch') with flax.deprecated.nn.stateful(state, mutable=False): logits = model(batch[DATAKEY], train=False) metrics = utils.compute_metrics(logits, batch[LABELKEY]) return metrics def eval_model(model: flax.deprecated.nn.Model, state: flax.deprecated.nn.Collection, eval_dataset: jnp.array) -> Dict[str, float]: """Evaluates the given model using the given dataset. Args: model: The model the evaluate. state: Model state containing state for stateful flax.deprecated.nn functions, such as batch normalization. eval_dataset: Dataset to evaluate the model over. Returns: Dictionary containing the average loss and accuracy of the model on the given dataset. """ p_eval_step = jax.pmap(_eval_step, axis_name='batch') batch_sizes = [] metrics = [] for batch in eval_dataset: batch_size = len(batch[LABELKEY]) # These are required for pmap call. batch = _shard_batch(batch) batch_metrics = p_eval_step(model, state, batch) batch_sizes.append(batch_size) metrics.append(batch_metrics) # Note: use weighted mean, since we do mean of means with potentially # different batch sizes otherwise. batch_sizes = jnp.array(batch_sizes) weights = batch_sizes / jnp.sum(batch_sizes) eval_metrics = common_utils.get_metrics(metrics) return jax.tree_map(lambda x: (weights * x).sum(), eval_metrics) ================================================ FILE: rigl/experimental/jax/training/training_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.training.training.""" import functools import math from absl.testing import absltest import flax from flax import jax_utils from flax.metrics import tensorboard from flax.training import common_utils import jax import jax.numpy as jnp from rigl.experimental.jax.datasets import dataset_factory from rigl.experimental.jax.models import model_factory from rigl.experimental.jax.training import training class TrainingTest(absltest.TestCase): """Tests functions for training loop and training convenience functions.""" def setUp(self): super().setUp() self._batch_size = 128 # Note: Tests are run on GPU/TPU. self._batch_size_test = 128 self._shuffle_buffer_size = 1024 self._rng = jax.random.PRNGKey(42) self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._num_classes = 10 self._num_epochs = 1 self._learning_rate_fn = lambda _: 0.01 self._weight_decay = 0.0001 self._momentum = 0.9 self._rng = jax.random.PRNGKey(42) self._min_loss = jnp.finfo(float).eps self._max_loss = 2.0 * math.log(self._num_classes) self._dataset_name = 'MNIST' self._model_name = 'MNIST_CNN' self._summarywriter = tensorboard.SummaryWriter('/tmp/') self._dataset = dataset_factory.create_dataset( self._dataset_name, self._batch_size, self._batch_size_test, shuffle_buffer_size=self._shuffle_buffer_size) self._model, self._state = model_factory.create_model( self._model_name, self._rng, (self._input_shape,), num_classes=self._num_classes) self._optimizer = flax.optim.Momentum( # pytype: disable=module-attr learning_rate=self._learning_rate_fn(0), beta=self._momentum, weight_decay=self._weight_decay) def test_train_one_step(self): """Tests training loop over one step.""" iterator = self._dataset.get_train() batch = next(iterator) state = jax_utils.replicate(self._state) optimizer = jax_utils.replicate(self._optimizer.create(self._model)) self._rng, step_key = jax.random.split(self._rng) batch = training._shard_batch(batch) sharded_keys = common_utils.shard_prng_key(step_key) p_train_step = jax.pmap( functools.partial( training.train_step, learning_rate_fn=self._learning_rate_fn), axis_name='batch') _, _, loss, gradient_norm = p_train_step(optimizer, batch, sharded_keys, state) loss = jnp.mean(loss) gradient_norm = jax_utils.unreplicate(gradient_norm) with self.subTest(name='test_loss_range'): self.assertBetween(loss, self._min_loss, self._max_loss) with self.subTest(name='test_gradient_norm'): self.assertGreaterEqual(gradient_norm, 0) def test_train_one_epoch(self): """Tests training loop over one epoch.""" trainer = training.Trainer(self._optimizer, self._model, self._state, self._dataset) with self.subTest(name='trainer_instantiation'): self.assertIsInstance(trainer, training.Trainer) best_model, best_metrics = trainer.train(self._num_epochs) with self.subTest(name='best_model_type'): self.assertIsInstance(best_model, flax.deprecated.nn.Model) with self.subTest(name='train_accuracy'): self.assertBetween(best_metrics['train_accuracy'], 0., 1.) with self.subTest(name='train_avg_loss'): self.assertBetween(best_metrics['train_avg_loss'], self._min_loss, self._max_loss) with self.subTest(name='step'): self.assertGreater(best_metrics['step'], 0) with self.subTest(name='cosine_distance'): self.assertBetween(best_metrics['cosine_distance'], 0., 1.) with self.subTest(name='cumulative_gradient_norm'): self.assertGreater(best_metrics['cumulative_gradient_norm'], 0) with self.subTest(name='test_accuracy'): self.assertBetween(best_metrics['test_accuracy'], 0., 1.) with self.subTest(name='test_avg_loss'): self.assertBetween(best_metrics['test_avg_loss'], self._min_loss, self._max_loss) def test_train_one_epoch_tensorboard(self): """Tests training loop over one epoch, with tensorboard.""" trainer = training.Trainer( self._optimizer, self._model, self._state, self._dataset, summary_writer=self._summarywriter) with self.subTest(name='TrainerInstantiation'): self.assertIsInstance(trainer, training.Trainer) best_model, best_metrics = trainer.train(self._num_epochs) with self.subTest(name='best_model_type'): self.assertIsInstance(best_model, flax.deprecated.nn.Model) with self.subTest(name='train_accuracy'): self.assertBetween(best_metrics['train_accuracy'], 0., 1.) with self.subTest(name='train_avg_loss'): self.assertBetween(best_metrics['train_avg_loss'], self._min_loss, self._max_loss) with self.subTest(name='step'): self.assertGreater(best_metrics['step'], 0) with self.subTest(name='cosine_distance'): self.assertBetween(best_metrics['cosine_distance'], 0., 1.) with self.subTest(name='cumulative_gradient_norm'): self.assertGreater(best_metrics['cumulative_gradient_norm'], 0) with self.subTest(name='test_accuracy'): self.assertBetween(best_metrics['test_accuracy'], 0., 1.) with self.subTest(name='test_avg_loss'): self.assertBetween(best_metrics['test_avg_loss'], self._min_loss, self._max_loss) def test_train_one_epoch_pruning_global_schedule(self): """Tests training loop over one epoch with global pruning rate schedule.""" trainer = training.Trainer(self._optimizer, self._model, self._state, self._dataset) with self.subTest(name='trainer_instantiation'): self.assertIsInstance(trainer, training.Trainer) best_model, best_metrics = trainer.train(self._num_epochs, pruning_rate_fn=lambda _: 0.5) with self.subTest(name='best_model_type'): self.assertIsInstance(best_model, flax.deprecated.nn.Model) with self.subTest(name='train_accuracy'): self.assertBetween(best_metrics['train_accuracy'], 0., 1.) with self.subTest(name='train_avg_loss'): self.assertBetween(best_metrics['train_avg_loss'], self._min_loss, self._max_loss) with self.subTest(name='step'): self.assertGreater(best_metrics['step'], 0) with self.subTest(name='cosine_distance'): self.assertBetween(best_metrics['cosine_distance'], 0., 1.) with self.subTest(name='cumulative_gradient_norm'): self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.) with self.subTest(name='test_accuracy'): self.assertBetween(best_metrics['test_accuracy'], 0., 1.) with self.subTest(name='test_avg_loss'): self.assertBetween(best_metrics['test_avg_loss'], self._min_loss, self._max_loss) def test_train_one_epoch_pruning_local_schedule(self): """Tests training loop over one epoch with local pruning rate schedule.""" trainer = training.Trainer(self._optimizer, self._model, self._state, self._dataset) with self.subTest(name='trainer_instantiation'): self.assertIsInstance(trainer, training.Trainer) best_model, best_metrics = trainer.train( self._num_epochs, pruning_rate_fn={'MaskedModule_0': lambda _: 0.5}) with self.subTest(name='best_model_type'): self.assertIsInstance(best_model, flax.deprecated.nn.Model) with self.subTest(name='train_accuracy'): self.assertBetween(best_metrics['train_accuracy'], 0., 1.) with self.subTest(name='train_avg_loss'): self.assertBetween(best_metrics['train_avg_loss'], self._min_loss, self._max_loss) with self.subTest(name='step'): self.assertGreater(best_metrics['step'], 0) with self.subTest(name='cosine_distance'): self.assertBetween(best_metrics['cosine_distance'], 0., 1.) with self.subTest(name='cumulative_gradient_norm'): self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.) with self.subTest(name='test_accuracy'): self.assertBetween(best_metrics['test_accuracy'], 0., 1.) with self.subTest(name='test_avg_loss'): self.assertBetween(best_metrics['test_avg_loss'], self._min_loss, self._max_loss) def test_eval_batch(self): """Tests model per-batch evaluation function.""" state = jax_utils.replicate(self._state) optimizer = jax_utils.replicate(self._optimizer.create(self._model)) iterator = self._dataset.get_test() batch = next(iterator) batch = training._shard_batch(batch) metrics = jax.pmap(training._eval_step, axis_name='batch')( optimizer.target, state, batch) loss = jnp.mean(metrics['loss']) accuracy = jnp.mean(metrics['accuracy']) with self.subTest(name='test_eval_batch_loss'): self.assertBetween(loss, self._min_loss, self._max_loss) with self.subTest(name='test_eval_batch_accuracy'): self.assertBetween(accuracy, 0., 1.) def test_eval(self): """Tests model evaluation function.""" state = jax_utils.replicate(self._state) optimizer = jax_utils.replicate(self._optimizer.create(self._model)) metrics = training.eval_model(optimizer.target, state, self._dataset.get_test()) loss = metrics['loss'] accuracy = metrics['accuracy'] with self.subTest(name='test_eval_loss'): self.assertBetween(loss, 0., 2.0*math.log(self._num_classes)) with self.subTest(name='test_eval_accuracy'): self.assertBetween(accuracy, 0., 1.) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/experimental/jax/utils/__init__.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: rigl/experimental/jax/utils/utils.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Convenience Functions for NN training. Misc. common functions used in training NN models. """ import functools import itertools import json import operator from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, TypeVar import flax from flax.training import common_utils import jax import jax.numpy as jnp import numpy as np def cross_entropy_loss(log_softmax_logits, labels): """Returns the cross-entropy classification loss. Args: log_softmax_logits: The log of the softmax of the logits for the mini-batch, e.g. as output by jax.nn.log_softmax(logits). labels: The labels for the mini-batch. """ num_classes = log_softmax_logits.shape[-1] one_hot_labels = common_utils.onehot(labels, num_classes) return -jnp.sum(one_hot_labels * log_softmax_logits) / labels.size def compute_metrics(logits, labels): """Computes the classification loss and accuracy for a mini-batch. Args: logits: NN model's logit outputs for the mini-batch. labels: The classification labels for the mini-batch. Returns: Metrics dictionary where 'loss' the mini-batch loss and 'accuracy' is the classification accuracy. Raises: ValueError: If the given logits array is not of the correct shape. """ if len(logits.shape) != 2: raise ValueError( 'Expected an array of (BATCHSIZE, NUM_CLASSES), but got {}'.format( logits.shape)) metrics = { 'loss': cross_entropy_loss(logits, labels), 'accuracy': jnp.mean(jnp.argmax(logits, -1) == labels) } return jax.lax.pmean(metrics, 'batch') def _np_converter(obj): """Explicitly cast Numpy types not recognized by JSON serializer.""" if isinstance(obj, jnp.integer) or isinstance(obj, np.integer): return int(obj) elif isinstance(obj, jnp.floating) or isinstance(obj, np.floating): return float(obj) elif isinstance(obj, jnp.ndarray) or isinstance(obj, np.ndarray): return obj.tolist() def dump_dict_json(data_dict, path): """Dumps a dictionary to a JSON file, ensuring Numpy types are cast correctly. Args: data_dict: A metrics dictionary. path: Path of the JSON file to save. Raises: """ with open(path, 'w') as json_file: json.dump(data_dict, json_file, default=_np_converter) def count_param(model, param_names): """Counts the number of parameters in the given model. Args: model: The model for which to count the parameters. param_names: The parameters in each layer which should be accounted for. Returns: The total number of parameters of the given names in the model. """ param_traversal = flax.optim.ModelParamTraversal( # pytype: disable=module-attr lambda path, _: any(param_name in path for param_name in param_names)) return functools.reduce( operator.add, [param.size for param in param_traversal.iterate(model)], 0) @jax.jit def cosine_similarity(a, b): """Calculates the cosine similarity between two tensors of same shape.""" a = a.flatten() b = b.flatten() return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b)) def param_as_array(params): """Returns a Flax parameter pytree as a single numpy weight vector.""" params_flat = jax.tree_util.tree_leaves(params) return jnp.concatenate([param.flatten() for param in params_flat]) def cosine_similarity_model(initial_model, current_model): """Calculates the cosine similarity between two model's parameters.""" initial_params = param_as_array(initial_model.params) params = param_as_array(current_model.params) return cosine_similarity(initial_params, params) def vector_difference_norm_model(initial_model, current_model): """Calculates norm of the difference between two model's parameter vectors.""" initial_params = param_as_array(initial_model.params) params = param_as_array(current_model.params) return jnp.linalg.norm(params - initial_params) # Use typevar to hint that we expect unspecified types to match. T = TypeVar('T') def pairwise_longest(iterable): """Creates a meta-iterator to iterate over current/next values concurrently. This is different from itertools pairwise recipe in that it returns the final element as (final, None). Args: iterable: An Iterable of any type. Returns: An iterable which returns the current and next items in the iterable, or None if there is no next. For example, for an iterator over the list (1, 2, 3, 4), this would return an iterator as ((1, 2), (2, 3), (3, 4), (4, None)). """ # From itertools example documentation. a, b = itertools.tee(iterable) next(b, None) return itertools.zip_longest(a, b) ================================================ FILE: rigl/experimental/jax/utils/utils_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for weight_symmetry.nn.nn_functions.""" import functools import json import operator import tempfile from typing import Optional, Sequence, TypeVar from absl.testing import absltest from absl.testing import parameterized import flax import jax import jax.numpy as jnp import numpy as np from rigl.experimental.jax.training import training from rigl.experimental.jax.utils import utils class TwoLayerDense(flax.deprecated.nn.Module): """Two-layer Dense Network.""" NUM_FEATURES: Sequence[int] = (32, 64) def apply(self, inputs): # If inputs are in image dimensions, flatten image. inputs = inputs.reshape(inputs.shape[0], -1) inputs = flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[0]) return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[1]) class UtilsTest(parameterized.TestCase): """Test functions for NN convenience functions.""" def setUp(self): """Common setup for test cases.""" super().setUp() self._batch_size = 2 self._num_classes = 10 self._true_logit = 0.5 self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32) self._input = jnp.ones(*self._input_shape) self._rng = jax.random.PRNGKey(42) _, initial_params = TwoLayerDense.init_by_shape(self._rng, (self._input_shape,)) self._model = flax.deprecated.nn.Model(TwoLayerDense, initial_params) _, initial_params = TwoLayerDense.init_by_shape(self._rng, (self._input_shape,)) self._model_diff_init = flax.deprecated.nn.Model(TwoLayerDense, initial_params) def _create_logits_labels(self, correct): """Creates a set of logits/labels resulting from correct classification. Args: correct: If true, creates labels for a correct classifiction, otherwise creates labels for an incorrect classification. Returns: A tuple of logits, labels. """ logits = np.full((self._batch_size, self._num_classes), (1.0 - self._true_logit) / self._num_classes, dtype=np.float32) # Diagonal over batch will be true. for i in range(self._batch_size): logits[i, i % self._num_classes] = self._true_logit labels = np.zeros(self._batch_size, dtype=jnp.int32) # Diagonal over batch will be true. for i in range(self._batch_size): labels[i] = (i if correct else i + 1) % self._num_classes return jnp.array(logits), jnp.array(labels) def test_compute_metrics_correct(self): """Tests output when logit outputs indicate correct classification.""" logits, labels_correct = self._create_logits_labels(True) logits = training._shard_batch(logits) labels_correct = training._shard_batch(labels_correct) p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch') metrics = p_compute_metrics(logits, labels_correct) loss = metrics['loss'] accuracy = metrics['accuracy'] with self.subTest(name='loss_type'): self.assertIsInstance(loss, jnp.ndarray) with self.subTest(name='loss_len'): self.assertEqual(loss.size, 1) with self.subTest(name='loss_values'): self.assertGreaterEqual(loss.all(), 0) with self.subTest(name='accuracy_type'): self.assertIsInstance(accuracy, jnp.ndarray) with self.subTest(name='accuracy_Len'): self.assertEqual(accuracy.size, 1) with self.subTest(name='accuracy_values'): self.assertAlmostEqual(accuracy.all(), 1.0) def test_compute_metrics_incorrect(self): """Tests output when logit outputs indicate incorrect classification.""" logits, labels_incorrect = self._create_logits_labels(False) logits = training._shard_batch(logits) labels_incorrect = training._shard_batch(labels_incorrect) p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch') metrics = p_compute_metrics(logits, labels_incorrect) loss = metrics['loss'] accuracy = metrics['accuracy'] with self.subTest(name='loss_type'): self.assertIsInstance(loss, jnp.ndarray) with self.subTest(name='loss_len'): self.assertEqual(loss.size, 1) with self.subTest(name='loss_values'): self.assertGreaterEqual(loss.all(), 0) with self.subTest(name='accuracy_type'): self.assertIsInstance(accuracy, jnp.ndarray) with self.subTest(name='accuracy_len'): self.assertEqual(accuracy.size, 1) with self.subTest(name='accuracy_values'): self.assertAlmostEqual(accuracy.all(), 0.0) def test_compute_metrics_equal_logits(self): """Tests output when the logit outputs are equal for all classes.""" logits, labels_correct = self._create_logits_labels(True) logits = training._shard_batch(logits) labels_correct = training._shard_batch(labels_correct) p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch') metrics = p_compute_metrics(logits, labels_correct) loss = metrics['loss'] accuracy = metrics['accuracy'] with self.subTest(name='loss_type'): self.assertIsInstance(loss, jnp.ndarray) with self.subTest(name='loss_len'): self.assertEqual(loss.size, 1) with self.subTest(name='loss_values'): self.assertGreaterEqual(loss.all(), 0) with self.subTest(name='accuracy_type'): self.assertIsInstance(accuracy, jnp.ndarray) with self.subTest(name='accuracy_len'): self.assertEqual(accuracy.size, 1) with self.subTest(name='accuracy_values'): self.assertAlmostEqual(accuracy.all(), 1.0) def test_dump_dict_json(self): """Tests JSON dumping function.""" data_dict = { 'np_float': np.dtype('float32').type(1.0), 'jnp_float': jnp.dtype('float32').type(1.0), 'np_int': np.dtype('int32').type(1), 'jnp_int': jnp.dtype('int32').type(1), 'np_array': np.array(1.0, dtype=np.float32), 'jnp_array': jnp.array(1.0, dtype=jnp.float32), } converted_dict = { key: utils._np_converter(value) for key, value in data_dict.items() } json_path = tempfile.NamedTemporaryFile() utils.dump_dict_json(data_dict, json_path.name) with open(json_path.name, 'r') as input_file: loaded_dict = json.load(input_file) self.assertDictEqual(loaded_dict, converted_dict) def test_count_param_two_layer_dense(self): """Tests model parameter counting on small FC model.""" count = utils.count_param(self._model, ('kernel',)) self.assertEqual( count, self._input.size / self._batch_size * TwoLayerDense.NUM_FEATURES[0] + TwoLayerDense.NUM_FEATURES[0] * TwoLayerDense.NUM_FEATURES[1]) def test_count_invalid_param(self): """Tests model parameter counting for a non-existent parameter name.""" count = utils.count_param(self._model, ('not_kernel',)) self.assertEqual(count, 0) def test_model_param_as_array(self): """Tests method for returning single parameter vector for model.""" param_array = utils.param_as_array(self._model.params) with self.subTest(name='test_param_is_vector'): self.assertLen(param_array.shape, 1) param_sizes = [param.size for param in jax.tree_leaves(self._model.params)] model_size = functools.reduce(operator.add, param_sizes) with self.subTest(name='test_param_size'): self.assertEqual(param_array.size, model_size) def test_cosine_similarity_random(self): """Tests cosine similarity for two random weight matrices.""" a = jax.random.normal(self._rng, (3, 4)) b = jax.random.normal(self._rng, (3, 4)) cosine_similarity = utils.cosine_similarity(a, b) with self.subTest(name='test_cosine_distance_range'): self.assertBetween(cosine_similarity, 0., 1.) def test_cosine_similarity_same(self): """Tests cosine similarity for the same weight matrix.""" a = jax.random.normal(self._rng, (3, 4)) cosine_similarity = utils.cosine_similarity(a, a) with self.subTest(name='test_cosine_distance_range'): self.assertAlmostEqual(cosine_similarity, 1., places=5) def test_cosine_similarity_same_model(self): """Tests cosine similarity for the same model.""" cosine_dist = utils.cosine_similarity_model(self._model, self._model) self.assertAlmostEqual(cosine_dist, 1., places=5) def test_vector_difference_norm_diff_model(self): """Tests vector difference norm for different models.""" vector_diff_norm = utils.vector_difference_norm_model( self._model, self._model_diff_init) self.assertGreaterEqual(vector_diff_norm, 0.) def test_vector_difference_norm_same_model(self): """Tests vector difference norm for the same model.""" vector_diff_norm = utils.vector_difference_norm_model( self._model, self._model) self.assertAlmostEqual(vector_diff_norm, 0., places=5) T = TypeVar('T') @parameterized.parameters( # Tests pairwise longest iterator convenience function with list. ((1, 2, 3, 4), ((1, 2), (2, 3), (3, 4), (4, None))), # Tests pairwise longest iterator with empty input iterator. (iter(()), ()), # Tests pairwise longest iterator with single element iterator. ((1,), ((1, None),)) ) def test_pairwise_longest_list_iterator( self, input_sequence, output_sequence): """Tests pairwise longest iterator with list iterators.""" output = list(utils.pairwise_longest(iter(input_sequence))) self.assertSequenceEqual(output, output_sequence) if __name__ == '__main__': absltest.main() ================================================ FILE: rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "e5O1UdsY202_" }, "source": [ "##### Copyright 2020 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Wtx39-f76KsC" }, "outputs": [], "source": [ "# Download necessary libraries.\n", "%%bash \n", "test -d rigl || git clone https://github.com/google-research/rigl rigl_repo \u0026\u0026 mv rigl_repo/rigl ./ \n", "test -d gresearch || git clone https://github.com/google-research/google-research google_research" ] }, { "cell_type": "markdown", "metadata": { "id": "i25HTaVl6LAI" }, "source": [ "## Parameter and FLOPs Counting for MobileNetv1 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gAkFMbjrNCww" }, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from google_research.micronet_challenge import counting\n", "from rigl import sparse_utils\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 34 }, "executionInfo": { "elapsed": 2458, "status": "ok", "timestamp": 1593006846761, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "dYm9k-Q47PXe", "outputId": "db7fc195-6e0b-4c04-b695-5670128503d7" }, "outputs": [ { "data": { "text/plain": [ "\u003ctf.Tensor 'mobilenet_1.00_224/act_softmax/Softmax:0' shape=(2, 1000) dtype=float32\u003e" ] }, "execution_count": 2, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "tf.compat.v1.reset_default_graph()\n", "model=tf.keras.applications.MobileNet(input_shape=(224,224,3), weights=None)\n", "model(tf.ones((2,224,224,3)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RNS1s5Wm7U8-" }, "outputs": [], "source": [ "masked_layers = []\n", "dw_layers = []\n", "for layer in model.layers:\n", " if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense, tf.keras.layers.DepthwiseConv2D)): \n", " masked_layers.append(layer)\n", " if 'conv_dw' in layer.name:\n", " dw_layers.append(layer)\n", " # print(layer.name, sparse_utils._get_kernel(layer).shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QtD03TrBSDzV" }, "outputs": [], "source": [ "PARAM_SIZE=32\n", "import functools\n", "\n", "get_stats = functools.partial(\n", " sparse_utils.get_stats, first_layer_name='conv1',\n", " last_layer_name='conv_preds', param_size=PARAM_SIZE)\n", "\n", "def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',\n", " custom_sparsities=None, is_debug=False, width=1.):\n", " print('Method: %s, Sparsity: %f' % (method, default_sparsity))\n", " total_flops, total_param_bits, sparsity = get_stats(\n", " masked_layers, default_sparsity=default_sparsity, method=method,\n", " custom_sparsities=custom_sparsities, is_debug=is_debug, width=width)\n", " print('Total Flops: %.3f MFlops' % (total_flops/1e6))\n", " print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))\n", " print('Real Sparsity: %.3f' % (sparsity))" ] }, { "cell_type": "markdown", "metadata": { "id": "FvqtfXePpgdb" }, "source": [ "### Printing sparse network stats" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 218 }, "executionInfo": { "elapsed": 548, "status": "ok", "timestamp": 1593006940695, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "qupDcQOlTxDk", "outputId": "f59b39d2-eedb-4e45-db93-f52958f24a45" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Method: erdos_renyi_kernel, Sparsity: 0.750000\n", "Total Flops: 599.144 MFlops\n", "Total Size: 4.888 Mbytes\n", "Real Sparsity: 0.742\n", "Method: random, Sparsity: 0.750000\n", "Total Flops: 330.769 MFlops\n", "Total Size: 4.894 Mbytes\n", "Real Sparsity: 0.742\n", "Method: random, Sparsity: 0.000000\n", "Total Flops: 1141.544 MFlops\n", "Total Size: 16.864 Mbytes\n", "Real Sparsity: 0.000\n" ] } ], "source": [ "c_sparsities = {'%s/depthwise_kernel:0' % l.name: 0. for l in dw_layers}\n", "c_sparsities_uniform = c_sparsities.copy()\n", "c_sparsities_uniform['conv1/kernel:0'] = 0.\n", "# c_sparsities_uniform['conv_preds/kernel:0'] = 0.\n", "# First layer has sparsity 0 by default.\n", "print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n", "print_stats(masked_layers, 0.75, 'random', c_sparsities_uniform, is_debug=False)\n", "print_stats(masked_layers, 0, 'random', is_debug=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 151 }, "executionInfo": { "elapsed": 529, "status": "ok", "timestamp": 1593028091210, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "qvagZCnX31yP", "outputId": "542832bb-7b59-4f43-d216-73260a9a3a56" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Method: erdos_renyi_kernel, Sparsity: 0.850000\n", "Total Flops: 439.152 MFlops\n", "Total Size: 3.224 Mbytes\n", "Real Sparsity: 0.841\n", "Method: random, Sparsity: 0.850000\n", "Total Flops: 222.666 MFlops\n", "Total Size: 3.229 Mbytes\n", "Real Sparsity: 0.841\n" ] } ], "source": [ "print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n", "print_stats(masked_layers, 0.85, 'random', c_sparsities_uniform, is_debug=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 151 }, "executionInfo": { "elapsed": 840, "status": "ok", "timestamp": 1593006957962, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "t3L8WlYJOhku", "outputId": "e5d4709b-984e-4e6d-ded4-8bdd81071267" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Method: erdos_renyi_kernel, Sparsity: 0.900000\n", "Total Flops: 334.134 MFlops\n", "Total Size: 2.392 Mbytes\n", "Real Sparsity: 0.890\n", "Method: random, Sparsity: 0.900000\n", "Total Flops: 168.614 MFlops\n", "Total Size: 2.396 Mbytes\n", "Real Sparsity: 0.890\n" ] } ], "source": [ "print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n", "print_stats(masked_layers, 0.9, 'random', c_sparsities_uniform, is_debug=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 153 }, "executionInfo": { "elapsed": 567, "status": "ok", "timestamp": 1582843606223, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 480 }, "id": "Ge1Ct0YjUME1", "outputId": "7144ccdc-eae9-47d8-8a5c-b74aad94187c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Method: erdos_renyi_kernel, Sparsity: 0.950000\n", "Total Flops: 205.281 MFlops\n", "Total Size: 1.560 Mbytes\n", "Real Sparsity: 0.940\n", "Method: random, Sparsity: 0.950000\n", "Total Flops: 114.563 MFlops\n", "Total Size: 1.563 Mbytes\n", "Real Sparsity: 0.940\n" ] } ], "source": [ "print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\n", "print_stats(masked_layers, 0.95, 'random', c_sparsities_uniform, is_debug=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "2RnZ9BCDVJ2P" }, "source": [ "## Finding the width Multiplier for small dense model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 173 }, "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1569942238017, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "-qQMOoNqURfs", "outputId": "4edf8c57-c3ab-45a1-f19d-13be5da23368" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9933069386323201\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 266.539 MFlops\n", "Total Size: 4.789 Mbytes\n", "Real Sparsity: 0.000\n", "Method: erdos_renyi_kernel, Sparsity: 0.750000\n", "Total Flops: 588.355 MFlops\n", "Total Size: 4.757 Mbytes\n", "Real Sparsity: 0.750\n" ] } ], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.47)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.47)\n", "print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 173 }, "executionInfo": { "elapsed": 536, "status": "ok", "timestamp": 1569942242149, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "P5mS-6h3ZChX", "outputId": "b722e40b-2797-454e-a2bb-91cdaef4a79d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9998127484496482\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 154.770 MFlops\n", "Total Size: 3.076 Mbytes\n", "Real Sparsity: 0.000\n", "Method: erdos_renyi_kernel, Sparsity: 0.850000\n", "Total Flops: 422.419 MFlops\n", "Total Size: 3.075 Mbytes\n", "Real Sparsity: 0.850\n" ] } ], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.353)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.353)\n", "print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 168 }, "executionInfo": { "elapsed": 656, "status": "ok", "timestamp": 1569028742267, "user": { "displayName": "Utku Evci", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64", "userId": "01088181649958641579" }, "user_tz": 240 }, "id": "wY2Uc8RlVkRb", "outputId": "03535606-8b6f-4eb9-ca48-ef235d69994f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9996546850118981\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 103.825 MFlops\n", "Total Size: 2.236 Mbytes\n", "Real Sparsity: 0.000\n", "Method: erdos_renyi_kernel, Sparsity: 0.900000\n", "Total Flops: 312.956 MFlops\n", "Total Size: 2.235 Mbytes\n", "Real Sparsity: 0.900\n" ] } ], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.285)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.285)\n", "print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 168 }, "executionInfo": { "elapsed": 574, "status": "ok", "timestamp": 1569089855290, "user": { "displayName": "Utku Evci", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64", "userId": "01088181649958641579" }, "user_tz": 240 }, "id": "TUfPAjO5Cryq", "outputId": "c528942a-f531-48df-a46e-d94d5dae0a89" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9982463429660301\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 56.617 MFlops\n", "Total Size: 1.396 Mbytes\n", "Real Sparsity: 0.000\n", "Method: erdos_renyi_kernel, Sparsity: 0.950000\n", "Total Flops: 180.359 MFlops\n", "Total Size: 1.393 Mbytes\n", "Real Sparsity: 0.950\n" ] } ], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.204)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.204)\n", "print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "f8sqZWZYpoqa" }, "source": [ "### Big-Sparse Networks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 242 }, "executionInfo": { "elapsed": 631, "status": "ok", "timestamp": 1569285091631, "user": { "displayName": "Utku Evci", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64", "userId": "01088181649958641579" }, "user_tz": 240 }, "id": "f-eD8zoFY_-U", "outputId": "0341ebde-cff6-497e-afaf-65e4a39ac438" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0084815029856933\n", "Method: erdos_renyi_kernel, Sparsity: 0.750000\n", "Total Flops: 2180.140 MFlops\n", "Total Size: 16.723 Mbytes\n", "Real Sparsity: 0.742\n", "Method: random, Sparsity: 0.750000\n", "Total Flops: 1122.572 MFlops\n", "Total Size: 15.863 Mbytes\n", "Real Sparsity: 0.757\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 1141.544 MFlops\n", "Total Size: 16.864 Mbytes\n", "Real Sparsity: 0.000\n" ] } ], "source": [ "# BIGGER\n", "_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=1.98)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1.98)\n", "print_stats(masked_layers, 0.75, 'random', {'conv_preds/kernel:0':0.8, 'conv1/kernel:0':0.}, is_debug=False, width=1.98)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 168 }, "executionInfo": { "elapsed": 581, "status": "ok", "timestamp": 1569029822060, "user": { "displayName": "Utku Evci", "photoUrl": "https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64", "userId": "01088181649958641579" }, "user_tz": 240 }, "id": "z_rW4hO0ZwIG", "outputId": "efe0e3cd-4ed1-49eb-db6b-d673b01cc020" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0032864697591513\n", "Method: erdos_renyi_kernel, Sparsity: 0.850000\n", "Total Flops: 2442.726 MFlops\n", "Total Size: 16.809 Mbytes\n", "Real Sparsity: 0.846\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 1141.544 MFlops\n", "Total Size: 16.864 Mbytes\n", "Real Sparsity: 0.000\n" ] } ], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=2.52)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=2.52)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 242 }, "executionInfo": { "elapsed": 558, "status": "ok", "timestamp": 1569939161351, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "MHhuiXGlaQEi", "outputId": "74db692f-bc1d-4f42-acc9-3848f4b2d21c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0120353164650686\n", "Method: erdos_renyi_kernel, Sparsity: 0.900000\n", "Total Flops: 2452.785 MFlops\n", "Total Size: 16.664 Mbytes\n", "Real Sparsity: 0.899\n", "Method: random, Sparsity: 0.900000\n", "Total Flops: 1058.478 MFlops\n", "Total Size: 17.833 Mbytes\n", "Real Sparsity: 0.890\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 1141.544 MFlops\n", "Total Size: 16.864 Mbytes\n", "Real Sparsity: 0.000\n" ] } ], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=3.)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=3.)\n", "print_stats(masked_layers, 0.9, 'random', {'conv_preds/kernel:0':0.8, 'conv1/kernel:0':0.}, is_debug=False, width=3.)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "height": 173 }, "executionInfo": { "elapsed": 523, "status": "ok", "timestamp": 1569939157037, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": 240 }, "id": "wENtmNUGaXwj", "outputId": "dab1f1c2-b647-4a67-b486-5ec5dcfcf4af" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0031304863290271\n", "Method: erdos_renyi_kernel, Sparsity: 0.950000\n", "Total Flops: 2132.954 MFlops\n", "Total Size: 16.812 Mbytes\n", "Real Sparsity: 0.954\n", "Method: erdos_renyi_kernel, Sparsity: 0.000000\n", "Total Flops: 1141.544 MFlops\n", "Total Size: 16.864 Mbytes\n", "Real Sparsity: 0.000\n" ] } ], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=3.98)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=3.98)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "klQNdBJIqm3E" }, "outputs": [], "source": [ "" ] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "//learning/brain/python/client:colab_notebook", "kind": "private" }, "name": "MobileNet v1: Param/Flops Counting [OPEN_SOURCE].ipynb" }, "kernelspec": { "display_name": "Python 2", "name": "python2" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "e5O1UdsY202_" }, "source": [ "##### Copyright 2020 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P5p1fkA3rgL_" }, "outputs": [], "source": [ "# Download the official ResNet50 implementation and other libraries.\n", "# the ResNet50 module s.t. we can use the model builders for our counting.\n", "%%bash \n", "test -d tpu || git clone https://github.com/tensorflow/tpu tpu \u0026\u0026 mv tpu/models/experimental/resnet50_keras ./ \n", "test -d rigl || git clone https://github.com/google-research/rigl rigl_repo \u0026\u0026 mv rigl_repo/rigl ./ \n", "test -d gresearch || git clone https://github.com/google-research/google-research google_research" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tmr3djWe1rKj" }, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from micronet_challenge import counting\n", "from resnet50_keras import resnet_model as resnet_keras\n", "from rigl import sparse_utils\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dYm9k-Q47PXe" }, "outputs": [], "source": [ "tf.compat.v1.reset_default_graph()\n", "model = resnet_keras.ResNet50(1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RNS1s5Wm7U8-" }, "outputs": [], "source": [ "masked_layers = []\n", "for layer in model.layers:\n", " if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):\n", " masked_layers.append(layer)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QtD03TrBSDzV" }, "outputs": [], "source": [ "PARAM_SIZE=32 # bits\n", "import functools\n", "get_stats = functools.partial(\n", " sparse_utils.get_stats, first_layer_name='conv1', last_layer_name='fc1000',\n", " param_size=PARAM_SIZE)\n", "def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',\n", " custom_sparsities={}, is_debug=False, width=1., **kwargs):\n", " print('Method: %s, Sparsity: %f' % (method, default_sparsity))\n", " total_flops, total_param_bits, sparsity = get_stats(\n", " masked_layers, default_sparsity=default_sparsity, method=method,\n", " custom_sparsities=custom_sparsities, is_debug=is_debug, width=width, **kwargs)\n", " print('Total Flops: %.3f MFlops' % (total_flops/1e6))\n", " print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))\n", " print('Real Sparsity: %.3f' % (sparsity))" ] }, { "cell_type": "markdown", "metadata": { "id": "C_2kH9dsrUqu" }, "source": [ "# Pruning FLOPs\n", "We calculate theoratical FLOPs for pruning, which means we will start counting sparse FLOPs when the pruning starts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yHmbXdMyT2c8" }, "outputs": [], "source": [ "p_start, p_end, p_freq = 10000,25000,1000\n", "target_sparsity = 0.8\n", "total_flops = []\n", "for i in range(0,32001,1000):\n", " if i \u003c p_start:\n", " sparsity = 0.\n", " elif p_end \u003c i:\n", " sparsity = target_sparsity\n", " else:\n", " sparsity = (1-(1-(i-p_start)/float(p_end-p_start))**3)*target_sparsity\n", " # print(i, sparsity)\n", " c_flops, _, _ = get_stats(\n", " masked_layers, default_sparsity=sparsity, method='random', custom_sparsities={'conv1/kernel:0':0, 'fc1000/kernel:0':0.8})\n", " # print(i, c_flops, sparsity)\n", " total_flops.append(c_flops)\n", "avg_flops = sum(total_flops) / len(total_flops)\n", "print('Average Flops: ', avg_flops, avg_flops/total_flops[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "xUU10hxxsZX-" }, "source": [ "### Printing sparse network stats." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qupDcQOlTxDk" }, "outputs": [], "source": [ "print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=True, erk_power_scale=0.2)\n", "print_stats(masked_layers, 0.8, 'erdos_renyi')\n", "print_stats(masked_layers, 0.8, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False)\n", "print_stats(masked_layers, 0, 'random', is_debug=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AI1HIlLrzuED" }, "outputs": [], "source": [ "print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False)\n", "print_stats(masked_layers, 0.9, 'erdos_renyi')\n", "print_stats(masked_layers, 0.9, 'random', {'conv1/kernel:0':0., 'fc1000/kernel:0':0.9}, is_debug=False)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oX5klsS4_vy-" }, "outputs": [], "source": [ "print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False)\n", "print_stats(masked_layers, 0.95, 'erdos_renyi')\n", "print_stats(masked_layers, 0.95, 'random', {'conv1/kernel:0':0}, is_debug=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fe2FHmPfzS7S" }, "outputs": [], "source": [ "print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', {'conv1/kernel:0':0}, is_debug=False)\n", "print_stats(masked_layers, 0.965, 'erdos_renyi')\n", "print_stats(masked_layers, 0.965, 'random', {'conv1/kernel:0':0}, is_debug=False)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "Yc2EeP_YWUfA" }, "source": [ "## Finding the width Multiplier for small dense model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "p8NJFEo9Se2S" }, "outputs": [], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.465)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.465)\n", "print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Gjk8Z2g2TOKq" }, "outputs": [], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.34)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.34)\n", "print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sa1zoC-bT-Qk" }, "outputs": [], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.26)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.26)\n", "print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "f_IugJP5URFa" }, "outputs": [], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0.965, 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.231)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.231)\n", "print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "fXd4Mx90sc9Q" }, "source": [ "### Printing the Big-Sparse Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BtpJ3LvKYCNn" }, "outputs": [], "source": [ "# BIGGER\n", "_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel', width=2.1)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=2.1)\n", "print_stats(masked_layers, 0.8, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8},\n", " is_debug=False, width=2.1)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.1)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kRcOlrf4YG7K" }, "outputs": [], "source": [ "_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\n", "_, bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', width=2.8)\n", "print(sparse_bits/bits)\n", "print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=2.8)\n", "print_stats(masked_layers, 0.9, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False, width=2.8)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.8)\n", "print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "BN8qfasQWva2" }, "source": [ "## [BONUS] DSR FLOPs\n", "Obtained from figure https://arxiv.org/abs/1902.05967; exact values are probably slightly different.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RwI5aRe-SH0n" }, "outputs": [], "source": [ "resnet_layers=['conv1/kernel:0',\n", "'res2a_branch2a/kernel:0',\n", "'res2a_branch2b/kernel:0',\n", "'res2a_branch2c/kernel:0',\n", "'res2a_branch1/kernel:0',\n", "'res2b_branch2a/kernel:0',\n", "'res2b_branch2b/kernel:0',\n", "'res2b_branch2c/kernel:0',\n", "'res2c_branch2a/kernel:0',\n", "'res2c_branch2b/kernel:0',\n", "'res2c_branch2c/kernel:0',\n", "'res3a_branch2a/kernel:0',\n", "'res3a_branch2b/kernel:0',\n", "'res3a_branch2c/kernel:0',\n", "'res3a_branch1/kernel:0',\n", "'res3b_branch2a/kernel:0',\n", "'res3b_branch2b/kernel:0',\n", "'res3b_branch2c/kernel:0',\n", "'res3c_branch2a/kernel:0',\n", "'res3c_branch2b/kernel:0',\n", "'res3c_branch2c/kernel:0',\n", "'res3d_branch2a/kernel:0',\n", "'res3d_branch2b/kernel:0',\n", "'res3d_branch2c/kernel:0',\n", "'res4a_branch2a/kernel:0',\n", "'res4a_branch2b/kernel:0',\n", "'res4a_branch2c/kernel:0',\n", "'res4a_branch1/kernel:0',\n", "'res4b_branch2a/kernel:0',\n", "'res4b_branch2b/kernel:0',\n", "'res4b_branch2c/kernel:0',\n", "'res4c_branch2a/kernel:0',\n", "'res4c_branch2b/kernel:0',\n", "'res4c_branch2c/kernel:0',\n", "'res4d_branch2a/kernel:0',\n", "'res4d_branch2b/kernel:0',\n", "'res4d_branch2c/kernel:0',\n", "'res4e_branch2a/kernel:0',\n", "'res4e_branch2b/kernel:0',\n", "'res4e_branch2c/kernel:0',\n", "'res4f_branch2a/kernel:0',\n", "'res4f_branch2b/kernel:0',\n", "'res4f_branch2c/kernel:0',\n", "'res5a_branch2a/kernel:0',\n", "'res5a_branch2b/kernel:0',\n", "'res5a_branch2c/kernel:0',\n", "'res5a_branch1/kernel:0',\n", "'res5b_branch2a/kernel:0',\n", "'res5b_branch2b/kernel:0',\n", "'res5b_branch2c/kernel:0',\n", "'res5c_branch2a/kernel:0',\n", "'res5c_branch2b/kernel:0',\n", "'res5c_branch2c/kernel:0',\n", "'fc1000/kernel:0']\n", "dsr_sparsities8=[0,\n", " 0., .15, .5, .425, .575, .55, .425, .32, .44, .15,\n", " 0., .15, .55, .6, .8, .65, .75, .65, .65, .65, .55, .65, .7,\n", " 0., .35, .65, .85, .9, .8, .85, .85, .8, .85, .85, .85, .85, .8, .8, .9, .75, .8, .85,\n", " 0., .65, .85, .95, .85, .8, .9, .65, .9, .8,\n", " .8]\n", "dsr_sparsities9=[0,\n", " 0., .4, .6, .65, .65, .6, .6, .5, .6, .45,\n", " 0., .4, .7, .8, .9, .8, .85, .8, .75, .8, .7, .8, .8,\n", " 0., .6, .8, .95, .95, .9, .95, .9, .9, .95, .9, .9, .95, .9, .9, .95, .85, .85, .9,\n", " 0., 0.8, .95, .95, .9, .9, .95, .8, .95, .9,\n", " .9] " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P6i-jjz6OLBH" }, "outputs": [], "source": [ "dsr_map = dict(zip(resnet_layers, dsr_sparsities8))\n", "print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xeGqdHtYYlZT" }, "outputs": [], "source": [ "dsr_map = dict(zip(resnet_layers, dsr_sparsities9))\n", "print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "Pf3qqLKrG67e" }, "source": [ "# [BONUS] STR FLOPs\n", "Layerwise sparsities are obtained from the [STR paper](https://arxiv.org/abs/2002.03231)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MIwBmu0NHOuI" }, "outputs": [], "source": [ "str_sparsities = \"\"\"\n", "Layer 1 - conv1 9408 118013952 51.46 51.40 63.02 59.80 59.83 64.87 67.36 66.96 72.11 69.46 73.29 73.47 72.05 75.12 76.12 77.75\n", "Layer 2 - layer1.0.conv1 4096 12845056 69.36 73.24 87.57 83.28 85.18 89.60 91.41 91.11 92.38 91.75 94.46 94.51 94.60 95.95 96.53 96.51\n", "Layer 3 - layer1.0.conv2 36864 115605504 77.85 76.26 90.87 89.48 87.31 94.79 94.27 95.04 95.69 96.07 97.36 97.77 98.35 98.51 98.59 98.84\n", "Layer 4 - layer1.0.conv3 16384 51380224 74.81 74.65 86.52 85.80 85.25 91.85 92.78 93.67 94.13 94.69 96.61 97.03 97.37 98.04 98.21 98.47\n", "Layer 5 - layer1.0.downsample.0 16384 51380224 70.95 72.96 83.53 83.34 82.56 89.13 90.62 90.17 91.83 92.69 95.48 94.89 95.68 96.98 97.56 97.72\n", "Layer 6 - layer1.1.conv1 16384 51380224 80.27 79.58 89.82 89.89 88.51 94.56 96.64 95.78 95.81 96.81 98.79 98.90 98.98 99.13 99.62 99.47\n", "Layer 7 - layer1.1.conv2 36864 115605504 81.36 80.95 91.75 90.60 89.61 94.70 95.78 96.18 96.42 97.26 98.65 99.07 99.40 99.11 99.31 99.56\n", "Layer 8 - layer1.1.conv3 16384 51380224 84.45 80.11 91.22 91.70 90.21 95.17 97.05 95.81 96.34 97.23 98.68 98.76 98.90 99.16 99.57 99.46\n", "Layer 9 - layer1.2.conv1 16384 51380224 78.23 79.79 90.12 88.07 89.36 94.62 95.94 94.74 96.23 96.75 97.96 98.41 98.72 99.38 99.35 99.46\n", "Layer 10 - layer1.2.conv2 36864 115605504 76.01 81.53 91.06 87.03 88.27 93.90 95.63 94.26 96.24 96.11 97.54 98.27 98.44 99.32 99.19 99.39\n", "Layer 11 - layer1.2.conv3 16384 51380224 84.47 83.28 94.95 90.99 92.64 95.76 96.95 96.01 96.87 97.31 98.38 98.60 98.72 99.38 99.27 99.51\n", "Layer 12 - layer2.0.conv1 32768 102760448 73.74 73.96 86.78 85.95 85.90 92.32 94.79 93.86 94.62 95.64 97.19 98.22 98.52 98.48 98.84 98.92\n", "Layer 13 - layer2.0.conv2 147456 115605504 82.56 85.70 91.31 93.91 94.03 97.54 97.43 97.65 98.38 98.62 99.24 99.23 99.40 99.61 99.67 99.63\n", "Layer 14 - layer2.0.conv3 65536 51380224 84.70 83.55 93.04 93.13 92.13 96.61 97.37 97.21 97.59 98.14 98.80 98.95 99.18 99.29 99.47 99.43\n", "Layer 15 - layer2.0.downsample.0 131072 102760448 85.10 87.66 92.78 94.96 95.13 98.07 97.97 98.15 98.70 98.88 99.37 99.35 99.40 99.69 99.68 99.71\n", "Layer 16 - layer2.1.conv1 65536 51380224 85.42 85.79 94.04 95.31 94.94 97.92 98.53 98.21 98.84 99.06 99.46 99.53 99.72 99.78 99.81 99.80\n", "Layer 17 - layer2.1.conv2 147456 115605504 76.95 82.75 87.63 91.50 91.76 95.59 97.22 96.07 97.32 97.80 98.24 98.24 98.60 99.24 99.66 99.33\n", "Layer 18 - layer2.1.conv3 65536 51380224 84.76 84.71 93.10 93.66 93.23 97.00 98.18 97.35 98.06 98.41 98.96 99.21 99.32 99.55 99.58 99.59\n", "Layer 19 - layer2.2.conv1 65536 51380224 84.30 85.34 92.70 94.61 94.76 97.72 97.91 98.21 98.54 98.98 99.24 99.35 99.50 99.62 99.63 99.77\n", "Layer 20 - layer2.2.conv2 147456 115605504 84.28 85.43 92.99 94.86 94.90 97.52 97.21 98.11 98.19 99.04 99.28 99.37 99.46 99.63 99.59 99.72\n", "Layer 21 - layer2.2.conv3 65536 51380224 82.19 84.21 91.12 93.38 93.53 96.89 97.14 97.59 97.77 98.66 98.96 99.15 99.25 99.49 99.51 99.57\n", "Layer 22 - layer2.3.conv1 65536 51380224 83.37 84.41 90.46 93.26 93.50 96.71 97.89 96.99 98.14 98.36 99.10 99.23 99.33 99.53 99.75 99.60\n", "Layer 23 - layer2.3.conv2 147456 115605504 82.83 84.03 91.44 93.21 93.25 96.83 98.02 96.96 98.45 98.30 98.97 99.06 99.26 99.31 99.81 99.68\n", "Layer 24 - layer2.3.conv3 65536 51380224 82.93 85.65 91.02 94.14 93.56 97.20 97.97 97.04 98.16 98.36 98.88 98.97 99.20 99.32 99.67 99.62\n", "Layer 25 - layer3.0.conv1 131072 102760448 76.63 77.98 85.99 88.85 88.60 94.26 95.07 94.97 96.21 96.59 97.75 98.04 98.30 98.72 99.11 99.06\n", "Layer 26 - layer3.0.conv2 589824 115605504 87.35 88.68 94.39 96.14 96.19 98.51 98.77 98.72 99.11 99.23 99.53 99.59 99.64 99.73 99.80 99.81\n", "Layer 27 - layer3.0.conv3 262144 51380224 81.22 83.22 90.58 93.19 93.05 96.82 97.38 97.32 97.98 98.28 98.88 99.03 99.16 99.39 99.55 99.53\n", "Layer 28 - layer3.0.downsample.0 524288 102760448 89.75 90.99 96.05 97.20 97.16 98.96 99.21 99.20 99.50 99.58 99.78 99.82 99.86 99.91 99.94 99.93\n", "Layer 29 - layer3.1.conv1 262144 51380224 85.88 87.35 93.43 95.36 96.12 98.64 98.77 98.87 99.22 99.33 99.64 99.67 99.72 99.82 99.88 99.84\n", "Layer 30 - layer3.1.conv2 589824 115605504 85.06 86.24 92.74 95.06 95.30 98.09 98.28 98.36 98.75 99.08 99.46 99.48 99.54 99.69 99.76 99.76\n", "Layer 31 - layer3.1.conv3 262144 51380224 84.34 86.79 92.15 94.84 94.90 97.75 98.15 98.11 98.56 98.94 99.30 99.36 99.45 99.65 99.79 99.70\n", "Layer 32 - layer3.2.conv1 262144 51380224 87.51 89.15 94.15 96.77 96.46 98.81 98.83 98.96 99.19 99.44 99.67 99.71 99.74 99.82 99.85 99.89\n", "Layer 33 - layer3.2.conv2 589824 115605504 87.15 88.67 94.09 95.59 96.14 98.86 98.69 98.91 99.21 99.20 99.64 99.72 99.76 99.85 99.84 99.90\n", "Layer 34 - layer3.2.conv3 262144 51380224 84.86 86.90 92.40 94.99 94.99 98.19 98.19 98.42 98.76 98.97 99.42 99.56 99.62 99.76 99.75 99.88\n", "Layer 35 - layer3.3.conv1 262144 51380224 86.62 89.46 94.06 96.08 95.88 98.70 98.71 98.77 99.01 99.27 99.58 99.66 99.69 99.83 99.87 99.87\n", "Layer 36 - layer3.3.conv2 589824 115605504 86.52 87.97 93.56 96.10 96.11 98.70 98.82 98.89 99.19 99.31 99.68 99.73 99.77 99.88 99.87 99.93\n", "Layer 37 - layer3.3.conv3 262144 51380224 84.19 86.81 92.32 94.94 94.91 98.20 98.37 98.43 98.82 99.00 99.51 99.57 99.64 99.81 99.81 99.87\n", "Layer 38 - layer3.4.conv1 262144 51380224 85.85 88.40 93.55 95.49 95.86 98.35 98.44 98.55 98.79 98.96 99.54 99.59 99.60 99.82 99.86 99.87\n", "Layer 39 - layer3.4.conv2 589824 115605504 85.96 87.38 93.27 95.66 95.63 98.41 98.58 98.56 99.19 99.26 99.64 99.69 99.67 99.87 99.90 99.92\n", "Layer 40 - layer3.4.conv3 262144 51380224 83.45 85.76 91.75 94.49 94.35 97.67 98.09 97.99 98.65 98.94 99.49 99.52 99.48 99.77 99.86 99.85\n", "Layer 41 - layer3.5.conv1 262144 51380224 83.33 85.77 91.79 95.09 94.24 97.46 97.89 97.92 98.71 98.90 99.35 99.52 99.58 99.76 99.79 99.83\n", "Layer 42 - layer3.5.conv2 589824 115605504 84.98 86.67 92.48 94.92 95.13 97.88 98.14 98.32 98.91 99.00 99.44 99.58 99.69 99.80 99.83 99.87\n", "Layer 43 - layer3.5.conv3 262144 51380224 79.78 82.23 89.39 93.14 92.76 96.59 97.04 97.30 98.10 98.41 99.03 99.25 99.44 99.61 99.71 99.75\n", "Layer 44 - layer4.0.conv1 524288 102760448 77.83 79.61 87.11 90.32 90.64 95.39 95.84 95.92 97.17 97.35 98.36 98.60 98.83 99.20 99.37 99.42\n", "Layer 45 - layer4.0.conv2 2359296 115605504 86.18 88.00 93.53 95.66 95.78 98.31 98.47 98.55 99.08 99.16 99.54 99.63 99.69 99.81 99.85 99.86\n", "Layer 46 - layer4.0.conv3 1048576 51380224 78.43 80.48 87.85 91.14 91.27 96.00 96.40 96.47 97.53 97.92 98.81 99.00 99.15 99.45 99.57 99.61\n", "Layer 47 - layer4.0.downsample.0 2097152 102760448 88.49 89.98 95.03 96.79 96.90 98.91 99.06 99.11 99.45 99.51 99.77 99.82 99.85 99.92 99.94 99.94\n", "Layer 48 - layer4.1.conv1 1048576 51380224 82.07 84.02 90.34 93.69 93.72 97.15 97.56 97.76 98.45 98.75 99.27 99.36 99.54 99.67 99.76 99.80\n", "Layer 49 - layer4.1.conv2 2359296 115605504 83.42 85.23 91.16 93.98 93.93 97.26 97.58 97.71 98.36 98.67 99.25 99.34 99.50 99.68 99.76 99.80\n", "Layer 50 - layer4.1.conv3 1048576 51380224 78.08 79.96 86.66 90.48 90.22 95.22 95.76 95.89 96.88 97.65 98.70 98.85 99.13 99.45 99.58 99.66\n", "Layer 51 - layer4.2.conv1 1048576 51380224 76.34 77.93 84.98 87.57 88.47 93.90 93.87 94.16 95.55 95.91 97.66 97.97 98.15 98.88 99.08 99.22\n", "Layer 52 - layer4.2.conv2 2359296 115605504 73.57 74.97 82.32 84.37 86.01 91.92 91.66 92.22 94.02 94.16 96.65 97.13 97.29 98.44 98.74 99.00\n", "Layer 53 - layer4.2.conv3 1048576 51380224 68.78 70.38 78.11 80.29 81.73 89.64 89.43 89.65 91.40 92.65 96.02 96.72 96.93 98.47 98.83 99.15\n", "Layer 54 - fc 2048000 2048000 50.65 52.46 60.48 64.50 65.12 75.20 75.73 75.80 78.57 80.69 85.96 87.26 88.03 91.11 92.15 92.87\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gSFw1eH1G8zh" }, "outputs": [], "source": [ "resnet_layers=['conv1/kernel:0',\n", "'res2a_branch2a/kernel:0',\n", "'res2a_branch2b/kernel:0',\n", "'res2a_branch2c/kernel:0',\n", "'res2a_branch1/kernel:0',\n", "'res2b_branch2a/kernel:0',\n", "'res2b_branch2b/kernel:0',\n", "'res2b_branch2c/kernel:0',\n", "'res2c_branch2a/kernel:0',\n", "'res2c_branch2b/kernel:0',\n", "'res2c_branch2c/kernel:0',\n", "'res3a_branch2a/kernel:0',\n", "'res3a_branch2b/kernel:0',\n", "'res3a_branch2c/kernel:0',\n", "'res3a_branch1/kernel:0',\n", "'res3b_branch2a/kernel:0',\n", "'res3b_branch2b/kernel:0',\n", "'res3b_branch2c/kernel:0',\n", "'res3c_branch2a/kernel:0',\n", "'res3c_branch2b/kernel:0',\n", "'res3c_branch2c/kernel:0',\n", "'res3d_branch2a/kernel:0',\n", "'res3d_branch2b/kernel:0',\n", "'res3d_branch2c/kernel:0',\n", "'res4a_branch2a/kernel:0',\n", "'res4a_branch2b/kernel:0',\n", "'res4a_branch2c/kernel:0',\n", "'res4a_branch1/kernel:0',\n", "'res4b_branch2a/kernel:0',\n", "'res4b_branch2b/kernel:0',\n", "'res4b_branch2c/kernel:0',\n", "'res4c_branch2a/kernel:0',\n", "'res4c_branch2b/kernel:0',\n", "'res4c_branch2c/kernel:0',\n", "'res4d_branch2a/kernel:0',\n", "'res4d_branch2b/kernel:0',\n", "'res4d_branch2c/kernel:0',\n", "'res4e_branch2a/kernel:0',\n", "'res4e_branch2b/kernel:0',\n", "'res4e_branch2c/kernel:0',\n", "'res4f_branch2a/kernel:0',\n", "'res4f_branch2b/kernel:0',\n", "'res4f_branch2c/kernel:0',\n", "'res5a_branch2a/kernel:0',\n", "'res5a_branch2b/kernel:0',\n", "'res5a_branch2c/kernel:0',\n", "'res5a_branch1/kernel:0',\n", "'res5b_branch2a/kernel:0',\n", "'res5b_branch2b/kernel:0',\n", "'res5b_branch2c/kernel:0',\n", "'res5c_branch2a/kernel:0',\n", "'res5c_branch2b/kernel:0',\n", "'res5c_branch2c/kernel:0',\n", "'fc1000/kernel:0']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "31sg-lNhHN7D" }, "outputs": [], "source": [ "from collections import defaultdict\n", "str_sparsities_parsed = defaultdict(list)\n", "for j, l in enumerate(str_sparsities.strip().split('\\n')):\n", " l = l.split('-')[1].strip().split(' ')\n", " if l[0] == 'Overall':\n", " overall_sparsities = map(float, l[3:])\n", " else:\n", " for i, ls in enumerate(l[3:]):\n", " s = overall_sparsities[i]\n", " # Accuracies are between 0 and 1, so devide by 100.\n", " str_sparsities_parsed[s].append(float(ls) / 100.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xrjtum-4HgAT" }, "outputs": [], "source": [ "for k in str_sparsities_parsed:\n", " print(k)\n", " dsr_map = dict(zip(resnet_layers, str_sparsities_parsed[k]))\n", " print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "//research/colab/notebook:notebook_backend", "kind": "private" }, "name": "Resnet-50: Param/Flops Counting [OpenSource].ipynb" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: rigl/imagenet_resnet/imagenet_train_eval.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""This script trains a ResNet model that implements various pruning methods. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import os from absl import app from absl import flags from absl import logging from rigl import sparse_optimizers from rigl import sparse_utils from rigl.imagenet_resnet import mobilenetv1_model from rigl.imagenet_resnet import mobilenetv2_model from rigl.imagenet_resnet import resnet_model from rigl.imagenet_resnet import utils from rigl.imagenet_resnet import vgg from official.resnet import imagenet_input from tensorflow.contrib import estimator as contrib_estimator from tensorflow.contrib import tpu as contrib_tpu from tensorflow.contrib.model_pruning.python import pruning from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.contrib.training.python.training import evaluation from tensorflow_estimator.python.estimator import estimator DST_METHODS = [ 'set', 'momentum', 'rigl', 'static' ] ALL_METHODS = tuple(['scratch', 'baseline', 'snip', 'dnw'] + DST_METHODS) NO_MASK_INIT_METHODS = ('snip', 'dnw', 'baseline') flags.DEFINE_string( 'precision', default='float32', help=('Precision to use; one of: {bfloat16, float32}')) flags.DEFINE_integer('num_workers', 1, 'Number of training workers.') flags.DEFINE_float( 'base_learning_rate', default=0.1, help=('Base learning rate when train batch size is 256.')) flags.DEFINE_float( 'momentum', default=0.9, help=('Momentum parameter used in the MomentumOptimizer.')) flags.DEFINE_integer('ps_task', 0, 'Task id of the replica running the training.') flags.DEFINE_float( 'weight_decay', default=1e-4, help=('Weight decay coefficiant for l2 regularization.')) flags.DEFINE_string('master', '', 'Master job.') flags.DEFINE_string('tpu_job_name', None, 'For complicated TensorFlowFlock') flags.DEFINE_integer( 'steps_per_checkpoint', default=1000, help=('Controls how often checkpoints are generated. More steps per ' 'checkpoint = higher utilization of TPU and generally higher ' 'steps/sec')) flags.DEFINE_integer( 'keep_checkpoint_max', default=0, help=('Number of checkpoints to hold.')) flags.DEFINE_integer( 'seed', default=0, help=('Sets the random seed.')) flags.DEFINE_string( 'data_directory', None, 'The location of the sstable used for training.') flags.DEFINE_string('eval_once_ckpt_prefix', '', 'File name of the eval chekpoint used for evaluation.') flags.DEFINE_string( 'data_format', default='channels_last', help=('A flag to override the data format used in the model. The value' ' is either channels_first or channels_last. To run the network on' ' CPU or TPU, channels_last should be used. For GPU, channels_first' ' will improve performance.')) flags.DEFINE_bool( 'transpose_input', default=False, help='Use TPU double transpose optimization') flags.DEFINE_bool( 'log_mask_imgs_each_iteration', default=False, help='Use to log few masks as images. Be careful when using. This is' ' very likely to slow down your training and create huge logs.') flags.DEFINE_string( 'mask_init_method', default='', help='If not empty string and mask is not loaded from a checkpoint, ' 'indicates the method used for mask initialization. One of the following: ' '`random`, `erdos_renyi`.') flags.DEFINE_integer( 'resnet_depth', default=50, help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,' ' 200}. ResNet-18 and 34 use the pre-activation residual blocks' ' without bottleneck layers. The other models use pre-activation' ' bottleneck layers. Deeper models require more training time and' ' more memory and may require reducing --train_batch_size to prevent' ' running out of memory.')) flags.DEFINE_float('label_smoothing', 0.1, 'Relax confidence in the labels by (1-label_smoothing).') flags.DEFINE_float( 'erk_power_scale', 1.0, 'Softens the ERK distribituion. Value 0 means uniform.' '1 means regular ERK.') flags.DEFINE_integer( 'train_steps', default=2, help=('The number of steps to use for training. Default is 112590 steps' ' which is approximately 90 epochs at batch size 1024. This flag' ' should be adjusted according to the --train_batch_size flag.')) flags.DEFINE_integer( 'train_batch_size', default=1024, help='Batch size for training.') flags.DEFINE_integer( 'eval_batch_size', default=1000, help='Batch size for evaluation.') flags.DEFINE_integer( 'num_train_images', default=1281167, help='Size of training data set.') flags.DEFINE_integer( 'num_eval_images', default=50000, help='Size of evaluation data set.') flags.DEFINE_integer( 'num_label_classes', default=1000, help='Number of classes, at least 2') flags.DEFINE_integer( 'steps_per_eval', default=1251, help=('Controls how often evaluation is performed. Since evaluation is' ' fairly expensive, it is advised to evaluate as infrequently as' ' possible (i.e. up to --train_steps, which evaluates the model only' ' after finishing the entire training regime).')) flags.DEFINE_bool( 'use_tpu', default=False, help=('Use TPU to execute the model for training and evaluation. If' ' --use_tpu=false, will use whatever devices are available to' ' TensorFlow by default (e.g. CPU and GPU)')) flags.DEFINE_integer( 'iterations_per_loop', default=1251, help=('Number of steps to run on TPU before outfeeding metrics to the CPU.' ' If the number of iterations in the loop would exceed the number of' ' train steps, the loop will exit before reaching' ' --iterations_per_loop. The larger this value is, the higher the' ' utilization on the TPU.')) flags.DEFINE_integer( 'num_parallel_calls', default=64, help=('Number of parallel threads in CPU for the input pipeline')) flags.DEFINE_integer( 'num_cores', default=8, help=('Number of TPU cores. For a single TPU device, this is 8 because each' ' TPU has 4 chips each with 2 cores.')) flags.DEFINE_string('output_dir', '/tmp/imagenet/', 'Directory where to write event logs and checkpoint.') flags.DEFINE_bool('use_folder_stub', True, 'If True the output_dir is extended with some parameters.') flags.DEFINE_bool('use_batch_statistics', False, 'If True the forward pass is made in training mode. ') flags.DEFINE_bool('eval_on_train', False, 'If True the evaluation is made on training set.') flags.DEFINE_enum( 'mode', 'train', ('train_and_eval', 'train', 'eval', 'eval_once'), 'One of {"train_and_eval", "train", "eval"}.') flags.DEFINE_integer('export_model_freq', 2502, 'The rate at which estimator exports the model.') flags.DEFINE_enum( 'training_method', 'scratch', ALL_METHODS, 'Method used for training sparse network. `scratch` means initial mask is ' 'kept during training. `set` is for sparse evalutionary training and ' '`baseline` is for dense baseline.') flags.DEFINE_enum( 'init_method', 'baseline', ('baseline', 'sparse'), 'Method for initialization. If sparse and training_method=scratch, then ' 'use initializers that take into account starting sparsity.') # flags.DEFINE_enum( # 'mask_init_method', 'baseline', ('default'), # 'Method for initializating masks. If not default, end_sparsities are used' # ' to define the layer wise random sparse connectivity.') flags.DEFINE_bool( 'is_warm_up', default=True, help=('Boolean for whether to scale weight of regularizer.')) flags.DEFINE_float( 'width', -1., 'Multiplier for the number of channels in each layer') # first and last layer are somewhat special. First layer has almost no # parameters, but 3% of the total flops. Last layer has only .05% of the total # flops but 10% of the total parameters. Depending on whether the goal is max # compression or max acceleration, pruning goals will be different. flags.DEFINE_bool('use_adam', False, 'Whether to use Adam or not') flags.DEFINE_bool('use_sgdr', False, 'Whether to use SGDR for learning rate schedule.') flags.DEFINE_float('sgdr_decay_step', 5, 'Initial cycle length for SGDR.') flags.DEFINE_float('sgdr_t_mul', 1.5, 'Cycle length multiplier for SGDR') flags.DEFINE_float('sgdr_m_mul', .5, 'Learning rate drop at each restart cycle.') flags.DEFINE_float('end_sparsity', 0.9, 'Target sparsity desired by end of training.') flags.DEFINE_float('drop_fraction', 0.3, 'When changing mask dynamically, this fraction decides how ' 'much of the ') flags.DEFINE_string('drop_fraction_anneal', 'constant', 'If not empty the drop fraction is annealed during sparse' ' training. One of the following: `constant`, `cosine` or ' '`exponential_(\\d*\\.?\\d*)$`. For example: ' '`exponential_3`, `exponential_.3`, `exponential_0.3`. ' 'The number after `exponential` defines the exponent.') flags.DEFINE_string('grow_init', 'zeros', 'Passed to the SparseInitializer, one of: zeros, ' 'initial_value, random_normal, random_uniform.') flags.DEFINE_float('s_momentum', 0.9, 'Momentum values for exponential moving average of ' 'gradients. Used when training_method="momentum".') flags.DEFINE_float('rigl_acc_scale', 0., 'Used to scale initial accumulated gradients for new ' 'connections.') flags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin pruning at.') flags.DEFINE_integer('maskupdate_end_step', 25000, 'Step to end pruning at.') flags.DEFINE_integer('maskupdate_frequency', 100, 'Step interval between pruning.') flags.DEFINE_float( 'first_layer_sparsity', 0., 'Sparsity to use for the first layer. Overrides default end_sparsity ' 'if greater than 0. If -1, default sparsity is applied. If 0, layer is not' 'pruned or masked.') flags.DEFINE_float( 'last_layer_sparsity', -1, 'Sparsity to use for the last layer. Overrides default end_sparsity ' 'if greater than 0. If -1, default sparsity is applied. If 0, layer is not' 'pruned or masked.') flags.DEFINE_string( 'load_mask_dir', '', 'Directory of a trained model from which to load only the mask') flags.DEFINE_string( 'initial_value_checkpoint', '', 'Directory of a model from which to load only the parameters') flags.DEFINE_string( 'model_architecture', 'resnet', 'Which architecture to use. Options: resnet, mobilenet_v1, mobilenet_v2.' 'vgg_16, vgg_a, vgg_19.') flags.DEFINE_float('expansion_factor', 6., 'how much to expand filters before depthwise conv') flags.DEFINE_float('training_steps_multiplier', 1.0, 'Training schedule is shortened or extended with the ' 'multiplier, if it is not 1.') flags.DEFINE_integer('block_width', 1, 'width of block') flags.DEFINE_integer('block_height', 1, 'height of block') FLAGS = flags.FLAGS LR_SCHEDULE = [] PARAM_SUFFIXES = ('gamma', 'beta', 'weights', 'biases') MASK_SUFFIX = 'mask' # Learning rate schedule (multiplier, epoch to start) tuples def set_lr_schedule(): """Sets the learning schedule: LR_SCHEDULE for the training.""" global LR_SCHEDULE if FLAGS.model_architecture == 'mobilenet_v2' or FLAGS.model_architecture == 'mobilenet_v1': LR_SCHEDULE = [(1.0, 8), (0.1, 40), (0.01, 75), (0.001, 95), (.0003, 120)] elif (FLAGS.model_architecture == 'resnet' or FLAGS.model_architecture.startswith('vgg')): LR_SCHEDULE = [(1.0, 0), (0.1, 30), (0.01, 70), (0.001, 90), (.0001, 120)] else: raise ValueError('Unknown architecture ' + FLAGS.model_architecture) if FLAGS.training_steps_multiplier != 1.0: multiplier = FLAGS.training_steps_multiplier LR_SCHEDULE = [(x, y * multiplier) for x, y in LR_SCHEDULE] FLAGS.train_steps = int(FLAGS.train_steps * multiplier) FLAGS.maskupdate_begin_step = int(FLAGS.maskupdate_begin_step * multiplier) FLAGS.maskupdate_end_step = int(FLAGS.maskupdate_end_step * multiplier) tf.logging.info( 'Training schedule is updated with multiplier: %.2f' % multiplier) tf.logging.info('LR schedule: %s' % LR_SCHEDULE) tf.logging.info('Training Steps: %d' % FLAGS.train_steps) # The input tensor is in the range of [0, 255], we need to scale them to the # range of [0, 1] MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] CUSTOM_SPARSITY_MAP = {} def set_custom_sparsity_map(): if FLAGS.first_layer_sparsity > 0.: CUSTOM_SPARSITY_MAP[ 'resnet_model/initial_conv'] = FLAGS.first_layer_sparsity if FLAGS.last_layer_sparsity > 0.: CUSTOM_SPARSITY_MAP[ 'resnet_model/final_dense'] = FLAGS.last_layer_sparsity def lr_schedule(current_epoch): """Computes learning rate schedule.""" scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) if FLAGS.use_sgdr: decay_rate = tf.train.cosine_decay_restarts( scaled_lr, current_epoch, FLAGS.sgdr_decay_step, t_mul=FLAGS.sgdr_t_mul, m_mul=FLAGS.sgdr_m_mul) else: decay_rate = ( scaled_lr * LR_SCHEDULE[0][0] * current_epoch / LR_SCHEDULE[0][1]) for mult, start_epoch in LR_SCHEDULE: decay_rate = tf.where(current_epoch < start_epoch, decay_rate, scaled_lr * mult) return decay_rate def train_function(training_method, loss, cross_loss, reg_loss, output_dir, use_tpu): """Training script for resnet model. Args: training_method: string indicating pruning method used to compress model. loss: tensor float32 of the cross entropy + regularization losses. cross_loss: tensor, only cross entropy loss, passed for logging. reg_loss: tensor, only regularization loss, passed for logging. output_dir: string tensor indicating the directory to save summaries. use_tpu: boolean indicating whether to run script on a tpu. Returns: host_call: summary tensors to be computed at each training step. train_op: the optimization term. """ global_step = tf.train.get_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) learning_rate = lr_schedule(current_epoch) if FLAGS.use_adam: # We don't use step decrease for the learning rate. learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) else: optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if use_tpu: # use CrossShardOptimizer when using TPU. optimizer = contrib_tpu.CrossShardOptimizer(optimizer) if training_method == 'set': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseSETOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, stateless_seed_offset=FLAGS.seed) elif training_method == 'static': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseStaticOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, stateless_seed_offset=FLAGS.seed) elif training_method == 'momentum': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseMomentumOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, stateless_seed_offset=FLAGS.seed, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=use_tpu) elif training_method == 'rigl': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseRigLOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, stateless_seed_offset=FLAGS.seed, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=use_tpu) elif training_method == 'snip': optimizer = sparse_optimizers.SparseSnipOptimizer( optimizer, mask_init_method=FLAGS.mask_init_method, custom_sparsity_map=CUSTOM_SPARSITY_MAP, default_sparsity=FLAGS.end_sparsity, use_tpu=use_tpu) elif training_method == 'dnw': optimizer = sparse_optimizers.SparseDNWOptimizer( optimizer, mask_init_method=FLAGS.mask_init_method, custom_sparsity_map=CUSTOM_SPARSITY_MAP, default_sparsity=FLAGS.end_sparsity, use_tpu=use_tpu) elif training_method in ('scratch', 'baseline'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) # UPDATE_OPS needs to be added as a dependency due to batch norm update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops), tf.name_scope('train'): grads_and_vars = optimizer.compute_gradients(loss) vars_with_grad = [v for g, v in grads_and_vars if g is not None] if not vars_with_grad: raise ValueError( 'No gradients provided for any variable, check your graph for ops' ' that do not support gradients, between variables %s and loss %s.' % ([str(v) for _, v in grads_and_vars], loss)) train_op = optimizer.apply_gradients( grads_and_vars, global_step=global_step) metrics = { 'global_step': tf.train.get_or_create_global_step(), 'loss': loss, 'cross_loss': cross_loss, 'reg_loss': reg_loss, 'learning_rate': learning_rate, 'current_epoch': current_epoch, } # Logging drop_fraction if dynamic sparse training. is_dst_method = training_method in DST_METHODS if is_dst_method: metrics['drop_fraction'] = optimizer.drop_fraction def flatten_list_of_vars(var_list): flat_vars = [tf.reshape(v, [-1]) for v in var_list] return tf.concat(flat_vars, axis=-1) if use_tpu: reduced_grads = [tf.tpu.cross_replica_sum(g) for g, _ in grads_and_vars] else: reduced_grads = [g for g, _ in grads_and_vars] metrics['grad_norm'] = tf.norm(flatten_list_of_vars(reduced_grads)) metrics['var_norm'] = tf.norm( flatten_list_of_vars([v for _, v in grads_and_vars])) # Let's log some statistics from a single parameter-mask couple. # This is useful for debugging. test_var = pruning.get_weights()[0] test_var_mask = pruning.get_masks()[0] metrics.update({ 'fw_nz_weight': tf.count_nonzero(test_var), 'fw_nz_mask': tf.count_nonzero(test_var_mask), 'fw_l1_weight': tf.reduce_sum(tf.abs(test_var)) }) masks = pruning.get_masks() global_sparsity = sparse_utils.calculate_sparsity(masks) metrics['global_sparsity'] = global_sparsity metrics.update( utils.mask_summaries(masks, with_img=FLAGS.log_mask_imgs_each_iteration)) host_call = (functools.partial(utils.host_call_fn, output_dir), utils.format_tensors(metrics)) return host_call, train_op def resnet_model_fn_w_pruning(features, labels, mode, params): """The model_fn for ResNet-50 with pruning. Args: features: A float32 batch of images. labels: A int32 batch of labels. mode: Specifies whether training or evaluation. params: Dictionary of parameters passed to the model. Returns: A TPUEstimatorSpec for the model """ width = 1. if FLAGS.width <= 0 else FLAGS.width if isinstance(features, dict): features = features['feature'] if FLAGS.data_format == 'channels_first': assert not FLAGS.transpose_input # channels_first only for GPU features = tf.transpose(features, [0, 3, 1, 2]) if FLAGS.transpose_input and mode != tf_estimator.ModeKeys.PREDICT: features = tf.transpose(features, [3, 0, 1, 2]) # HWCN to NHWC # Normalize the image to zero mean and unit variance. features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype) features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype) training_method = params['training_method'] use_tpu = params['use_tpu'] def build_network(): """Construct the network in the graph.""" if FLAGS.model_architecture == 'mobilenet_v2': network_func = functools.partial( mobilenetv2_model.mobilenet_v2, expansion_factor=FLAGS.expansion_factor) elif FLAGS.model_architecture == 'mobilenet_v1': network_func = functools.partial(mobilenetv1_model.mobilenet_v1) elif FLAGS.model_architecture == 'resnet': prune_first_layer = FLAGS.first_layer_sparsity != 0. network_func = functools.partial( resnet_model.resnet_v1_, resnet_depth=FLAGS.resnet_depth, init_method=FLAGS.init_method, end_sparsity=FLAGS.end_sparsity, prune_first_layer=prune_first_layer) elif FLAGS.model_architecture.startswith('vgg'): network_func = functools.partial( vgg.vgg, vgg_type=FLAGS.model_architecture, init_method=FLAGS.init_method, end_sparsity=FLAGS.end_sparsity) else: raise ValueError('Unknown archiecture ' + FLAGS.archiecture) prune_last_layer = FLAGS.last_layer_sparsity != 0. network = network_func( num_classes=FLAGS.num_label_classes, # TODO remove the pruning_method option. pruning_method='threshold', width=width, prune_last_layer=prune_last_layer, data_format=FLAGS.data_format, weight_decay=FLAGS.weight_decay) is_training = (mode == tf_estimator.ModeKeys.TRAIN) if FLAGS.use_batch_statistics: is_training = True return network(inputs=features, is_training=is_training) if FLAGS.precision == 'bfloat16': with contrib_tpu.bfloat16_scope(): logits = build_network() logits = tf.cast(logits, tf.float32) elif FLAGS.precision == 'float32': logits = build_network() if mode == tf_estimator.ModeKeys.PREDICT: predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } return tf_estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ 'classify': tf_estimator.export.PredictOutput(predictions) }) output_dir = params['output_dir'] # Calculate loss, which includes softmax cross entropy and L2 regularization. one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes) # make sure we reuse the same label smoothing parameter is we're doing # scratch / lottery ticket experiments. label_smoothing = FLAGS.label_smoothing if FLAGS.training_method == 'scratch' and FLAGS.load_mask_dir: scratch_stripped = FLAGS.load_mask_dir.replace('/scratch', '') label_smoothing = float(scratch_stripped.split('/')[15]) tf.logging.info('LABEL SMOOTHING USED: %.2f' % label_smoothing) cross_loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=one_hot_labels, label_smoothing=label_smoothing) # Add regularization loss term reg_loss = tf.losses.get_regularization_loss() loss = cross_loss + reg_loss host_call = None if mode == tf_estimator.ModeKeys.TRAIN: host_call, train_op = train_function(training_method, loss, cross_loss, reg_loss, output_dir, use_tpu) else: train_op = None eval_metrics = None if mode == tf_estimator.ModeKeys.EVAL: def metric_fn(labels, logits, cross_loss, reg_loss): """Calculate eval metrics.""" logging.info('In metric function') eval_metrics = {} predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32) eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5) eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss) eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss) eval_metrics['eval_accuracy'] = tf.metrics.accuracy( labels=labels, predictions=predictions) # If evaluating once lets also calculate sparsities. if FLAGS.mode == 'eval_once': sparsity_summaries = utils.mask_summaries(pruning.get_masks()) # We call mean on a scalar to create tensor, update_op pairs. sparsity_summaries = {k: tf.metrics.mean(v) for k, v in sparsity_summaries.items()} eval_metrics.update(sparsity_summaries) return eval_metrics tensors = [labels, logits, tf.broadcast_to(cross_loss, tf.shape(labels)), tf.broadcast_to(reg_loss, tf.shape(labels))] eval_metrics = (metric_fn, tensors) if (FLAGS.load_mask_dir and FLAGS.training_method not in NO_MASK_INIT_METHODS): def scaffold_fn(): """For initialization, passed to the estimator.""" utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir, FLAGS.output_dir, MASK_SUFFIX) if FLAGS.initial_value_checkpoint: utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint, FLAGS.output_dir, PARAM_SUFFIXES) return tf.train.Scaffold() elif (FLAGS.mask_init_method and FLAGS.training_method not in NO_MASK_INIT_METHODS): def scaffold_fn(): """For initialization, passed to the estimator.""" if FLAGS.initial_value_checkpoint: utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint, FLAGS.output_dir, PARAM_SUFFIXES) all_masks = pruning.get_masks() assigner = sparse_utils.get_mask_init_fn( all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, CUSTOM_SPARSITY_MAP, erk_power_scale=FLAGS.erk_power_scale) def init_fn(scaffold, session): """A callable for restoring variable from a checkpoint.""" del scaffold # Unused. session.run(assigner) return tf.train.Scaffold(init_fn=init_fn) else: assert FLAGS.training_method in NO_MASK_INIT_METHODS scaffold_fn = None tf.logging.info('No mask is set, starting dense.') return contrib_tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, host_call=host_call, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) class ExportModelHook(tf.train.SessionRunHook): """Train hooks called after each session run for exporting the model.""" def __init__(self, classifier, export_dir): self.classifier = classifier self.global_step = None self.export_dir = export_dir self.last_export = 0 self.supervised_input_receiver_fn = ( contrib_estimator.build_raw_supervised_input_receiver_fn( { 'feature': tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3]) }, tf.placeholder(dtype=tf.int32, shape=[None]))) def begin(self): self.global_step = tf.train.get_or_create_global_step() def after_run(self, run_context, run_values): # export saved model global_step = run_context.session.run(self.global_step) if global_step - self.last_export >= FLAGS.export_model_freq: tf.logging.info( 'Export model for prediction (step={}) ...'.format(global_step)) self.last_export = global_step contrib_estimator.export_all_saved_models( self.classifier, os.path.join(self.export_dir, str(global_step)), { tf_estimator.ModeKeys.EVAL: self.supervised_input_receiver_fn, tf_estimator.ModeKeys.PREDICT: imagenet_input.image_serving_input_fn }) def main(argv): del argv # Unused. tf.enable_resource_variables() tf.set_random_seed(FLAGS.seed) set_lr_schedule() set_custom_sparsity_map() folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity), str(FLAGS.maskupdate_begin_step), str(FLAGS.maskupdate_end_step), str(FLAGS.maskupdate_frequency), str(FLAGS.drop_fraction), str(FLAGS.label_smoothing), str(FLAGS.weight_decay)) output_dir = FLAGS.output_dir if FLAGS.use_folder_stub: output_dir = os.path.join(output_dir, folder_stub) export_dir = os.path.join(output_dir, 'export_dir') # we pass the updated eval and train string to the params dictionary. params = {} params['output_dir'] = output_dir params['training_method'] = FLAGS.training_method params['use_tpu'] = FLAGS.use_tpu dataset_func = functools.partial( imagenet_input.ImageNetInput, data_dir=FLAGS.data_directory, transpose_input=False, num_parallel_calls=FLAGS.num_parallel_calls, use_bfloat16=False) imagenet_train, imagenet_eval = [dataset_func(is_training=is_training) for is_training in [True, False]] run_config = tpu_config.RunConfig( master=FLAGS.master, model_dir=output_dir, save_checkpoints_steps=FLAGS.steps_per_checkpoint, keep_checkpoint_max=FLAGS.keep_checkpoint_max, session_config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=False), tpu_config=tpu_config.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_cores, tpu_job_name=FLAGS.tpu_job_name)) classifier = tpu_estimator.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=resnet_model_fn_w_pruning, params=params, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size) cpu_classifier = tpu_estimator.TPUEstimator( use_tpu=FLAGS.use_tpu, model_fn=resnet_model_fn_w_pruning, params=params, config=run_config, train_batch_size=FLAGS.train_batch_size, export_to_tpu=False, eval_batch_size=FLAGS.eval_batch_size) if FLAGS.num_eval_images % FLAGS.eval_batch_size != 0: raise ValueError( 'eval_batch_size (%d) must evenly divide num_eval_images(%d)!' % (FLAGS.eval_batch_size, FLAGS.num_eval_images)) eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size if FLAGS.mode == 'eval_once': ckpt_path = os.path.join(output_dir, FLAGS.eval_once_ckpt_prefix) dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval classifier.evaluate( input_fn=dataset.input_fn, steps=eval_steps, checkpoint_path=ckpt_path, name='{0}'.format(FLAGS.eval_once_ckpt_prefix)) elif FLAGS.mode == 'eval': # Run evaluation when there's a new checkpoint for ckpt in evaluation.checkpoints_iterator(output_dir): tf.logging.info('Starting to evaluate.') try: dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval classifier.evaluate( input_fn=dataset.input_fn, steps=eval_steps, checkpoint_path=ckpt, name='eval') # Terminate eval job when final checkpoint is reached global_step = int(os.path.basename(ckpt).split('-')[1]) if global_step >= FLAGS.train_steps: tf.logging.info( 'Evaluation finished after training step %d' % global_step) break except tf.errors.NotFoundError: logging('Checkpoint no longer exists,skipping checkpoint.') else: global_step = estimator._load_global_step_from_checkpoint_dir(output_dir) # Session run hooks to export model for prediction export_hook = ExportModelHook(cpu_classifier, export_dir) hooks = [export_hook] if FLAGS.mode == 'train': tf.logging.info('start training...') classifier.train( input_fn=imagenet_train.input_fn, hooks=hooks, max_steps=FLAGS.train_steps) else: assert FLAGS.mode == 'train_and_eval' tf.logging.info('start training and eval...') while global_step < FLAGS.train_steps: next_checkpoint = min(global_step + FLAGS.steps_per_eval, FLAGS.train_steps) classifier.train( input_fn=imagenet_train.input_fn, max_steps=next_checkpoint) global_step = next_checkpoint logging('Completed training up to step :', global_step) classifier.evaluate(input_fn=imagenet_eval.input_fn, steps=eval_steps) if __name__ == '__main__': app.run(main) ================================================ FILE: rigl/imagenet_resnet/mobilenetv1_model.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Straightforward MobileNet v1 for inputs of size 224x224.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools from absl import flags from rigl.imagenet_resnet import resnet_model from rigl.imagenet_resnet.pruning_layers import sparse_conv2d from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected import tensorflow.compat.v1 as tf from tensorflow.contrib import layers as contrib_layers FLAGS = flags.FLAGS def _make_divisible(v, divisor=8, min_value=None): if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v def depthwise_conv2d_fixed_padding(inputs, kernel_size, stride, data_format='channels_first', name=None): """Depthwise Strided 2-D convolution with explicit padding. The padding is consistent and is based only on `kernel_size`, not on the dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. kernel_size: Int designating size of kernel to be used in the convolution. stride: Int specifying the stride. If stride >1, the input is downsampled. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. name: String that specifies name for model layer. Returns: The output activation tensor of size [batch, filters, height_out, width_out] Raises: ValueError: If the data_format provided is not a valid string. """ if stride > 1: inputs = resnet_model.fixed_padding( inputs, kernel_size, data_format=data_format) padding = 'SAME' if stride == 1 else 'VALID' if data_format == 'channels_last': data_format_channels = 'NHWC' elif data_format == 'channels_first': data_format_channels = 'NCHW' else: raise ValueError('Not a valid channel string:', data_format) return contrib_layers.separable_conv2d( inputs=inputs, num_outputs=None, kernel_size=kernel_size, stride=stride, padding=padding, data_format=data_format_channels, activation_fn=None, weights_regularizer=None, biases_initializer=None, biases_regularizer=None, scope=name) def conv2d_fixed_padding(inputs, filters, kernel_size, strides, pruning_method='baseline', data_format='channels_first', weight_decay=0., name=None): """Strided 2-D convolution with explicit padding. The padding is consistent and is based only on `kernel_size`, not on the dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. kernel_size: Int designating size of kernel to be used in the convolution. strides: Int specifying the stride. If stride >1, the input is downsampled. pruning_method: String that specifies the pruning method used to identify which weights to remove. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. weight_decay: Weight for the l2 regularization loss. name: String that specifies name for model layer. Returns: The output activation tensor of size [batch, filters, height_out, width_out] Raises: ValueError: If the data_format provided is not a valid string. """ if strides > 1: inputs = resnet_model.fixed_padding( inputs, kernel_size, data_format=data_format) padding = 'VALID' else: padding = 'SAME' kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) return sparse_conv2d( x=inputs, units=filters, activation=None, kernel_size=[kernel_size, kernel_size], use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, bias_initializer=None, biases_regularizer=None, sparsity_technique=pruning_method, normalizer_fn=None, strides=[strides, strides], padding=padding, data_format=data_format, name=name) def mbv1_block_(inputs, filters, is_training, stride, width=1., block_id=0, pruning_method='baseline', data_format='channels_first', weight_decay=0.): """Standard building block for mobilenetv1 networks. Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. is_training: Boolean specifying whether the model is training. stride: Int specifying the stride. If stride >1, the input is downsampled. width: multiplier for channel dimensions block_id: which block this is pruning_method: String that specifies the pruning method used to identify which weights to remove. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. weight_decay: Weight for the l2 regularization loss. Returns: The output activation tensor. """ # separable_conv_2d followed by contracting 1x1 conv. end_point = 'depthwise_nxn_%s' % block_id # Depthwise depthwise_out = depthwise_conv2d_fixed_padding( inputs=inputs, kernel_size=3, stride=stride, data_format=data_format, name=end_point) depthwise_out = resnet_model.batch_norm_relu( depthwise_out, is_training, relu=True, data_format=data_format) # Contraction end_point = 'contraction_1x1_%s' % block_id divisible_by = 8 if block_id == 0: divisible_by = 1 out_filters = _make_divisible(int(width * filters), divisor=divisible_by) contraction_out = conv2d_fixed_padding( inputs=depthwise_out, filters=out_filters, kernel_size=1, strides=1, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay, name=end_point) contraction_out = resnet_model.batch_norm_relu( contraction_out, is_training, relu=True, data_format=data_format) output = contraction_out return output def mobilenet_v1_generator(num_classes=1000, pruning_method='baseline', width=1., prune_last_layer=False, data_format='channels_first', weight_decay=0., name=None): """Generator for mobilenet v2 models. Args: num_classes: Int number of possible classes for image classification. pruning_method: String that specifies the pruning method used to identify which weights to remove. width: Float that scales the number of filters in each layer. prune_last_layer: Whether or not to prune the last layer. data_format: String either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. weight_decay: Weight for the l2 regularization loss. name: String that specifies name for model layer. Returns: Model `function` that takes in `inputs` and `is_training` and returns the output `Tensor` of the ResNet model. """ def model(inputs, is_training): """Creation of the model graph.""" with tf.variable_scope(name, 'resnet_model'): inputs = resnet_model.fixed_padding( inputs, kernel_size=3, data_format=data_format) padding = 'VALID' kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) inputs = tf.layers.conv2d( inputs=inputs, filters=_make_divisible(32 * width), kernel_size=3, strides=2, padding=padding, use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, data_format=data_format, name='initial_conv') inputs = tf.identity(inputs, 'initial_conv') inputs = resnet_model.batch_norm_relu( inputs, is_training, data_format=data_format) mb_block = functools.partial( mbv1_block_, is_training=is_training, width=width, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay) inputs = mb_block(inputs, filters=64, stride=1, block_id=0) inputs = mb_block(inputs, filters=128, stride=2, block_id=1) inputs = mb_block(inputs, filters=128, stride=1, block_id=2) inputs = mb_block(inputs, filters=256, stride=2, block_id=3) inputs = mb_block(inputs, filters=256, stride=1, block_id=4) inputs = mb_block(inputs, filters=512, stride=2, block_id=5) inputs = mb_block(inputs, filters=512, stride=1, block_id=6) inputs = mb_block(inputs, filters=512, stride=1, block_id=7) inputs = mb_block(inputs, filters=512, stride=1, block_id=8) inputs = mb_block(inputs, filters=512, stride=1, block_id=9) inputs = mb_block(inputs, filters=512, stride=1, block_id=10) inputs = mb_block(inputs, filters=1024, stride=2, block_id=11) inputs = mb_block(inputs, filters=1024, stride=1, block_id=12) last_block_filters = _make_divisible(int(1024 * width), 8) if data_format == 'channels_last': pool_size = (inputs.shape[1], inputs.shape[2]) elif data_format == 'channels_first': pool_size = (inputs.shape[2], inputs.shape[3]) inputs = tf.layers.average_pooling2d( inputs=inputs, pool_size=pool_size, strides=1, padding='VALID', data_format=data_format, name='final_avg_pool') inputs = tf.identity(inputs, 'final_avg_pool') inputs = tf.reshape(inputs, [-1, last_block_filters]) kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) if prune_last_layer: inputs = sparse_fully_connected( x=inputs, units=num_classes, sparsity_technique=pruning_method if prune_last_layer else 'baseline', kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='final_dense') else: inputs = tf.layers.dense( inputs=inputs, units=num_classes, activation=None, use_bias=True, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='final_dense') inputs = tf.identity(inputs, 'final_dense') return inputs model.default_image_size = 224 return model def mobilenet_v1(num_classes, pruning_method='baseline', width=1., prune_last_layer=True, data_format='channels_first', weight_decay=0.): """Returns the mobilenet_V1 model for a given size and number of output classes. Args: num_classes: Int number of possible classes for image classification. pruning_method: String that specifies the pruning method used to identify which weights to remove. width: Float multiplier of the number of filters in each layer. prune_last_layer: Whether or not to prune the last layer. data_format: String specifying either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. weight_decay: Weight for the l2 regularization loss. Raises: ValueError: If the resnet_depth int is not in the model_params dictionary. """ return mobilenet_v1_generator(num_classes, pruning_method, width, prune_last_layer, data_format, weight_decay) ================================================ FILE: rigl/imagenet_resnet/mobilenetv2_model.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Straightforward MobileNet v2 for inputs of size 224x224.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools from absl import flags from rigl.imagenet_resnet import resnet_model from rigl.imagenet_resnet.pruning_layers import sparse_conv2d from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected import tensorflow.compat.v1 as tf from tensorflow.contrib import layers as contrib_layers FLAGS = flags.FLAGS def _make_divisible(v, divisor=8, min_value=None): if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if new_v < 0.9 * v: new_v += divisor return new_v def depthwise_conv2d_fixed_padding(inputs, kernel_size, stride, data_format='channels_first', name=None): """Depthwise Strided 2-D convolution with explicit padding. The padding is consistent and is based only on `kernel_size`, not on the dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. kernel_size: Int designating size of kernel to be used in the convolution. stride: Int specifying the stride. If stride >1, the input is downsampled. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. name: String that specifies name for model layer. Returns: The output activation tensor of size [batch, filters, height_out, width_out] Raises: ValueError: If the data_format provided is not a valid string. """ if stride > 1: inputs = resnet_model.fixed_padding( inputs, kernel_size, data_format=data_format) padding = 'SAME' if stride == 1 else 'VALID' if data_format == 'channels_last': data_format_channels = 'NHWC' elif data_format == 'channels_first': data_format_channels = 'NCHW' else: raise ValueError('Not a valid channel string:', data_format) return contrib_layers.separable_conv2d( inputs=inputs, num_outputs=None, kernel_size=kernel_size, stride=stride, padding=padding, data_format=data_format_channels, activation_fn=None, weights_regularizer=None, biases_initializer=None, biases_regularizer=None, scope=name) def conv2d_fixed_padding(inputs, filters, kernel_size, strides, pruning_method='baseline', data_format='channels_first', weight_decay=0., name=None): """Strided 2-D convolution with explicit padding. The padding is consistent and is based only on `kernel_size`, not on the dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. kernel_size: Int designating size of kernel to be used in the convolution. strides: Int specifying the stride. If stride >1, the input is downsampled. pruning_method: String that specifies the pruning method used to identify which weights to remove. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. weight_decay: Weight for the l2 regularization loss. name: String that specifies name for model layer. Returns: The output activation tensor of size [batch, filters, height_out, width_out] Raises: ValueError: If the data_format provided is not a valid string. """ if strides > 1: inputs = resnet_model.fixed_padding( inputs, kernel_size, data_format=data_format) padding = 'VALID' else: padding = 'SAME' kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) return sparse_conv2d( x=inputs, units=filters, activation=None, kernel_size=[kernel_size, kernel_size], use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, bias_initializer=None, biases_regularizer=None, sparsity_technique=pruning_method, normalizer_fn=None, strides=[strides, strides], padding=padding, data_format=data_format, name=name) def inverted_res_block_(inputs, filters, is_training, stride, width=1., expansion_factor=6., block_id=0, pruning_method='baseline', data_format='channels_first', weight_decay=0.,): """Standard building block for mobilenetv2 networks. Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. is_training: Boolean specifying whether the model is training. stride: Int specifying the stride. If stride >1, the input is downsampled. width: multiplier for channel dimensions expansion_factor: How much to increase the filters before the depthwise conv. block_id: which block this is pruning_method: String that specifies the pruning method used to identify which weights to remove. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. weight_decay: Weight for the l2 regularization loss. Returns: The output activation tensor. """ # 1x1 expanded conv, followed by separable_conv_2d followed by # contracting 1x1 conv. shortcut = inputs if data_format == 'channels_first': prev_depth = inputs.get_shape().as_list()[1] elif data_format == 'channels_last': prev_depth = inputs.get_shape().as_list()[3] else: raise ValueError('Unknown data_format ' + data_format) # Expand multiplier = expansion_factor if block_id > 0 else 1 # skip the expansion if this is the first block if block_id: end_point = 'expand_1x1_%s' % block_id inputs = conv2d_fixed_padding( inputs=inputs, filters=int(multiplier * prev_depth), kernel_size=1, strides=1, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay, name=end_point) inputs = resnet_model.batch_norm_relu( inputs, is_training, relu=True, data_format=data_format) end_point = 'depthwise_nxn_%s' % block_id # Depthwise depthwise_out = depthwise_conv2d_fixed_padding( inputs=inputs, kernel_size=3, stride=stride, data_format=data_format, name=end_point) depthwise_out = resnet_model.batch_norm_relu( depthwise_out, is_training, relu=True, data_format=data_format) # Contraction end_point = 'contraction_1x1_%s' % block_id divisible_by = 8 if block_id == 0: divisible_by = 1 out_filters = _make_divisible(int(width * filters), divisor=divisible_by) contraction_out = conv2d_fixed_padding( inputs=depthwise_out, filters=out_filters, kernel_size=1, strides=1, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay, name=end_point) contraction_out = resnet_model.batch_norm_relu( contraction_out, is_training, relu=False, data_format=data_format) output = contraction_out if prev_depth == out_filters and stride == 1: output += shortcut return output def mobilenet_v2_generator(num_classes=1000, pruning_method='baseline', width=1., expansion_factor=6., prune_last_layer=False, data_format='channels_first', weight_decay=0., name=None): """Generator for mobilenet v2 models. Args: num_classes: Int number of possible classes for image classification. pruning_method: String that specifies the pruning method used to identify which weights to remove. width: Float that scales the number of filters in each layer. expansion_factor: How much to expand the input filters for the depthwise conv. prune_last_layer: Whether or not to prune the last layer. data_format: String either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. weight_decay: Weight for the l2 regularization loss. name: String that specifies name for model layer. Returns: Model `function` that takes in `inputs` and `is_training` and returns the output `Tensor` of the ResNet model. """ def model(inputs, is_training): """Creation of the model graph.""" with tf.variable_scope(name, 'resnet_model'): inputs = resnet_model.fixed_padding( inputs, kernel_size=3, data_format=data_format) padding = 'VALID' kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) inputs = tf.layers.conv2d( inputs=inputs, filters=_make_divisible(32 * width), kernel_size=3, strides=2, padding=padding, use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, data_format=data_format, name='initial_conv') inputs = tf.identity(inputs, 'initial_conv') inputs = resnet_model.batch_norm_relu( inputs, is_training, data_format=data_format) inverted_res_block = functools.partial( inverted_res_block_, is_training=is_training, width=width, expansion_factor=expansion_factor, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay) inputs = inverted_res_block(inputs, filters=16, stride=1, block_id=0) inputs = inverted_res_block(inputs, filters=24, stride=2, block_id=1) inputs = inverted_res_block(inputs, filters=24, stride=1, block_id=2) inputs = inverted_res_block(inputs, filters=32, stride=2, block_id=3) inputs = inverted_res_block(inputs, filters=32, stride=1, block_id=4) inputs = inverted_res_block(inputs, filters=32, stride=1, block_id=5) inputs = inverted_res_block(inputs, filters=64, stride=2, block_id=6) inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=7) inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=8) inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=9) inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=10) inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=11) inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=12) inputs = inverted_res_block(inputs, filters=160, stride=2, block_id=13) inputs = inverted_res_block(inputs, filters=160, stride=1, block_id=14) inputs = inverted_res_block(inputs, filters=160, stride=1, block_id=15) inputs = inverted_res_block(inputs, filters=320, stride=1, block_id=16) last_block_filters = max(1280, _make_divisible(1280 * width, 8)) inputs = conv2d_fixed_padding( inputs=inputs, filters=last_block_filters, kernel_size=1, strides=1, pruning_method=pruning_method, data_format=data_format, weight_decay=weight_decay, name='final_1x1_conv') inputs = resnet_model.batch_norm_relu( inputs, is_training, data_format=data_format) if data_format == 'channels_last': pool_size = (inputs.shape[1], inputs.shape[2]) elif data_format == 'channels_first': pool_size = (inputs.shape[2], inputs.shape[3]) inputs = tf.layers.average_pooling2d( inputs=inputs, pool_size=pool_size, strides=1, padding='VALID', data_format=data_format, name='final_avg_pool') inputs = tf.identity(inputs, 'final_avg_pool') inputs = tf.reshape(inputs, [-1, last_block_filters]) kernel_initializer = tf.variance_scaling_initializer() kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) if prune_last_layer: inputs = sparse_fully_connected( x=inputs, units=num_classes, sparsity_technique=pruning_method if prune_last_layer else 'baseline', kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='final_dense') else: inputs = tf.layers.dense( inputs=inputs, units=num_classes, activation=None, use_bias=True, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='final_dense') inputs = tf.identity(inputs, 'final_dense') return inputs model.default_image_size = 224 return model def mobilenet_v2(num_classes, pruning_method='baseline', width=1., expansion_factor=6., prune_last_layer=True, data_format='channels_first', weight_decay=0.,): """Returns the mobilenet_V2 model for a given size and number of output classes. Args: num_classes: Int number of possible classes for image classification. pruning_method: String that specifies the pruning method used to identify which weights to remove. width: Float multiplier of the number of filters in each layer. expansion_factor: How much to increase the number of filters before the depthwise conv. prune_last_layer: Whether or not to prune the last layer. data_format: String specifying either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. weight_decay: Weight for the l2 regularization loss. Raises: ValueError: If the resnet_depth int is not in the model_params dictionary. """ return mobilenet_v2_generator( num_classes, pruning_method, width, expansion_factor, prune_last_layer, data_format, weight_decay) ================================================ FILE: rigl/imagenet_resnet/pruning_layers.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tensorflow layers with parameters for implementing pruning.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.model_pruning.python.layers import layers from tensorflow.python.ops import init_ops def get_model_variables(getter, name, shape=None, dtype=None, initializer=None, regularizer=None, trainable=True, collections=None, caching_device=None, partitioner=None, rename=None, use_resource=None, **_): """This ensure variables are retrieved in a consistent way for core layers.""" short_name = name.split('/')[-1] if rename and short_name in rename: name_components = name.split('/') name_components[-1] = rename[short_name] name = '/'.join(name_components) return variables.model_variable( name, shape=shape, dtype=dtype, initializer=initializer, regularizer=regularizer, collections=collections, trainable=trainable, caching_device=caching_device, partitioner=partitioner, custom_getter=getter, use_resource=use_resource) def variable_getter(rename=None): """Ensures scope is respected and consistently used.""" def layer_variable_getter(getter, *args, **kwargs): kwargs['rename'] = rename return get_model_variables(getter, *args, **kwargs) return layer_variable_getter def sparse_conv2d(x, units, kernel_size, activation=None, use_bias=False, kernel_initializer=None, kernel_regularizer=None, bias_initializer=None, biases_regularizer=None, sparsity_technique='baseline', normalizer_fn=None, strides=(1, 1), padding='SAME', data_format='channels_last', name=None): """Function that constructs conv2d with any desired pruning method. Args: x: Input, float32 tensor. units: Int representing size of output tensor. kernel_size: The size of the convolutional window, int of list of ints. activation: If None, a linear activation is used. use_bias: Boolean specifying whether bias vector should be used. kernel_initializer: Initializer for the convolution weights. kernel_regularizer: Regularization method for the convolution weights. bias_initializer: Initalizer of the bias vector. biases_regularizer: Optional regularizer for the bias vector. sparsity_technique: Method used to introduce sparsity. ['threshold', 'baseline'] normalizer_fn: function used to transform the output activations. strides: stride length of convolution, a single int is expected. padding: May be populated as 'VALID' or 'SAME'. data_format: Either 'channels_last', 'channels_first'. name: String speciying name scope of layer in network. Returns: Output: activations. Raises: ValueError: If the rank of the input is not greater than 2. """ if data_format == 'channels_last': data_format_channels = 'NHWC' elif data_format == 'channels_first': data_format_channels = 'NCHW' else: raise ValueError('Not a valid channel string:', data_format) layer_variable_getter = variable_getter({ 'bias': 'biases', 'kernel': 'weights', }) input_rank = x.get_shape().ndims if input_rank != 4: raise ValueError('Rank not supported {}'.format(input_rank)) with tf.variable_scope( name, 'Conv', [x], custom_getter=layer_variable_getter) as sc: input_shape = x.get_shape().as_list() if input_shape[-1] is None: raise ValueError('The last dimension of the inputs to `Convolution` ' 'should be defined. Found `None`.') pruning_methods = ['threshold'] if sparsity_technique in pruning_methods: return layers.masked_conv2d( inputs=x, num_outputs=units, kernel_size=kernel_size[0], stride=strides[0], padding=padding, data_format=data_format_channels, rate=1, activation_fn=activation, weights_initializer=kernel_initializer, weights_regularizer=kernel_regularizer, normalizer_fn=normalizer_fn, normalizer_params=None, biases_initializer=bias_initializer, biases_regularizer=biases_regularizer, outputs_collections=None, trainable=True, scope=sc) elif sparsity_technique == 'baseline': return tf.layers.conv2d( inputs=x, filters=units, kernel_size=kernel_size, strides=strides, padding=padding, use_bias=use_bias, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, data_format=data_format, name=name) else: raise ValueError( 'Unsupported sparsity technique {}'.format(sparsity_technique)) def sparse_fully_connected(x, units, activation=None, use_bias=True, kernel_initializer=None, kernel_regularizer=None, bias_initializer=init_ops.zeros_initializer(), biases_regularizer=None, sparsity_technique='baseline', name=None): """Constructs sparse_fully_connected with any desired pruning method. Args: x: Input, float32 tensor. units: Int representing size of output tensor. activation: If None, a linear activation is used. use_bias: Boolean specifying whether bias vector should be used. kernel_initializer: Initializer for the convolution weights. kernel_regularizer: Regularization method for the convolution weights. bias_initializer: Initalizer of the bias vector. biases_regularizer: Optional regularizer for the bias vector. sparsity_technique: Method used to introduce sparsity. ['baseline', 'threshold'] name: String speciying name scope of layer in network. Returns: Output: activations. Raises: ValueError: If the rank of the input is not greater than 2. """ layer_variable_getter = variable_getter({ 'bias': 'biases', 'kernel': 'weights', }) with tf.variable_scope( name, 'Dense', [x], custom_getter=layer_variable_getter) as sc: input_shape = x.get_shape().as_list() if input_shape[-1] is None: raise ValueError('The last dimension of the inputs to `Dense` ' 'should be defined. Found `None`.') pruning_methods = ['threshold'] if sparsity_technique in pruning_methods: return layers.masked_fully_connected( inputs=x, num_outputs=units, activation_fn=activation, weights_initializer=kernel_initializer, weights_regularizer=kernel_regularizer, biases_initializer=bias_initializer, biases_regularizer=biases_regularizer, outputs_collections=None, trainable=True, scope=sc) elif sparsity_technique == 'baseline': return tf.layers.dense( inputs=x, units=units, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, bias_initializer=bias_initializer, bias_regularizer=biases_regularizer, name=name) else: raise ValueError( 'Unsupported sparsity technique {}'.format(sparsity_technique)) ================================================ FILE: rigl/imagenet_resnet/resnet_model.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ResNet modified to including pruning layers if specified. Residual networks (ResNets) were proposed in: [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun Deep Residual Learning for Image Recognition. arXiv:1512.03385 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math from absl import flags from rigl.imagenet_resnet.pruning_layers import sparse_conv2d from rigl.imagenet_resnet.pruning_layers import sparse_fully_connected import tensorflow.compat.v1 as tf from tensorflow.contrib import layers as contrib_layers from tensorflow.python.ops import init_ops FLAGS = flags.FLAGS BATCH_NORM_DECAY = 0.9 BATCH_NORM_EPSILON = 1e-5 def batch_norm_relu(inputs, is_training, relu=True, init_zero=False, data_format='channels_first'): """Performs a batch normalization followed by a ReLU. Args: inputs: `Tensor` of shape `[batch, channels, ...]`. is_training: `bool` for whether the model is training. relu: `bool` if False, omits the ReLU operation. init_zero: `bool` if True, initializes scale parameter of batch normalization with 0 instead of 1 (default). data_format: `str` either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. Returns: A normalized `Tensor` with the same `data_format`. """ if init_zero: gamma_initializer = tf.zeros_initializer() else: gamma_initializer = tf.ones_initializer() if data_format == 'channels_first': axis = 1 else: axis = 3 inputs = tf.layers.batch_normalization( inputs=inputs, axis=axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, center=True, scale=True, training=is_training, fused=True, gamma_initializer=gamma_initializer) if relu: inputs = tf.nn.relu(inputs) return inputs def fixed_padding(inputs, kernel_size, data_format='channels_first'): """Pads the input along the spatial dimensions independently of input size. Args: inputs: `Tensor` of size `[batch, channels, height, width]` or `[batch, height, width, channels]` depending on `data_format`. kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d` operations. Should be a positive integer. data_format: `str` either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. Returns: A padded `Tensor` of the same `data_format` with size either intact (if `kernel_size == 1`) or padded (if `kernel_size > 1`). """ pad_total = kernel_size - 1 pad_beg = pad_total // 2 pad_end = pad_total - pad_beg if data_format == 'channels_first': padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]]) else: padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) return padded_inputs class RandomSparseInitializer(init_ops.Initializer): """An initializer that sets a fraction of values to zero.""" def __init__(self, sparsity, seed=None, dtype=tf.float32): if sparsity < 0. or sparsity > 1.: raise ValueError('sparsity must be in the range [0., 1.].') self.kernel_initializer = tf.variance_scaling_initializer(seed=seed, dtype=dtype) self.seed = seed self.dtype = dtype self.sparsity = float(sparsity) def __call__(self, *args, **kwargs): init_tensor = self.kernel_initializer(*args, **kwargs) rand_vals = tf.random_uniform(tf.shape(init_tensor)) threshold = tf.constant(self.sparsity) masked_tensor = tf.where(rand_vals < threshold, tf.zeros_like(rand_vals), init_tensor) return masked_tensor def get_config(self): return { 'seed': self.seed, 'dtype': self.dtype.name, 'sparsity': self.sparsity } class SparseConvVarianceScalingInitializer(init_ops.Initializer): """Define an initializer for an already sparse layer.""" def __init__(self, sparsity, seed=None, dtype=tf.float32): if sparsity < 0. or sparsity >= 1.: raise ValueError('sparsity must be in the range [0., 1.).') self.sparsity = sparsity self.seed = seed def __call__(self, shape, dtype=None, partition_info=None): if partition_info is not None: raise ValueError('partition_info not supported.') if dtype is None: dtype = self.dtype # Calculate number of non-zero weights nnz = 1. for d in shape: nnz *= d nnz *= (1. - self.sparsity) input_channels = shape[-2] n = nnz / input_channels variance = (2. / n)**.5 return tf.random_normal(shape, 0, variance, dtype, seed=self.seed) def get_config(self): return { 'seed': self.seed, 'dtype': self.dtype.name, } class SparseFCVarianceScalingInitializer(init_ops.Initializer): """Define an initializer for an already sparse layer.""" def __init__(self, sparsity, seed=None, dtype=tf.float32): if sparsity < 0. or sparsity >= 1.: raise ValueError('sparsity must be in the range [0., 1.).') self.sparsity = sparsity self.seed = seed def __call__(self, shape, dtype=None, partition_info=None): if partition_info is not None: raise ValueError('partition_info not supported.') if dtype is None: dtype = self.dtype if len(shape) != 2: raise ValueError('Weights must be 2-dimensional.') fan_in = shape[0] fan_out = shape[1] # Calculate number of non-zero weights nnz = 1. for d in shape: nnz *= d nnz *= (1. - self.sparsity) limit = math.sqrt(6. / (nnz / fan_out + nnz / fan_in)) return tf.random_uniform(shape, -limit, limit, dtype, seed=self.seed) def get_config(self): return { 'seed': self.seed, 'dtype': self.dtype.name, } def _pick_initializer(kernel_initializer, init_method, pruning_method, end_sparsity): """Updates the initializer selected, if necessary.""" if init_method == 'sparse': if pruning_method != 'threshold': raise ValueError( 'Unsupported combination of flags, pruning_method must be threshold' ' if init_method is `sparse`.') else: kernel_initializer = SparseFCVarianceScalingInitializer(end_sparsity) elif init_method == 'random_zeros': if pruning_method != 'baseline': raise ValueError( 'Unsupported combination of flags, pruning_method must be ' 'baseline if init_method is `random_zeros`.') else: kernel_initializer = RandomSparseInitializer(end_sparsity) return kernel_initializer def conv2d_fixed_padding(inputs, filters, kernel_size, strides, pruning_method='baseline', init_method='baseline', data_format='channels_first', end_sparsity=0., weight_decay=0., init_scale=1.0, name=None): """Strided 2-D convolution with explicit padding. The padding is consistent and is based only on `kernel_size`, not on the dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. kernel_size: Int designating size of kernel to be used in the convolution. strides: Int specifying the stride. If stride >1, the input is downsampled. pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' set random weights to zero using end_sparsoty parameter and used with 'baseline' method. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. init_scale: float, passed to the VarianceScalingInitializer. name: String that specifies name for model layer. Returns: The output activation tensor of size [batch, filters, height_out, width_out] Raises: ValueError: If the data_format provided is not a valid string. """ if strides > 1: inputs = fixed_padding( inputs, kernel_size, data_format=data_format) padding = 'SAME' if strides == 1 else 'VALID' kernel_initializer = tf.variance_scaling_initializer(scale=init_scale) kernel_initializer = _pick_initializer(kernel_initializer, init_method, pruning_method, end_sparsity) kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) return sparse_conv2d( x=inputs, units=filters, activation=None, kernel_size=[kernel_size, kernel_size], use_bias=False, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, bias_initializer=None, biases_regularizer=None, sparsity_technique=pruning_method, normalizer_fn=None, strides=[strides, strides], padding=padding, data_format=data_format, name=name) def residual_block_(inputs, filters, is_training, strides, use_projection=False, pruning_method='baseline', init_method='baseline', data_format='channels_first', end_sparsity=0., weight_decay=0., name=''): """Standard building block for residual networks with BN after convolutions. Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. is_training: Boolean specifying whether the model is training. strides: Int specifying the stride. If stride >1, the input is downsampled. use_projection: Boolean for whether the layer should use a projection shortcut Often, use_projection=True for the first block of a block group. pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' sets random weights to zero using end_sparsoty parameter and used with 'baseline' method. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. name: String that specifies name for model layer. Returns: The output activation tensor. """ shortcut = inputs if use_projection: # Projection shortcut in first layer to match filters and strides end_point = 'residual_projection_%s' % name shortcut = conv2d_fixed_padding( inputs=inputs, filters=filters, kernel_size=1, strides=strides, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) shortcut = batch_norm_relu( shortcut, is_training, relu=False, data_format=data_format) end_point = 'residual_1_%s' % name inputs = conv2d_fixed_padding( inputs=inputs, filters=filters, kernel_size=3, strides=strides, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) inputs = batch_norm_relu( inputs, is_training, data_format=data_format) end_point = 'residual_2_%s' % name inputs = conv2d_fixed_padding( inputs=inputs, filters=filters, kernel_size=3, strides=1, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) inputs = batch_norm_relu( inputs, is_training, relu=False, init_zero=True, data_format=data_format) return tf.nn.relu(inputs + shortcut) def bottleneck_block_(inputs, filters, is_training, strides, use_projection=False, pruning_method='baseline', init_method='baseline', data_format='channels_first', end_sparsity=0., weight_decay=0., name=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height, width]. filters: Int specifying number of filters for the first two convolutions. is_training: Boolean specifying whether the model is training. strides: Int specifying the stride. If stride >1, the input is downsampled. use_projection: Boolean for whether the layer should use a projection shortcut Often, use_projection=True for the first block of a block group. pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' set random weights to zero using end_sparsoty parameter and used with 'baseline' method. data_format: String that specifies either "channels_first" for [batch, channels, height,width] or "channels_last" for [batch, height, width, channels]. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. name: String that specifies name for model layer. Returns: The output activation tensor. """ shortcut = inputs if use_projection: # Projection shortcut only in first block within a group. Bottleneck blocks # end with 4 times the number of filters. filters_out = 4 * filters end_point = 'bottleneck_projection_%s' % name shortcut = conv2d_fixed_padding( inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) shortcut = batch_norm_relu( shortcut, is_training, relu=False, data_format=data_format) end_point = 'bottleneck_1_%s' % name inputs = conv2d_fixed_padding( inputs=inputs, filters=filters, kernel_size=1, strides=1, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) inputs = batch_norm_relu( inputs, is_training, data_format=data_format) end_point = 'bottleneck_2_%s' % name inputs = conv2d_fixed_padding( inputs=inputs, filters=filters, kernel_size=3, strides=strides, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) inputs = batch_norm_relu( inputs, is_training, data_format=data_format) end_point = 'bottleneck_3_%s' % name inputs = conv2d_fixed_padding( inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) inputs = batch_norm_relu( inputs, is_training, relu=False, init_zero=True, data_format=data_format) return tf.nn.relu(inputs + shortcut) def block_group(inputs, filters, block_fn, blocks, strides, is_training, name, pruning_method='baseline', init_method='baseline', data_format='channels_first', end_sparsity=0., weight_decay=0.): """Creates one group of blocks for the ResNet model. Args: inputs: `Tensor` of size `[batch, channels, height, width]`. filters: `int` number of filters for the first convolution of the layer. block_fn: `function` for the block to use within the model blocks: `int` number of blocks contained in the layer. strides: `int` stride to use for the first convolution of the layer. If greater than 1, this layer will downsample the input. is_training: `bool` for whether the model is training. name: String specifying the Tensor output of the block layer. pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' set random weights to zero using end_sparsoty parameter and used with 'baseline' method. data_format: `str` either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. Returns: The output `Tensor` of the block layer. """ with tf.name_scope(name): end_point = 'block_group_projection_%s' % name # Only the first block per block_group uses projection shortcut and strides. inputs = block_fn( inputs, filters, is_training, strides, use_projection=True, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) for n in range(1, blocks): with tf.name_scope('block_group_%d' % n): end_point = '%s_%d_1' % (name, n) inputs = block_fn( inputs, filters, is_training, 1, pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name=end_point) return tf.identity(inputs, name) def resnet_v1_generator(block_fn, num_blocks, num_classes, pruning_method='baseline', init_method='baseline', width=1., prune_first_layer=True, prune_last_layer=True, data_format='channels_first', end_sparsity=0., weight_decay=0., name=None): """Generator for ResNet v1 models. Args: block_fn: String that defines whether to use a `residual_block` or `bottleneck_block`. num_blocks: list of Ints that denotes number of blocks to include in each block group. Each group consists of blocks that take inputs of the same resolution. num_classes: Int number of possible classes for image classification. pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' set random weights to zero using end_sparsoty parameter and used with 'baseline' method. width: Float that scales the number of filters in each layer. prune_first_layer: Whether or not to prune the first layer. prune_last_layer: Whether or not to prune the last layer. data_format: String either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. name: String that specifies name for model layer. Returns: Model `function` that takes in `inputs` and `is_training` and returns the output `Tensor` of the ResNet model. """ def model(inputs, is_training): """Creation of the model graph.""" with tf.variable_scope(name, 'resnet_model'): inputs = conv2d_fixed_padding( inputs=inputs, filters=int(64 * width), kernel_size=7, strides=2, pruning_method=pruning_method if prune_first_layer else 'baseline', init_method=init_method if prune_first_layer else 'baseline', data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay, name='initial_conv') inputs = tf.identity(inputs, 'initial_conv') inputs = batch_norm_relu( inputs, is_training, data_format=data_format) inputs = tf.layers.max_pooling2d( inputs=inputs, pool_size=3, strides=2, padding='SAME', data_format=data_format, name='initial_max_pool') inputs = tf.identity(inputs, 'initial_max_pool') inputs = block_group( inputs=inputs, filters=int(64 * width), block_fn=block_fn, blocks=num_blocks[0], strides=1, is_training=is_training, name='block_group1', pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay) inputs = block_group( inputs=inputs, filters=int(128 * width), block_fn=block_fn, blocks=num_blocks[1], strides=2, is_training=is_training, name='block_group2', pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay) inputs = block_group( inputs=inputs, filters=int(256 * width), block_fn=block_fn, blocks=num_blocks[2], strides=2, is_training=is_training, name='block_group3', pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay) inputs = block_group( inputs=inputs, filters=int(512 * width), block_fn=block_fn, blocks=num_blocks[3], strides=2, is_training=is_training, name='block_group4', pruning_method=pruning_method, init_method=init_method, data_format=data_format, end_sparsity=end_sparsity, weight_decay=weight_decay) pool_size = (inputs.shape[1], inputs.shape[2]) inputs = tf.layers.average_pooling2d( inputs=inputs, pool_size=pool_size, strides=1, padding='VALID', data_format=data_format, name='final_avg_pool') inputs = tf.identity(inputs, 'final_avg_pool') multiplier = 4 if block_fn is bottleneck_block_ else 1 fc_units = multiplier * int(512 * width) inputs = tf.reshape(inputs, [-1, fc_units]) kernel_initializer = tf.random_normal_initializer(stddev=.01) # If init_method==sparse and not pruning, skip. if init_method != 'sparse' or prune_last_layer: kernel_initializer = _pick_initializer(kernel_initializer, init_method, pruning_method, end_sparsity) kernel_regularizer = contrib_layers.l2_regularizer(weight_decay) inputs = sparse_fully_connected( x=inputs, units=num_classes, sparsity_technique=pruning_method if prune_last_layer else 'baseline', kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='final_dense') inputs = tf.identity(inputs, 'final_dense') return inputs model.default_image_size = 224 return model def resnet_v1_(resnet_depth, num_classes, pruning_method='baseline', init_method='baseline', width=1., prune_first_layer=True, prune_last_layer=True, data_format='channels_first', end_sparsity=0., weight_decay=0., name=None): """Returns the ResNet model for a given size and number of output classes. Args: resnet_depth: Int number of blocks in the architecture. num_classes: Int number of possible classes for image classification. pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' set random weights to zero using end_sparsoty parameter and used with 'baseline' method. width: Float multiplier of the number of filters in each layer. prune_first_layer: Whether or not to prune the first layer. prune_last_layer: Whether or not to prune the last layer. data_format: String specifying either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. name: String that specifies the prefix for the scope. Raises: ValueError: If the resnet_depth int is not in the model_params dictionary. """ model_params = { 18: { 'block': residual_block_, 'layers': [2, 2, 2, 2] }, 34: { 'block': residual_block_, 'layers': [3, 4, 6, 3] }, 50: { 'block': bottleneck_block_, 'layers': [3, 4, 6, 3] }, 101: { 'block': bottleneck_block_, 'layers': [3, 4, 23, 3] }, 152: { 'block': bottleneck_block_, 'layers': [3, 8, 36, 3] }, 200: { 'block': bottleneck_block_, 'layers': [3, 24, 36, 3] } } if resnet_depth not in model_params: raise ValueError('Not a valid resnet_depth:', resnet_depth) params = model_params[resnet_depth] return resnet_v1_generator( params['block'], params['layers'], num_classes, pruning_method, init_method, width, prune_first_layer, prune_last_layer, data_format, end_sparsity, weight_decay, name) ================================================ FILE: rigl/imagenet_resnet/train_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Tests for the data_helper input pipeline and the training process. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import flags import absl.testing.parameterized as parameterized from rigl.imagenet_resnet.imagenet_train_eval import resnet_model_fn_w_pruning from rigl.imagenet_resnet.imagenet_train_eval import set_lr_schedule import tensorflow.compat.v1 as tf # tf from official.resnet import imagenet_input from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator FLAGS = flags.FLAGS class DataInputTest(tf.test.TestCase, parameterized.TestCase): def _retrieve_data(self, is_training, data_dir): dataset = imagenet_input.ImageNetInput( is_training=is_training, data_dir=data_dir, transpose_input=False, num_parallel_calls=8, use_bfloat16=False) return dataset @parameterized.parameters('snip', 'set', 'rigl', 'scratch') def testTrainingPipeline(self, training_method): output_directory = '/tmp/' g = tf.Graph() with g.as_default(): dataset = self._retrieve_data(is_training=False, data_dir=False) FLAGS.transpose_input = False FLAGS.use_tpu = False FLAGS.mode = 'train' FLAGS.mask_init_method = 'random' FLAGS.precision = 'float32' FLAGS.train_steps = 1 FLAGS.train_batch_size = 1 FLAGS.eval_batch_size = 1 FLAGS.steps_per_eval = 1 FLAGS.model_architecture = 'resnet' params = {} params['output_dir'] = output_directory params['training_method'] = training_method params['use_tpu'] = False set_lr_schedule() run_config = tpu_config.RunConfig( master=None, model_dir=None, save_checkpoints_steps=1, tpu_config=tpu_config.TPUConfig(iterations_per_loop=1, num_shards=1)) classifier = tpu_estimator.TPUEstimator( use_tpu=False, model_fn=resnet_model_fn_w_pruning, params=params, config=run_config, train_batch_size=1, eval_batch_size=1) classifier.train(input_fn=dataset.input_fn, max_steps=1) if __name__ == '__main__': tf.test.main() ================================================ FILE: rigl/imagenet_resnet/utils.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Helped functions to concatenate subset of noisy images to batch.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf from tensorflow.compat.v2 import summary IMG_SUMMARY_PREFIX = '_img_' def format_tensors(*dicts): """Format metrics to be callable as tf.summary scalars on tpu's. Args: *dicts: A set of metric dictionaries, containing metric name + value tensor. Returns: A single formatted dictionary that holds all tensors. Raises: ValueError: if any tensor is not a scalar. """ merged_summaries = {} for d in dicts: for metric_name, value in d.items(): shape = value.shape.as_list() if metric_name.startswith(IMG_SUMMARY_PREFIX): # If image, shape it into 2d. merged_summaries[metric_name] = tf.reshape(value, (1, -1, value.shape[-1], 1)) elif not shape: merged_summaries[metric_name] = tf.expand_dims(value, axis=0) elif shape == [1]: merged_summaries[metric_name] = value else: raise ValueError( 'Metric {} has value {} that is not reconciliable'.format( metric_name, value)) return merged_summaries def host_call_fn(model_dir, **kwargs): """host_call function used for creating training summaries when using TPU. Args: model_dir: String indicating the output_dir to save summaries in. **kwargs: Set of metric names and tensor values for all desired summaries. Returns: Summary op to be passed to the host_call arg of the estimator function. """ gs = kwargs.pop('global_step')[0] with summary.create_file_writer(model_dir).as_default(): # Always record summaries. with summary.record_if(True): for name, tensor in kwargs.items(): if name.startswith(IMG_SUMMARY_PREFIX): summary.image(name.replace(IMG_SUMMARY_PREFIX, ''), tensor, max_images=1) else: summary.scalar(name, tensor[0], step=gs) # Following function is under tf:1x, so we use it. return tf.summary.all_v2_summary_ops() def mask_summaries(masks, with_img=False): metrics = {} for mask in masks: metrics['pruning/{}/sparsity'.format( mask.op.name)] = tf.nn.zero_fraction(mask) if with_img: metrics[IMG_SUMMARY_PREFIX + 'mask/' + mask.op.name] = mask return metrics def initialize_parameters_from_ckpt(ckpt_path, model_dir, param_suffixes): """Load parameters from an existing checkpoint. Args: ckpt_path: str, loads the mask variables from this checkpoint. model_dir: str, if checkpoint exists in this folder no-op. param_suffixes: list or str, suffix of parameters to be load from checkpoint. """ already_has_ckpt = model_dir and tf.train.latest_checkpoint( model_dir) is not None if already_has_ckpt: tf.logging.info( 'Training already started on this model, not loading masks from' 'previously trained model') return reader = tf.train.NewCheckpointReader(ckpt_path) param_names = reader.get_variable_to_shape_map().keys() param_names = [x for x in param_names if x.endswith(param_suffixes)] variable_map = {} for var in tf.global_variables(): var_name = var.name.split(':')[0] if var_name in param_names: tf.logging.info('Loading parameter variable from checkpoint: %s', var_name) variable_map[var_name] = var elif var_name.endswith(param_suffixes): tf.logging.info( 'Cannot find parameter variable in checkpoint, skipping: %s', var_name) tf.train.init_from_checkpoint(ckpt_path, variable_map) ================================================ FILE: rigl/imagenet_resnet/vgg.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains model definitions for versions of the Oxford VGG network. These model definitions were introduced in the following technical report: Very Deep Convolutional Networks For Large-Scale Image Recognition Karen Simonyan and Andrew Zisserman arXiv technical report, 2015 PDF: http://arxiv.org/pdf/1409.1556.pdf ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf CC-BY-4.0 More information can be obtained from the VGG website: www.robots.ox.ac.uk/~vgg/research/very_deep/ Usage: with arg_scope(vgg.vgg_arg_scope()): outputs, end_points = vgg.vgg_net(inputs,scope='vgg_19') """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools from rigl.imagenet_resnet import resnet_model import tensorflow.compat.v1 as tf from tensorflow.contrib import layers network_cfg = { 'vgg_a': [1, 1, 2, 2, 2], 'vgg_16': [2, 2, 3, 3, 3], 'vgg_19': [2, 2, 4, 4, 4], } def vgg_net(inputs, num_classes=1000, spatial_squeeze=True, name='vgg_a', global_pool=True, pruning_method='baseline', init_method='baseline', data_format='channels_last', width=1., prune_last_layer=True, end_sparsity=0., weight_decay=0.): """Oxford Net VGG. Note: All the fully_connected layers have been transformed to conv2d layers. To use in classification mode, resize input to 224x224. Args: inputs: a tensor of size [batch_size, height, width, channels]. num_classes: number of predicted classes. If 0 or None, the logits layer is omitted and the input features to the logits layer are returned instead. spatial_squeeze: whether or not should squeeze the spatial dimensions of the outputs. Useful to remove unnecessary dimensions for classification. name: Optional scope for the variables. global_pool: Optional boolean flag. If True, the input to the classification layer is avgpooled to size 1x1, for any input size. (This is not part of the original VGG architecture.) pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' set random weights to zero using end_sparsoty parameter and used with 'baseline' method. data_format: String specifying either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. width: Float multiplier of the number of filters in each layer. prune_last_layer: Whether or not to prune the last layer. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. Returns: net: the output of the logits layer (if num_classes is a non-zero integer), or the non-dropped-out input to the logits layer (if num_classes is 0 or None). end_points: a dict of tensors with intermediate activations. For backwards compatibility, some Tensors appear multiple times in the dict. """ net_cfg = network_cfg[name] sparse_conv2d = functools.partial( resnet_model.conv2d_fixed_padding, pruning_method=pruning_method, init_method=init_method, data_format=data_format, init_scale=2.0, # Heinit end_sparsity=end_sparsity, weight_decay=weight_decay) def new_sparse_conv2d(*args, **kwargs): kwargs['name'] = kwargs['scope'] del kwargs['scope'] activation_fn = 'relu' if 'activation_fn' in kwargs: activation_fn = kwargs['activation_fn'] del kwargs['activation_fn'] out = sparse_conv2d(*args, **kwargs) if activation_fn == 'relu': out = tf.nn.relu(out) return out with tf.variable_scope(name, name, values=[inputs]): net = layers.repeat( inputs, net_cfg[0], new_sparse_conv2d, int(64 * width), 3, strides=1, scope='conv1') net = layers.max_pool2d(net, [2, 2], scope='pool1') net = layers.repeat( net, net_cfg[1], new_sparse_conv2d, int(128 * width), 3, strides=1, scope='conv2') net = layers.max_pool2d(net, [2, 2], scope='pool2') net = layers.repeat( net, net_cfg[2], new_sparse_conv2d, int(256 * width), 3, strides=1, scope='conv3') net = layers.max_pool2d(net, [2, 2], scope='pool3') net = layers.repeat( net, net_cfg[3], new_sparse_conv2d, int(512 * width), 3, strides=1, scope='conv4') net = layers.max_pool2d(net, [2, 2], scope='pool4') net = layers.repeat( net, net_cfg[4], new_sparse_conv2d, int(512 * width), 3, strides=1, scope='conv5') # # Use conv2d instead of fully_connected layers. # net = new_sparse_conv2d(net, 512, [7, 7], strides=1, scope='fc6') # # net = layers.dropout(net, dropout_keep_prob, is_training=is_training, # # scope='dropout6') # net = new_sparse_conv2d(net, 512, [1, 1], strides=1, scope='fc7') if global_pool: net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') if num_classes: # net = layers.dropout(net, dropout_keep_prob, is_training=is_training, # scope='dropout7') if prune_last_layer: net = new_sparse_conv2d( net, num_classes, 1, activation_fn=None, strides=1, scope='fc8') else: net = layers.conv2d( net, num_classes, [1, 1], activation_fn=None, scope='fc8') if spatial_squeeze: net = tf.squeeze(net, [1, 2], name='fc8/squeezed') return net def vgg(vgg_type, num_classes, pruning_method='baseline', init_method='baseline', width=1., prune_last_layer=True, data_format='channels_last', end_sparsity=0., weight_decay=0.): """Returns the ResNet model for a given size and number of output classes. Args: vgg_type: Int number of blocks in the architecture. num_classes: Int number of possible classes for image classification. pruning_method: String that specifies the pruning method used to identify which weights to remove. init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard initialization or initialization that takes into the existing sparsity of the layer. 'sparse' only makes sense when combined with pruning_method == 'scratch'. 'random_zeros' set random weights to zero using end_sparsoty parameter and used with 'baseline' method. width: Float multiplier of the number of filters in each layer. prune_last_layer: Whether or not to prune the last layer. data_format: String specifying either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. end_sparsity: Desired sparsity at the end of training. Necessary to initialize an already sparse network. weight_decay: Weight for the l2 regularization loss. Raises: ValueError: If the resnet_depth int is not in the model_params dictionary. """ def model_fn(inputs, is_training): del is_training return vgg_net( inputs, num_classes, name=vgg_type, pruning_method=pruning_method, init_method=init_method, data_format=data_format, width=width, prune_last_layer=prune_last_layer, end_sparsity=end_sparsity, weight_decay=weight_decay) return model_fn ================================================ FILE: rigl/mnist/mnist_train_eval.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""A configurable, multi-layer fully connected network trained on MNIST. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import time from absl import flags import numpy as np from rigl import sparse_optimizers from rigl import sparse_utils import tensorflow.compat.v1 as tf from tensorflow.contrib import layers as contrib_layers from tensorflow.contrib.model_pruning.python import pruning from tensorflow.contrib.model_pruning.python.layers import layers from tensorflow.examples.tutorials.mnist import input_data flags.DEFINE_string('mnist', '/tmp/data', 'Location of the MNIST ' 'dataset.') ## optimizer hyperparameters flags.DEFINE_integer('batch_size', 100, 'The number of samples in each batch') flags.DEFINE_float('learning_rate', .2, 'Initial learning rate.') flags.DEFINE_float('momentum', .9, 'Momentum.') flags.DEFINE_boolean('use_nesterov', True, 'Use nesterov momentum.') flags.DEFINE_integer('num_epochs', 200, 'Number of epochs to run.') flags.DEFINE_integer('lr_drop_epoch', 75, 'The epoch to start dropping lr.') flags.DEFINE_string('optimizer', 'momentum', 'Optimizer to use. sgd, momentum or adam') flags.DEFINE_float('l2_scale', 1e-4, 'l2 loss scale') flags.DEFINE_string('network_type', 'fc', 'Type of the network. See below for available options.') flags.DEFINE_enum( 'training_method', 'baseline', ('scratch', 'set', 'baseline', 'momentum', 'rigl', 'static', 'snip', 'prune'), 'Method used for training sparse network. `scratch` means initial mask is ' 'kept during training. `set` is for sparse evalutionary training and ' '`baseline` is for dense baseline.') flags.DEFINE_float('drop_fraction', 0.3, 'When changing mask dynamically, this fraction decides how ' 'much of the ') flags.DEFINE_string('drop_fraction_anneal', 'cosine', 'If not empty the drop fraction is annealed during sparse' ' training. One of the following: `constant`, `cosine` or ' '`exponential_(\\d*\\.?\\d*)$`. For example: ' '`exponential_3`, `exponential_.3`, `exponential_0.3`. ' 'The number after `exponential` defines the exponent.') flags.DEFINE_string('grow_init', 'zeros', 'Passed to the SparseInitializer, one of: zeros, ' 'initial_value, random_normal, random_uniform.') flags.DEFINE_float('s_momentum', 0.9, 'Momentum values for exponential moving average of ' 'gradients. Used when training_method="momentum".') flags.DEFINE_string( 'input_mask_path', '', 'If given, uses the first mask of the checkpoint to mask ' 'the input. If all the outgoing connections are masked ' 'in the mask, we mask that dimension of the input.') flags.DEFINE_float('sparsity_scale', 0.9, 'Relative sparsity of second layer.') flags.DEFINE_float('rigl_acc_scale', 0., 'Used to scale initial accumulated gradients for new ' 'connections.') flags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin mask updates.') flags.DEFINE_integer('maskupdate_end_step', 50000, 'Step to end mask updates.') flags.DEFINE_integer('maskupdate_frequency', 100, 'Step interval between mask updates.') flags.DEFINE_integer('mask_record_frequency', 0, 'Step interval between mask logging.') flags.DEFINE_string( 'mask_init_method', default='random', help='If not empty string and mask is not loaded from a checkpoint, ' 'indicates the method used for mask initialization. One of the following: ' '`random`, `erdos_renyi`.') flags.DEFINE_integer('prune_begin_step', 2000, 'step to begin pruning') flags.DEFINE_integer('prune_end_step', 30000, 'step to end pruning') flags.DEFINE_float('end_sparsity', .98, 'desired sparsity of final model.') flags.DEFINE_integer('pruning_frequency', 500, 'how often to prune.') flags.DEFINE_float('threshold_decay', 0, 'threshold_decay for pruning.') flags.DEFINE_string('save_path', '', 'Where to save the model.') flags.DEFINE_boolean('save_model', True, 'Whether to save model or not.') flags.DEFINE_integer('seed', default=0, help=('Sets the random seed.')) FLAGS = flags.FLAGS # momentum = 0.9 # lr = 0.2 # batch = 100 # decay = 1e-4 def mnist_network_fc(input_batch, reuse=False, model_pruning=False): """Define a basic FC network.""" regularizer = contrib_layers.l2_regularizer(scale=FLAGS.l2_scale) if model_pruning: y = layers.masked_fully_connected( inputs=input_batch[0], num_outputs=300, activation_fn=tf.nn.relu, weights_regularizer=regularizer, reuse=reuse, scope='layer1') y1 = layers.masked_fully_connected( inputs=y, num_outputs=100, activation_fn=tf.nn.relu, weights_regularizer=regularizer, reuse=reuse, scope='layer2') logits = layers.masked_fully_connected( inputs=y1, num_outputs=10, reuse=reuse, activation_fn=None, weights_regularizer=regularizer, scope='layer3') else: y = tf.layers.dense( inputs=input_batch[0], units=300, activation=tf.nn.relu, kernel_regularizer=regularizer, reuse=reuse, name='layer1') y1 = tf.layers.dense( inputs=y, units=100, activation=tf.nn.relu, kernel_regularizer=regularizer, reuse=reuse, name='layer2') logits = tf.layers.dense(inputs=y1, units=10, reuse=reuse, kernel_regularizer=regularizer, name='layer3') cross_entropy = tf.losses.sparse_softmax_cross_entropy( labels=input_batch[1], logits=logits) cross_entropy += tf.losses.get_regularization_loss() predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32) accuracy = tf.reduce_mean( tf.cast(tf.equal(input_batch[1], predictions), tf.float32)) return cross_entropy, accuracy def get_compressed_fc(masks): """Given the masks of a sparse network returns the compact network.""" # Dead input pixels. inds = np.sum(masks[0], axis=1) != 0 masks[0] = masks[0][inds] compressed_masks = [] for i in range(len(masks)): w = masks[i] # Find neurons that doesn't have any incoming edges. do_w = np.sum(w, axis=0) != 0 if i < (len(masks) - 1): # Find neurons that doesn't have any outgoing edges. di_wnext = np.sum(masks[i+1], axis=1) != 0 # Kept neurons should have at least one incoming and one outgoing edges. do_w = np.logical_and(do_w, di_wnext) compressed_w = w[:, do_w] compressed_masks.append(compressed_w) if i < (len(masks) - 1): # Remove incoming edges from removed neurons. masks[i+1] = masks[i+1][do_w] sparsities = [np.sum(m == 0) / float(np.size(m)) for m in compressed_masks] sizes = [compressed_masks[0].shape[0]] for m in compressed_masks: sizes.append(m.shape[1]) return sparsities, sizes def main(unused_args): tf.set_random_seed(FLAGS.seed) tf.get_variable_scope().set_use_resource(True) np.random.seed(FLAGS.seed) # Load the MNIST data and set up an iterator. mnist_data = input_data.read_data_sets( FLAGS.mnist, one_hot=False, validation_size=0) train_images = mnist_data.train.images test_images = mnist_data.test.images if FLAGS.input_mask_path: reader = tf.train.load_checkpoint(FLAGS.input_mask_path) input_mask = reader.get_tensor('layer1/mask') indices = np.sum(input_mask, axis=1) != 0 train_images = train_images[:, indices] test_images = test_images[:, indices] dataset = tf.data.Dataset.from_tensor_slices( (train_images, mnist_data.train.labels.astype(np.int32))) num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0]) batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size) iterator = batched_dataset.make_one_shot_iterator() test_dataset = tf.data.Dataset.from_tensor_slices( (test_images, mnist_data.test.labels.astype(np.int32))) num_test_images = mnist_data.test.images.shape[0] test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images) test_iterator = test_dataset.make_one_shot_iterator() # Set up loss function. use_model_pruning = FLAGS.training_method != 'baseline' if FLAGS.network_type == 'fc': cross_entropy_train, _ = mnist_network_fc( iterator.get_next(), model_pruning=use_model_pruning) cross_entropy_test, accuracy_test = mnist_network_fc( test_iterator.get_next(), reuse=True, model_pruning=use_model_pruning) else: raise RuntimeError(FLAGS.network + ' is an unknown network type.') # Remove extra added ones. Current implementation adds the variables twice # to the collection. Improve this hacky thing. # TODO test the following with the convnet or any other network. if use_model_pruning: for k in ('masks', 'masked_weights', 'thresholds', 'kernel'): # del tf.get_collection_ref(k)[2] # del tf.get_collection_ref(k)[2] collection = tf.get_collection_ref(k) del collection[len(collection)//2:] print(tf.get_collection_ref(k)) # Set up optimizer and update ops. global_step = tf.train.get_or_create_global_step() batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size if FLAGS.optimizer != 'adam': if not use_model_pruning: boundaries = [int(round(s * batch_per_epoch)) for s in [60, 70, 80]] else: boundaries = [int(round(s * batch_per_epoch)) for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]] learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[FLAGS.learning_rate / (3. ** i) for i in range(len(boundaries) + 1)]) else: learning_rate = FLAGS.learning_rate if FLAGS.optimizer == 'adam': opt = tf.train.AdamOptimizer(FLAGS.learning_rate) elif FLAGS.optimizer == 'momentum': opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum, use_nesterov=FLAGS.use_nesterov) elif FLAGS.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate) else: raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type') custom_sparsities = { 'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale, 'layer3': FLAGS.end_sparsity * 0 } if FLAGS.training_method == 'set': # We override the train op to also update the mask. opt = sparse_optimizers.SparseSETOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'static': # We override the train op to also update the mask. opt = sparse_optimizers.SparseStaticOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'momentum': # We override the train op to also update the mask. opt = sparse_optimizers.SparseMomentumOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif FLAGS.training_method == 'rigl': # We override the train op to also update the mask. opt = sparse_optimizers.SparseRigLOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif FLAGS.training_method == 'snip': opt = sparse_optimizers.SparseSnipOptimizer( opt, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, custom_sparsity_map=custom_sparsities, use_tpu=False) elif FLAGS.training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) train_op = opt.minimize(cross_entropy_train, global_step=global_step) if FLAGS.training_method == 'prune': hparams_string = ('begin_pruning_step={0},sparsity_function_begin_step={0},' 'end_pruning_step={1},sparsity_function_end_step={1},' 'target_sparsity={2},pruning_frequency={3},' 'threshold_decay={4}'.format( FLAGS.prune_begin_step, FLAGS.prune_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.threshold_decay)) pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) pruning_hparams.set_hparam('weight_sparsity_map', ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()]) print(pruning_hparams) pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() weight_sparsity_levels = pruning.get_weight_sparsity() global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks()) tf.summary.scalar('test_accuracy', accuracy_test) tf.summary.scalar('global_sparsity', global_sparsity) for k, v in zip(pruning.get_masks(), weight_sparsity_levels): tf.summary.scalar('sparsity/%s' % k.name, v) if FLAGS.training_method in ('prune', 'snip', 'baseline'): mask_init_op = tf.no_op() tf.logging.info('No mask is set, starting dense.') else: all_masks = pruning.get_masks() mask_init_op = sparse_utils.get_mask_init_fn( all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, custom_sparsities) if FLAGS.save_model: saver = tf.train.Saver() init_op = tf.global_variables_initializer() hyper_params_string = '_'.join([FLAGS.network_type, str(FLAGS.batch_size), str(FLAGS.learning_rate), str(FLAGS.momentum), FLAGS.optimizer, str(FLAGS.l2_scale), FLAGS.training_method, str(FLAGS.prune_begin_step), str(FLAGS.prune_end_step), str(FLAGS.end_sparsity), str(FLAGS.pruning_frequency), str(FLAGS.seed)]) tf.io.gfile.makedirs(FLAGS.save_path) filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt') merged_summary_op = tf.summary.merge_all() # Run session. if not use_model_pruning: with tf.Session() as sess: summary_writer = tf.summary.FileWriter(FLAGS.save_path, graph=tf.get_default_graph()) print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy') sess.run([init_op]) tic = time.time() with tf.io.gfile.GFile(filename, 'w') as outputfile: for i in range(FLAGS.num_epochs * num_batches): sess.run([train_op]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([cross_entropy_test, accuracy_test, merged_summary_op]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %.4f, %.4f, %.4f' % ( i // num_batches, epoch_time, loss, accuracy) print(log_str) print(log_str, file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) else: with tf.Session() as sess: summary_writer = tf.summary.FileWriter(FLAGS.save_path, graph=tf.get_default_graph()) log_str = ','.join([ 'Epoch', 'Iteration', 'Test loss', 'Test accuracy', 'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1' ]) sess.run(init_op) sess.run(mask_init_op) tic = time.time() mask_records = {} with tf.io.gfile.GFile(filename, 'w') as outputfile: print(log_str) print(log_str, file=outputfile) for i in range(FLAGS.num_epochs * num_batches): if (FLAGS.mask_record_frequency > 0 and i % FLAGS.mask_record_frequency == 0): mask_vals = sess.run(pruning.get_masks()) # Cast into bool to save space. mask_records[i] = [a.astype(bool) for a in mask_vals] sess.run([train_op]) weight_sparsity, global_sparsity_val = sess.run( [weight_sparsity_levels, global_sparsity]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([cross_entropy_test, accuracy_test, merged_summary_op]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % ( i // num_batches, i, loss, accuracy, global_sparsity_val, weight_sparsity[0], weight_sparsity[1]) print(log_str) print(log_str, file=outputfile) mask_vals = sess.run(pruning.get_masks()) if FLAGS.network_type == 'fc': sparsities, sizes = get_compressed_fc(mask_vals) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes)) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes), file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) if mask_records: np.save(os.path.join(FLAGS.save_path, 'mask_records'), mask_records) if __name__ == '__main__': tf.app.run() ================================================ FILE: rigl/mnist/visualize_mask_records.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Visualizes the dumped masks using matplotlib. We count the number of outgoing edges from the input dimensions. For the first layer input dimensions correspond to the input pixels and we can visualize it nicely. You can control which layer is visualized by changing `layer_id` and `new_shape`. Default is the first layer and we visualize the number of outgoing connections from individual pixels. python visualize_mask_records.py --records_path=/tmp/mnist/mask_records.npy To save the results as gif: python visualize_mask_records.py --records_path=/path/to/mask_records.npy \ --save_path=/path/to/mask.gif Modified from: https://eli.thegreenplace.net/2016/drawing-animated-gifs-with-matplotlib/ """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import flags from matplotlib.animation import FuncAnimation import matplotlib.pyplot as plt import numpy as np import tensorflow.compat.v1 as tf flags.DEFINE_string('records_path', '/tmp/mnist/mask_records.npy', 'Path to load masks records.') flags.DEFINE_string('save_path', '', 'Path to save the animation.') flags.DEFINE_list('new_shape', '28,28', 'Path for reshaping the units.') flags.DEFINE_integer('interval', 100, 'Miliseconds between plot updates.') flags.DEFINE_integer('layer_id', 0, 'of which we plot statistics during ' 'training.') flags.DEFINE_integer('skip_mask', 10, 'number of checkpoints to skip for ' 'each frame.') flags.DEFINE_integer( 'slow_until', 50, 'Number of masks to show with slower ' 'speed. After this number of frames, we start skipping ' 'frames to make the video shorter.') FLAGS = flags.FLAGS def main(unused_args): fig, ax = plt.subplots() fig.set_tight_layout(True) # Query the figure's on-screen size and DPI. Note that when saving the figure # to a file, we need to provide a DPI for that separately. print('fig size: {0} DPI, size in inches {1}'.format(fig.get_dpi(), fig.get_size_inches())) # Plot a scatter that persists (isn't redrawn) and the initial line. mask_records = np.load(FLAGS.records_path, allow_pickle=True).item() sorted_keys = sorted(mask_records.keys()) new_shape = [int(a) for a in FLAGS.new_shape] reshape_fn = lambda mask: np.reshape(np.sum(mask, axis=1), new_shape) c_mask = mask_records[sorted_keys[0]][FLAGS.layer_id] im = plt.imshow(reshape_fn(c_mask), interpolation='none', vmin=0, vmax=30) fig.colorbar(im, ax=ax) def update(i): """Updates the plot.""" save_iter = sorted_keys[i] label = 'timestep {0}'.format(save_iter) print(label) # Update the line and the axes (with a new xlabel). Return a tuple of # "artists" that have to be redrawn for this frame. c_data = reshape_fn(mask_records[save_iter][FLAGS.layer_id]) im.set_data(c_data) ax.set_xlabel(label) return [im, ax] # FuncAnimation will call the 'update' function for each frame; here # animating over 10 frames, with an interval of 200ms between frames. iteration = FLAGS.slow_until frames = ( list(np.arange(0, iteration, 1)) + list(np.arange(iteration, len(sorted_keys), FLAGS.skip_mask))) anim = FuncAnimation(fig, update, frames=frames, interval=FLAGS.interval) if FLAGS.save_path: anim.save(FLAGS.save_path, dpi=80, writer='imagemagick') else: # plt.show() will just loop the animation forever. plt.show() if __name__ == '__main__': tf.app.run(main) ================================================ FILE: rigl/requirements.txt ================================================ absl-py>=0.6.0 gin-config numpy>=1.15.4 six>=1.12.0 tensorflow>=1.12.0,<2.0 # change to 'tensorflow-gpu' for gpu support tensorflow-datasets==2.1 tensorflow-model-optimization ================================================ FILE: rigl/rigl_tf2/README.md ================================================ # Gradient Flow in Sparse Neural Networks and How Lottery Tickets Win Lottery Tickets explained **Paper**: [https://arxiv.org/abs/2010.03533](https://arxiv.org/abs/2010.03533) This code includes a TF-2 implementation of RigL and some other popular sparse training methods along with pruning, scratch and lottery ticket experiments in a unified codebase. Run pruning experiments. ``` python train.py --gin_config=configs/prune.gin ``` Runs lottery training. ``` Lottery experiments: python train.py logdir=/tmp/sparse_spectrum/lottery --seed=8 \ --gin_config=configs/lottery.gin ``` Runs scratch training. ``` python train.py --logdir=/tmp/sparse_spectrum/scratch --seed=8 \ --gin_config=configs/scratch.gin ``` For assigning different gin flags use gin_bindings. i.e. ``` `--gin_bindings='network.weight_init_method="unit_scaled"' --gin_bindings='unit_scaled_init.init_method="faninout_uniform"' ``` Calculating eigenvalues of hessian. Use logdir to point different checkpoints. ``` python train.py --mode=hessian \ --gin_config=configs/hessian.gin ``` Point `mlp_configs` to run MLP experiments. ``` python train.py --gin_config=mlp_configs/prune.gin ``` Running interpolation experiments is done as the following: ``` python interpolate.py --logdir=/tmp/sparse_spectrum/scratch \ --gin_config=configs/interpolate.gin \ --ckpt_start=/path_to_lottery_logdir/cp-11719.ckpt \ --ckpt_end=/path_to_prune_logdir/cp-11719.ckpt \ --operative_gin=/path_to_logdir/operative_config.gin \ --logdir=/path_to_prune_logdir/ltsolution2prune/ ``` ## a journey with train.py. 1) check `main()`. - Load preload_gin_config. This is useful for scratch experiments to use same hyper_parameters as the pruning experiments. We can overwrite these with regular `gin_configs/bindings` flags. - Load data and create the network. Network might load its values from a checkpoint. These arguments are set through gin. See utils.get_network for details. - Then the code either trains the network `mode=train_eval` or calculates the hessian: `mode=hessian`. 2) train_model() - Create the optimizer and samples a validation set from the training set. Validation set is a subset of the training set and used to get better estimates of certain metrics. - Create the `mask_updater` object. The returned value can be none, then the masks are not updated. - Perform pre-training updates to the network: i.e. meta_initialization. - Set-up checkpointing so that if a checkpoint exist continue from where it is left. - Define gradient function. This function is used during training and for certain other metrics. Note that we have to manually mask the gradients since they are dense. - Define logging function for logging tensorboard event summaries. - Main training loop: save, log, gradient step, mask update. ================================================ FILE: rigl/rigl_tf2/colabs/MnistProp.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "e5O1UdsY202_" }, "source": [ "##### Copyright 2020 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "markdown", "metadata": { "id": "jUW1g2_jWmBk" }, "source": [ "## Measuring Signal Properties of Various Initializations\n", "For a random signal x ~ normal(0, 1), and a neural network denoted with f(x)=y; ensuring std(y)=1 at initialization is a common goal for popular NN initialization schemes. Here we measure signal propagation for different sparse initializations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "4rvDSX8FFYTI" }, "outputs": [], "source": [ "#@title Imports and Definitions\n", "import numpy as np\n", "import os\n", "import tensorflow.compat.v2 as tf\n", "tf.enable_v2_behavior()\n", "\n", "import gin\n", "from rigl import sparse_utils\n", "from rigl.rigl_tf2 import init_utils\n", "from rigl.rigl_tf2 import utils\n", "from rigl.rigl_tf2 import train\n", "from rigl.rigl_tf2 import networks\n", "from rigl.rigl_tf2 import mask_updaters\n", "\n", "import functools\n", "\n", "pruning_params = utils.get_pruning_params(mode='constant', final_sparsity = 0., begin_step=int(1e10))\n", "INPUT_SHAPE = (28, 28, 3)\n", "class Lenet5(tf.keras.Model):\n", "\n", " def __init__(self,\n", " input_shape,\n", " num_classes,\n", " activation: str,\n", " hidden_sizes = (6, 16, 120, 84)):\n", " super(Lenet5, self).__init__()\n", " l = tf.keras.layers\n", " kwargs = {'activation': activation}\n", " filter_fn = lambda _: True\n", " wrap_fn = functools.partial(utils.maybe_prune_layer, params=pruning_params, filter_fn=filter_fn)\n", " self.conv1 = wrap_fn(l.Conv2D(hidden_sizes[0], 5, input_shape=input_shape, **kwargs))\n", " self.pool1 = l.MaxPool2D(pool_size=(2, 2))\n", " self.conv2 = wrap_fn(l.Conv2D(hidden_sizes[1], 5, input_shape=input_shape, **kwargs))\n", " self.pool2 = l.MaxPool2D(pool_size=(2, 2))\n", " self.flatten = l.Flatten()\n", " self.dense1 = wrap_fn(l.Dense(hidden_sizes[2], **kwargs))\n", " self.dense2 = wrap_fn(l.Dense(hidden_sizes[3], **kwargs))\n", " self.dense3 = wrap_fn(l.Dense(num_classes, **kwargs))\n", " self.build((1,)+input_shape)\n", "\n", " def call(self, inputs):\n", " x = inputs\n", " results = {}\n", " for l_name in ['conv1', 'pool1', 'conv2', 'pool2', 'flatten', 'dense1', 'dense2', 'dense3']:\n", " x = getattr(self, l_name)(x)\n", " results[l_name] = x \n", " return results\n", "\n", "def get_mask_random_numpy(mask_shape, sparsity):\n", " \"\"\"Creates a random sparse mask with deterministic sparsity.\n", "\n", " Args:\n", " mask_shape: list, used to obtain shape of the random mask.\n", " sparsity: float, between 0 and 1.\n", "\n", " Returns:\n", " numpy.ndarray\n", " \"\"\"\n", " all_ones = np.abs(np.ones(mask_shape))\n", " n_zeros = int(np.floor(sparsity * all_ones.size))\n", " rand_vals = np.random.uniform(size=mask_shape, high=range(1,mask_shape[-1]+1))\n", " randflat=rand_vals.flatten()\n", " randflat.sort()\n", " t = randflat[n_zeros]\n", " all_ones[rand_vals\u003c=t] = 0\n", " return all_ones\n", "\n", "def create_convnet(sparsity=0, weight_init_method = None, scale=2, method='fanin_normal'):\n", " model = Lenet5(INPUT_SHAPE, num_classes, 'relu')\n", " if sparsity \u003e 0:\n", " all_masks = [layer.pruning_vars[0][1] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]\n", " for mask in all_masks:\n", " new_mask = tf.cast(get_mask_random_numpy(mask.shape, sparsity), dtype=mask.dtype)\n", " mask.assign(new_mask)\n", " if weight_init_method:\n", " all_weights = [layer.pruning_vars[0][0] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]\n", " for mask, param in zip(all_masks, all_weights):\n", " if weight_init_method == 'unit':\n", " new_init = init_utils.unit_scaled_init(mask, method=method, scale=scale)\n", " elif weight_init_method == 'layer':\n", " new_init = init_utils.layer_scaled_init(mask, method=method, scale=scale)\n", " else:\n", " raise ValueError\n", " param.assign(new_init)\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "id": "fkZ_GNjyYYqZ" }, "source": [ "Here we demonstrate how we can calculate the standard deviation of random noise at initialization for `layer-wise` scaled initialization of Liu et. al." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NsmPRCuZnxDA" }, "outputs": [], "source": [ "# Let's create a 95% sparse Lenet-5.\n", "model = create_convnet(sparsity=0.95, weight_init_method='layer', scale=2, method='fanin_normal')\n", "# Random input signal\n", "random_input = tf.random.normal((1000,) + INPUT_SHAPE)\n", "output_dict = model(random_input)\n", "all_stds = []\n", "for k in ['dense1', 'dense2', 'dense3']:\n", " out_dim = output_dict[k].shape[-1]\n", " stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)\n", " all_stds.append(stds)\n", "print('Mean deviation per neuron', np.mean(np.concatenate(all_stds, axis=0)))\n", "print('Mean deviation per output neuron', np.mean(all_stds[-1]))\n", "print('Deviation at output', np.std(random_input))" ] }, { "cell_type": "markdown", "metadata": { "id": "l3ttY88rYovo" }, "source": [ "Now we define the code above as a function and use it on a grid to plot signal propagation at different sparsities." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "executionInfo": { "elapsed": 320, "status": "ok", "timestamp": 1613388807790, "user": { "displayName": "", "photoUrl": "", "userId": "" }, "user_tz": -180 }, "id": "4rfMGKciOOHf" }, "outputs": [], "source": [ "def propagate_signal(sparsity, init_method, batch_size=500):\n", " model = create_convnet(sparsity=sparsity, weight_init_method=init_method)\n", " random_input = tf.random.normal((batch_size,) + INPUT_SHAPE)\n", " # print(np.mean(random_input), np.std(random_input))\n", " output_dict = model(random_input)\n", " out_std = np.std(output_dict['dense3'])\n", " all_stds = []\n", " for k in ['dense1', 'dense2', 'dense3']:\n", " out_dim = output_dict[k].shape[-1]\n", " stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)\n", " all_stds.append(stds)\n", " meanstd = np.mean(np.concatenate(all_stds, axis=0))\n", " return meanstd, out_std" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F1rNPLXk7Ins" }, "outputs": [], "source": [ "import itertools, collections\n", "import numpy as np\n", "all_results = collections.defaultdict(dict)\n", "\n", "N_EXP = 3\n", "for s in np.linspace(0.8,0.98,5):\n", " print(s)\n", " for method, name in zip((None, 'unit', 'layer'), ('Masked Dense', 'Ours', 'Scaled-Init')):\n", " all_results[name][s] = [propagate_signal(s, method) for _ in range(N_EXP)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Sbjc7LxpVGl0" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "for k, v in all_results.items():\n", " # if k == 'Masked Dense':\n", " # continue\n", " x = sorted(v.keys())\n", " y = [np.mean([vv[1] for vv in v[kk]])+1e-5 for kk in x]\n", " plt.plot(x, y, label=k)\n", "plt.hlines(y=1, color='r', xmin=0, xmax=1)\n", "plt.yscale('log')\n", "plt.title('std(output)')\n", "plt.legend()\n", "plt.show()\n", "\n", "for k, v in all_results.items():\n", " # if k == 'Masked Dense':\n", " # continue\n", " x = sorted(v.keys())\n", " y = [np.mean([vv[0] for vv in v[kk]])+1e-5 for kk in x]\n", " plt.plot(x, y, label=k)\n", "plt.yscale('log')\n", "plt.hlines(y=1, color='r', xmin=0, xmax=1)\n", "plt.title('mean(std_per_neuron)')\n", "plt.legend()\n", "plt.show()" ] } ], "metadata": { "colab": { "collapsed_sections": [], "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "name": "Mnist propagation init sparse .ipynb", "provenance": [ { "file_id": "126QJDydlS0V4tQ-KhiN6bSlCOisqLV-Z", "timestamp": 1612472405306 }, { "file_id": "137QdNeUdTGoAOEPKpPMC09keiwlu12Bh", "timestamp": 1601472560303 } ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: rigl/rigl_tf2/configs/dense.gin ================================================ training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 500 steps. training.log_freq = 200 network.network_name = 'lenet5' network.weight_decay = 0.0005 # original_hidden_size/sqrt(20) -> 20 comes from 95% sparsity. # following lenet has 2399 params vs 2396 (95% sparse lenet5). lenet5.hidden_sizes = (6, 16, 120, 84) lenet5.use_batch_norm = False optimizer.name = "momentum" optimizer.learning_rate = 0.05 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/configs/grasp.gin ================================================ training.use_metainit = False training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 training.gradient_regularization=0 optimizer.name = "momentum" optimizer.learning_rate = 0.1 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT network.weight_decay = 0.0002 # Disable GMP pruning. pruning.mode = 'constant' pruning.final_sparsity = 0. # Enable one shot pruning. training.oneshot_prune_fraction = 0.95 training.val_batch_size = 5000 pruning.begin_step = 100000000 # High begin_step, so it never starts. # Mask Updates mask_updater.update_alg = 'rigl_grasp' # Prune part of rigl_grasp corresponds to grasp. mask_updater.last_update_step=0 # Never updates. ================================================ FILE: rigl/rigl_tf2/configs/hessian.gin ================================================ hessian.batch_size = 60000 hessian.rows_at_once = 2 # range(0,100,5) + range(100,2000,100) + range(2000,11719,500) hessian.ckpt_ids = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, 10500, 11000, 11500] # range(4000,11719,50) # For Rigl updates # hessian.ckpt_ids = [-499, -999, -1499, -1999, -2499, -2999, -3499, -3999, -4499, -4999, -5499, -5999, -6499, -6999, -7499, -7999, -8499, -8999, -9499, -9999, -10499, -10999, -11499, -500, -1000, -1500, -2000, -2500, -3000, -3500, -4000, -4500, -5000, -5500, -6000, -6500, -7000, -7500, -8000, -8500, -9000, -9500, -10000, -10500, -11000, -11500] # hessian.ckpt_ids = [-100, -99, -199, -200, -500, -499, -999, -1999, -1499, -1500, -1000, -2000] hessian.overwrite = True ================================================ FILE: rigl/rigl_tf2/configs/interpolate.gin ================================================ interpolate.i_start = -0.20 interpolate.i_end = 1.20 interpolate.n_interpolation = 29 ================================================ FILE: rigl/rigl_tf2/configs/lottery.gin ================================================ training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 optimizer.name = "momentum" optimizer.learning_rate = 0.05 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_path = '/tmp/sparse_spectrum/ckpt-0' pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/configs/prune.gin ================================================ training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 network.network_name = 'lenet5' network.mask_init_path = None network.weight_decay = 0.0005 lenet5.use_batch_norm = False lenet5.hidden_sizes = (6, 16, 120, 84) optimizer.name = "momentum" optimizer.learning_rate = 0.05 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None pruning.mode = 'prune' pruning.initial_sparsity = 0.0 pruning.final_sparsity = 0.95 pruning.begin_step = 3000 pruning.end_step = 7000 pruning.frequency = 100 ================================================ FILE: rigl/rigl_tf2/configs/rigl.gin ================================================ training.use_metainit = False training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 training.gradient_regularization=0 optimizer.name = "momentum" optimizer.learning_rate = 0.05 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_method = None network.weight_decay = 0.0005 pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. unit_scaled_init.method='fanin_normal' # Mask Updates mask_updater.update_alg = 'rigl' mask_updater.schedule_alg = 'lr' mask_updater.update_freq = 100 mask_updater.init_drop_fraction = 0.3 mask_updater.last_update_step=-1 ================================================ FILE: rigl/rigl_tf2/configs/scratch.gin ================================================ training.use_metainit = False training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 training.gradient_regularization=0 optimizer.name = "momentum" optimizer.learning_rate = 0.05 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_method = None network.shuffle_mask = False network.weight_decay = 0.0005 pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/configs/set.gin ================================================ training.use_metainit = False training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 training.gradient_regularization=0 optimizer.name = "momentum" optimizer.learning_rate = 0.05 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_method = None network.weight_decay = 0.0005 pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. unit_scaled_init.method='fanin_normal' # Mask Updates mask_updater.update_alg = 'set' mask_updater.schedule_alg = 'lr' mask_updater.update_freq = 100 mask_updater.init_drop_fraction = 0.3 mask_updater.last_update_step=-1 ================================================ FILE: rigl/rigl_tf2/configs/small_dense.gin ================================================ training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 network.network_name = 'lenet5' network.weight_decay = 0.0005 # original_hidden_size/sqrt(20) -> 20 comes from 95% sparsity. # following lenet has 2399 params vs 2396 (95% sparse lenet5). lenet5.hidden_sizes = (3, 3, 27, 20) lenet5.use_batch_norm = False optimizer.name = "momentum" optimizer.learning_rate = 0.05 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/configs/snip.gin ================================================ training.use_metainit = False training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 training.gradient_regularization=0 optimizer.name = "momentum" optimizer.learning_rate = 0.1 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT network.weight_decay = 0.0002 # Disable GMP pruning. pruning.mode = 'constant' pruning.final_sparsity = 0. # Enable one shot pruning. training.oneshot_prune_fraction = 0.95 training.val_batch_size = 5000 pruning.begin_step = 100000000 # High begin_step, so it never starts. # Mask Updates mask_updater.update_alg = 'rigl_s' # Prune part of rigl_s corresponds to snip. mask_updater.last_update_step=0 # Never updates. ================================================ FILE: rigl/rigl_tf2/init_utils.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements initializations for sparse layers.""" import math import gin import tensorflow as tf @gin.configurable(denylist=['mask']) def unit_scaled_init(mask, method='fanavg_uniform', scale=1.0): """Scales the variance of each unit with correct fan_in.""" mode, distribution = method.strip().split('_') # Lets calculate all fan_ins. if len(mask.shape) == 4: mask_reduced2d = tf.reduce_sum(mask, axis=[0, 1]) elif len(mask.shape) == 2: mask_reduced2d = mask else: raise ValueError(f'mask.shape: {mask.shape} must be 4 or 2 dimensional.') fan_ins = tf.reduce_sum(mask_reduced2d, axis=-2) fan_outs = tf.reduce_sum(mask_reduced2d, axis=-1) non_zero_indices = tf.where(mask) # shape=(NZ, N_dim) # Lets sample each row with the correct fan_in. new_vals = [] # Following iterates over each output channel. for index in non_zero_indices: # Get fan_in and out of neurons that the non_zero connection connects. fan_in = fan_ins[index[-1]] fan_out = fan_outs[index[-2]] # Following code is modified from `tensorflow/python/ops/init_ops_v2.py`. if mode == 'fanin': current_scale = scale / max(1., fan_in) elif mode == 'fanout': current_scale = scale / max(1., fan_out) elif mode == 'fanavg': current_scale = scale / max(1., (fan_in + fan_out) / 2.) else: raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.') if distribution == 'normal': stddev = math.sqrt(current_scale) new_val = tf.random.normal((1,), 0.0, stddev, mask.dtype) elif distribution == 'uniform': limit = math.sqrt(3.0 * current_scale) new_val = tf.random.uniform((1,), -limit, limit, mask.dtype) else: raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.') new_vals.append(new_val) new_vals = tf.concat(new_vals, axis=-1) new_weights = tf.scatter_nd( indices=non_zero_indices, updates=new_vals, shape=mask.shape) return new_weights @gin.configurable(denylist=['mask']) def layer_scaled_init(mask, method='fanavg_uniform', scale=1.0): """Scales the variance of each unit with correct fan_in.""" mode, distribution = method.strip().split('_') init_factory = tf.keras.initializers.VarianceScaling( mode=mode.replace('fan', 'fan_'), scale=scale, distribution=distribution) dense_init = init_factory(shape=mask.shape, dtype=mask.dtype) fraction_nnz = tf.reduce_sum(mask) / tf.size(mask, out_type=mask.dtype) new_weights = dense_init / tf.math.sqrt(fraction_nnz) return new_weights def unit_scaled_init_tf1(mask, method='fanavg_uniform', scale=1.0, dtype=tf.float32): """Scales the variance of each unit with correct fan_in.""" mode, distribution = method.strip().split('_') # Lets calculate all fan_ins. if len(mask.shape) == 4: mask_reduced2d = tf.reduce_sum(mask, axis=[0, 1]) elif len(mask.shape) == 2: mask_reduced2d = mask else: raise ValueError(f'mask.shape: {mask.shape} must be 4 or 2 dimensional.') fan_ins = tf.reduce_sum(mask_reduced2d, axis=-2) fan_outs = tf.reduce_sum(mask_reduced2d, axis=-1) non_zero_indices = tf.where(mask) # shape=(NZ, N_dim) # Lets sample each row with the correct fan_in. def new_val_fn(index): # Get fan_in and out of neurons that the non_zero connection connects. fan_in = fan_ins[index[-1]] fan_out = fan_outs[index[-2]] # Following code is modified from `tensorflow/python/ops/init_ops_v2.py`. if mode == 'fanin': current_scale = scale / tf.math.maximum(1., fan_in) elif mode == 'fanout': current_scale = scale / tf.math.maximum(1., fan_out) elif mode == 'fanavg': current_scale = scale / tf.math.maximum(1., (fan_in + fan_out) / 2.) else: raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.') if distribution == 'normal': stddev = tf.math.sqrt(current_scale) new_val = tf.random.normal((1,), 0.0, stddev, dtype) elif distribution == 'uniform': limit = tf.math.sqrt(3.0 * current_scale) new_val = tf.random.uniform((1,), -limit, limit, dtype) else: raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.') return new_val # Following iterates over each output channel. new_vals = tf.squeeze(tf.map_fn(new_val_fn, non_zero_indices, dtype=dtype)) new_weights = tf.scatter_nd( indices=non_zero_indices, updates=new_vals, shape=mask.shape) return new_weights ================================================ FILE: rigl/rigl_tf2/interpolate.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Script for interpolating between checkpoints. """ import os from absl import app from absl import flags from absl import logging import gin import numpy as np from rigl.rigl_tf2 import utils import tensorflow.compat.v2 as tf from pyglib import timer FLAGS = flags.FLAGS flags.DEFINE_string('logdir', '/tmp/sparse_spectrum/interpolation', 'Directory to save experiment in.') flags.DEFINE_string('ckpt_start', '/tmp/sparse_spectrum/cp-0001.ckpt', 'Directory to save experiment in.') flags.DEFINE_string('ckpt_end', '/tmp/sparse_spectrum/cp-0041.ckpt', 'Directory to save experiment in.') flags.DEFINE_string( 'preload_gin_config', '', 'If non-empty reads a gin file ' 'before parsing gin_config and bindings. This is useful,' 'when you want to start from a configuration of another ' 'run. Values are then overwritten by additional configs ' 'and bindings provided.') flags.DEFINE_bool('use_tpu', True, 'Whether to run on TPU or not.') flags.DEFINE_bool('eval_on_train', True, 'Whether to evaluate on training set.') flags.DEFINE_integer('load_mask_from', 0, '0 means start checkpoint, 1 means ' 'end checkpoint. -1 means no mask loaded.') flags.DEFINE_enum('mode', 'train_eval', ('train_eval', 'hessian'), 'Whether to run on TPU or not.') flags.DEFINE_string( 'tpu_job_name', 'tpu_worker', 'Name of the TPU worker job. This is required when having ' 'multiple TPU worker jobs.') flags.DEFINE_string('master', None, 'TPU worker.') flags.DEFINE_multi_string('gin_config', [], 'List of paths to the config files.') flags.DEFINE_multi_string('gin_bindings', [], 'Newline separated list of Gin parameter bindings.') def test_model(model, d_test, batch_size=1000): """Tests the model and calculates cross entropy loss and accuracy.""" test_loss = tf.keras.metrics.Mean(name='test_loss') test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='test_accuracy') loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) for x, y in d_test.batch(batch_size): predictions = model(x, training=False) batch_loss = loss_object(y, predictions) test_loss.update_state(batch_loss) test_accuracy.update_state(y, predictions) logging.info('Test loss: %f', test_loss.result().numpy()) logging.info('Test accuracy: %f', test_accuracy.result().numpy()) return test_loss.result().numpy(), test_accuracy.result().numpy() @gin.configurable( 'interpolate', denylist=['model_start', 'model_end', 'model_inter', 'd_set']) def interpolate(model_start, model_end, model_inter, d_set, i_start=-0.2, i_end=1.2, n_interpolation=29): """Interpolates between 2 sparse networks linearly and evaluates.""" interpolation_coefs = np.linspace(i_start, i_end, n_interpolation) all_scores = {} for i_coef in interpolation_coefs: logging.info('Interpolating with: %f', i_coef) for var_start, var_end, var_inter in zip(model_start.trainable_variables, model_end.trainable_variables, model_inter.trainable_variables): new_value = (1 - i_coef) * var_start + i_coef * var_end var_inter.assign(new_value) scores = test_model(model_inter, d_set) all_scores[i_coef] = scores return all_scores def main(unused_argv): init_timer = timer.Timer() init_timer.Start() if FLAGS.preload_gin_config: # Load default values from the original experiment, always the first one. with gin.unlock_config(): gin.parse_config_file(FLAGS.preload_gin_config, skip_unknown=True) logging.info('Operative Gin configurations loaded from: %s', FLAGS.preload_gin_config) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) data_train, data_test, info = utils.get_dataset() input_shape = info.features['image'].shape num_classes = info.features['label'].num_classes logging.info('Input Shape: %s', input_shape) logging.info('train samples: %s', info.splits['train'].num_examples) logging.info('test samples: %s', info.splits['test'].num_examples) data_eval = data_train if FLAGS.eval_on_train else data_test pruning_params = utils.get_pruning_params(mode='constant') mask_load_dict = {-1: None, 0: FLAGS.ckpt_start, 1: FLAGS.ckpt_end} mask_path = mask_load_dict[FLAGS.load_mask_from] # Currently we interpolate only on the same sparse space. model_start = utils.get_network( pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_start) model_start.summary() model_end = utils.get_network( pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_end) model_end.summary() # Create a third network for interpolation. model_inter = utils.get_network( pruning_params, input_shape, num_classes, mask_init_path=mask_path, weight_init_path=FLAGS.ckpt_end) logging.info('Performance at init (model_start:') test_model(model_start, data_eval) logging.info('Performance at init (model_end:') test_model(model_end, data_eval) all_results = interpolate(model_start=model_start, model_end=model_end, model_inter=model_inter, d_set=data_eval) tf.io.gfile.makedirs(FLAGS.logdir) results_path = os.path.join(FLAGS.logdir, 'all_results') with tf.io.gfile.GFile(results_path, 'wb') as f: np.save(f, all_results) logging.info('Total runtime: %.3f s', init_timer.GetDuration()) logconfigfile_path = os.path.join(FLAGS.logdir, 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str()) if __name__ == '__main__': tf.enable_v2_behavior() app.run(main) ================================================ FILE: rigl/rigl_tf2/mask_updaters.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements RigL.""" import gin from rigl.rigl_tf2 import utils import tensorflow as tf def get_all_layers(model, filter_fn=lambda _: True): """Gets all layers of a model and layers of a layer if it is a keras.Model.""" all_layers = [] for l in model.layers: if hasattr(l, 'layers'): all_layers.extend(get_all_layers(l, filter_fn=filter_fn)) elif filter_fn(l): all_layers.append(l) return all_layers def is_pruned(layer): return isinstance(layer, utils.PRUNING_WRAPPER) and layer.trainable class MaskUpdater(object): """Base class for mask update algorithms. Attributes: model: tf.keras.Model optimizer: tf.train.Optimizer use_stateless: bool, if True stateless operations are used. This is important for multi-worker jobs not to diverge. stateless_seed_offset: int, added to the seed of stateless operations. Use this to create randomness without divergence across workers. """ def __init__(self, model, optimizer, use_stateless=True, stateless_seed_offset=0, loss_fn=None): self._model = model self._optimizer = optimizer self._use_stateless = use_stateless self._stateless_seed_offset = stateless_seed_offset self._loss_fn = loss_fn self.val_x = self.val_y = None def prune_masks(self, prune_fraction): """Updates a fraction of weights in each layer.""" all_masks, all_vars = self.get_vars_and_masks() drop_scores = self.get_drop_scores(all_vars, all_masks) grow_score = None for mask, var, drop_score in zip(all_masks, all_vars, drop_scores): self.generic_mask_update(mask, var, drop_score, grow_score, prune_fraction) def update_masks(self, drop_fraction): """Updates a fraction of weights in each layer.""" all_masks, all_vars = self.get_vars_and_masks() drop_scores = self.get_drop_scores(all_vars, all_masks) grow_scores = self.get_grow_scores(all_vars, all_masks) for mask, var, drop_score, grow_score in zip(all_masks, all_vars, drop_scores, grow_scores): self.generic_mask_update(mask, var, drop_score, grow_score, drop_fraction) def get_all_pruning_layers(self): """Returns all pruned layers from the model.""" if hasattr(self._model, 'layers'): return get_all_layers(self._model, filter_fn=is_pruned) else: return [self._model] if is_pruned(self._model) else [] def get_vars_and_masks(self): """Gets all masked variables and corresponding masks.""" all_masks = [] all_vars = [] for layer in self.get_all_pruning_layers(): for var, mask, _ in layer.pruning_vars: all_vars.append(var) all_masks.append(mask) return all_masks, all_vars def get_drop_scores(self, all_vars, all_masks): raise NotImplementedError def get_grow_scores(self, all_vars, all_masks): raise NotImplementedError def generic_mask_update(self, mask, var, score_drop, score_grow, drop_fraction, reinit_when_same=False): """Prunes+grows connections, all tensors same shape.""" n_total = tf.size(score_drop) n_ones = tf.cast(tf.reduce_sum(mask), dtype=tf.int32) n_prune = tf.cast( tf.cast(n_ones, dtype=tf.float32) * drop_fraction, tf.int32) n_keep = n_ones - n_prune # Sort the entire array since the k needs to be constant for TPU. _, sorted_indices = tf.math.top_k( tf.reshape(score_drop, [-1]), k=n_total) sorted_indices_ex = tf.expand_dims(sorted_indices, 1) # We will have zeros after having `n_keep` many ones. new_values = tf.where( tf.range(n_total) < n_keep, tf.ones_like(sorted_indices, dtype=mask.dtype), tf.zeros_like(sorted_indices, dtype=mask.dtype)) mask1 = tf.scatter_nd(sorted_indices_ex, new_values, new_values.shape) if score_grow is not None: # Flatten the scores. score_grow = tf.reshape(score_grow, [-1]) # Set scores of the enabled connections(ones) to min(s) - 1, so that they # have the lowest scores. score_grow_lifted = tf.where( tf.math.equal(mask1, 1), tf.ones_like(mask1) * (tf.reduce_min(score_grow) - 1), score_grow) _, sorted_indices = tf.math.top_k(score_grow_lifted, k=n_total) sorted_indices_ex = tf.expand_dims(sorted_indices, 1) new_values = tf.where( tf.range(n_total) < n_prune, tf.ones_like(sorted_indices, dtype=mask.dtype), tf.zeros_like(sorted_indices, dtype=mask.dtype)) mask2 = tf.scatter_nd(sorted_indices_ex, new_values, new_values.shape) # Ensure masks are disjoint. tf.debugging.assert_near(tf.reduce_sum(mask1 * mask2), 0.) # Let's set the weights of the growed connections. mask2_reshaped = tf.reshape(mask2, mask.shape) # Set the values of the new connections. grow_tensor = tf.zeros_like(var, dtype=var.dtype) if reinit_when_same: # If dropped and grown, we re-initialize. new_connections = tf.math.equal(mask2_reshaped, 1) else: new_connections = tf.math.logical_and( tf.math.equal(mask2_reshaped, 1), tf.math.equal(mask, 0)) new_weights = tf.where(new_connections, grow_tensor, var) var.assign(new_weights) # Ensure there is no momentum value for new connections self.reset_momentum(var, new_connections) mask_combined = tf.reshape(mask1 + mask2, mask.shape) else: mask_combined = tf.reshape(mask1, mask.shape) mask.assign(mask_combined) def reset_momentum(self, var, new_connections): for s_name in self._optimizer.get_slot_names(): # Momentum variable for example, we reset the aggregated values to zero. optim_var = self._optimizer.get_slot(var, s_name) new_values = tf.where(new_connections, tf.zeros_like(optim_var), optim_var) optim_var.assign(new_values) def _random_uniform(self, *args, **kwargs): if self._use_stateless: c_seed = self._stateless_seed_offset + kwargs['seed'] kwargs['seed'] = tf.cast( tf.stack([c_seed, self._optimizer.iterations]), tf.int32) return tf.random.stateless_uniform(*args, **kwargs) else: return tf.random.uniform(*args, **kwargs) def _random_normal(self, *args, **kwargs): if self._use_stateless: c_seed = self._stateless_seed_offset + kwargs['seed'] kwargs['seed'] = tf.cast( tf.stack([c_seed, self._optimizer.iterations]), tf.int32) return tf.random.stateless_normal(*args, **kwargs) else: return tf.random.normal(*args, **kwargs) def set_validation_data(self, val_x, val_y): self.val_x, self.val_y = val_x, val_y def _get_gradients(self, all_vars): """Returns the gradients of the given weights using the validation data.""" with tf.GradientTape() as tape: batch_loss = self._loss_fn(self.val_x, self.val_y) grads = tape.gradient(batch_loss, all_vars) if grads: grads = tf.distribute.get_replica_context().all_reduce('sum', grads) return grads class SET(MaskUpdater): """Implementation of dynamic sparsity optimizers. Implementation of SET. See https://www.nature.com/articles/s41467-018-04316-3 This optimizer wraps a regular optimizer and performs updates on the masks according to schedule given. """ def get_drop_scores(self, all_vars, all_masks, noise_std=0): def score_fn(mask, var): score = tf.math.abs(mask*var) if noise_std != 0: score += self._random_normal( score.shape, stddev=noise_std, dtype=score.dtype, seed=(hash(var.name + 'drop'))) return score return [score_fn(mask, var) for mask, var in zip(all_masks, all_vars)] def get_grow_scores(self, all_vars, all_masks): return [self._random_uniform(var.shape, seed=hash(var.name + 'grow')) for var in all_vars] class RigL(MaskUpdater): """Implementation of dynamic sparsity optimizers. Implementation of RigL. """ def get_drop_scores(self, all_vars, all_masks, noise_std=0): def score_fn(mask, var): score = tf.math.abs(mask*var) if noise_std != 0: score += self._random_normal( score.shape, stddev=noise_std, dtype=score.dtype, seed=(hash(var.name + 'drop'))) return score return [score_fn(mask, var) for mask, var in zip(all_masks, all_vars)] def get_grow_scores(self, all_vars, all_masks): return [tf.abs(g) for g in self._get_gradients(all_vars)] class RigLInverted(RigL): """Implementation of dynamic sparsity optimizers. Implementation of RigL. """ def get_grow_scores(self, all_vars, all_masks): return [-tf.abs(g) for g in self._get_gradients(all_vars)] class UpdateSchedule(object): """Base class for mask update algorithms. Attributes: mask_updater: MaskUpdater, to invoke. update_freq: int, frequency of mask updates. init_drop_fraction: float, initial drop fraction. """ def __init__(self, mask_updater, init_drop_fraction, update_freq, last_update_step): self._mask_updater = mask_updater self.update_freq = update_freq self.last_update_step = last_update_step self.init_drop_fraction = tf.convert_to_tensor(init_drop_fraction) self.last_drop_fraction = 0 def get_drop_fraction(self, step): raise NotImplementedError def is_update_iter(self, step): """Returns true if it is a valid mask update step.""" # last_update_step < 0 means, there is no last step. # last_update_step = 0 means, never update. tf.debugging.Assert(step >= 0, [step]) if self.last_update_step < 0: is_valid_step = True elif self.last_update_step == 0: is_valid_step = False else: is_valid_step = step <= self.last_update_step return tf.logical_and(is_valid_step, step % self.update_freq == 0) def update(self, step, check_update_iter=True): if check_update_iter: tf.debugging.Assert(self.is_update_iter(step), [step]) self.last_drop_fraction = self.get_drop_fraction(step) def true_fn(): self._mask_updater.update_masks(self.last_drop_fraction) tf.cond(self.last_drop_fraction > 0., true_fn, lambda: None) def prune(self, prune_fraction): self.last_drop_fraction = prune_fraction self._mask_updater.prune_masks(self.last_drop_fraction) def set_validation_data(self, val_x, val_y): self._mask_updater.set_validation_data(val_x, val_y) class ConstantUpdateSchedule(UpdateSchedule): """Updates a constant fraction of connections.""" def get_drop_fraction(self, step): return self.init_drop_fraction class CosineUpdateSchedule(UpdateSchedule): """Updates a constant fraction of connections.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._drop_fraction_fn = tf.keras.experimental.CosineDecay( self.init_drop_fraction, self.last_update_step, alpha=0.0, name='cosine_drop_fraction') def get_drop_fraction(self, step): return self._drop_fraction_fn(step) class ScaledLRUpdateSchedule(UpdateSchedule): """Scales the drop fraction with learning rate.""" def __init__(self, mask_updater, init_drop_fraction, update_freq, last_update_step, optimizer): self._optimizer = optimizer self._initial_lr = self._get_lr(0) super(ScaledLRUpdateSchedule, self).__init__( mask_updater, init_drop_fraction, update_freq, last_update_step) def _get_lr(self, step): if isinstance(self._optimizer.lr, tf.Variable): return self._optimizer.lr.numpy() else: return self._optimizer.lr(step) def get_drop_fraction(self, step): current_lr = self._get_lr(step) return (self.init_drop_fraction / self._initial_lr) * current_lr @gin.configurable( 'mask_updater', allowlist=[ 'update_alg', 'schedule_alg', 'update_freq', 'init_drop_fraction', 'last_update_step', 'use_stateless', ]) def get_mask_updater( model, optimizer, loss_fn, update_alg='', schedule_alg='lr', update_freq=100, init_drop_fraction=0.3, last_update_step=-1, use_stateless=True): """Retrieves the update algorithm and passes it to the schedule object.""" if not update_alg: return None elif update_alg == 'set': mask_updater = SET(model, optimizer, use_stateless=use_stateless) elif update_alg == 'rigl': mask_updater = RigL( model, optimizer, loss_fn=loss_fn, use_stateless=use_stateless) elif update_alg == 'rigl_inverted': mask_updater = RigLInverted( model, optimizer, loss_fn=loss_fn, use_stateless=use_stateless) else: raise ValueError('update_alg:%s is not valid.' % update_alg) if schedule_alg == 'lr': update_schedule = ScaledLRUpdateSchedule( mask_updater, init_drop_fraction, update_freq, last_update_step, optimizer) elif schedule_alg == 'cosine': update_schedule = CosineUpdateSchedule( mask_updater, init_drop_fraction, update_freq, last_update_step) elif schedule_alg == 'constant': update_schedule = ConstantUpdateSchedule(mask_updater, init_drop_fraction, update_freq, last_update_step) else: raise ValueError('schedule_alg:%s is not valid.' % schedule_alg) return update_schedule ================================================ FILE: rigl/rigl_tf2/metainit.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MetaInit algorithm to dynamically initialize neural nets.""" import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf class ScaleSGD(tf1.train.Optimizer): """SGD optimizer that only trains the scales of the parameters. This optimizer only tunes the scales of weight matrices. """ def __init__(self, learning_rate=0.1, momentum=0.9, mindim=3, use_locking=False, name="ScaleSGD"): super(ScaleSGD, self).__init__(use_locking, name) self._lr = learning_rate self._momentum = momentum self._mindim = mindim # Tensor versions of the constructor arguments, created in _prepare(). self._lr_t = None self._momentum_t = None def _prepare(self): self._lr_t = tf1.convert_to_tensor(self._lr, name="learning_rate") self._momentum_t = tf1.convert_to_tensor(self._momentum, name="momentum_t") def _create_slots(self, var_list): for v in var_list: self._get_or_make_slot_with_initializer(v, tf1.constant_initializer(0), tf1.TensorShape([]), tf1.float32, "m", self._name) def _resource_apply_dense(self, grad, handle): var = handle m = self.get_slot(var, "m") if len(var.shape) < self._mindim: return tf.group(*[var, m]) lr_t = tf1.cast(self._lr_t, var.dtype.base_dtype) momentum_t = tf1.cast(self._momentum_t, var.dtype.base_dtype) scale = tf1.sqrt(tf1.reduce_sum(var ** 2)) dscale = tf1.sign(tf1.reduce_sum(var * grad) / (scale + 1e-12)) m_t = m.assign(momentum_t * m - lr_t * dscale) new_scale = scale + m_t var_update = tf1.assign(var, var * new_scale / (scale + 1e-12)) return tf1.group(*[var_update, m_t]) def _apply_dense(self, grad, var): return self._resource_apply_dense(grad, var) def _apply_sparse(self, grad, var): raise NotImplementedError("Sparse gradient updates are not supported.") def meta_init(model, loss, x_shape, y_shape, n_params, learning_rate=0.001, momentum=0.9, meta_steps=1000, eps=1e-5, mask_gradient_fn=None): """Run MetaInit algorithm. See `https://papers.nips.cc/paper/9427-metainit-initializing-learning-by-learning-to-initialize`""" optimizer = ScaleSGD(learning_rate, momentum=momentum) for _ in range(meta_steps): x = np.random.normal(0, 1, x_shape) y = np.random.randint(0, y_shape[1], y_shape[0]) with tf.GradientTape(persistent=True) as tape: batch_loss = loss(y, model(x, training=True)) grad = tape.gradient(batch_loss, model.trainable_variables) if mask_gradient_fn is not None: grad = mask_gradient_fn(model, grad, model.trainable_variables) prod = tape.gradient(tf.reduce_sum([tf.reduce_sum(g**2) / 2 for g in grad]), model.trainable_variables) if mask_gradient_fn is not None: prod = mask_gradient_fn(model, prod, model.trainable_variables) meta_loss = [tf.abs(1 - ((g - p) / (g + eps * tf.stop_gradient( (2 * tf.cast(tf.greater_equal(g, 0), tf.float32)) - 1)))) for g, p in zip(grad, prod)] if mask_gradient_fn is not None: meta_loss = mask_gradient_fn(model, meta_loss, model.trainable_variables) meta_loss = sum([tf.reduce_sum(m) for m in meta_loss]) / n_params tf.summary.scalar("meta_loss", meta_loss) gradients = tape.gradient(meta_loss, model.trainable_variables) if mask_gradient_fn is not None: gradients = mask_gradient_fn(model, gradients, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) ================================================ FILE: rigl/rigl_tf2/mlp_configs/dense.gin ================================================ training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 500 steps. training.log_freq = 200 network.network_name = 'mlp' network.weight_decay = 0.0001 optimizer.name = "momentum" optimizer.learning_rate = 0.2 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/mlp_configs/lottery.gin ================================================ # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_path = '/tmp/sparse_spectrum/ckpt-0' pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/mlp_configs/prune.gin ================================================ training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 network.network_name = 'mlp' network.mask_init_path = None network.weight_decay = 0.0001 optimizer.name = "momentum" optimizer.learning_rate = 0.2 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None pruning.mode = 'prune' pruning.initial_sparsity = 0.0 pruning.final_sparsity = 0.98 pruning.begin_step = 3000 pruning.end_step = 7000 pruning.frequency = 100 ================================================ FILE: rigl/rigl_tf2/mlp_configs/rigl.gin ================================================ training.use_metainit = False # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_method = None pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. unit_scaled_init.method='fanin_normal' # Mask Updates mask_updater.update_alg = 'rigl' mask_updater.schedule_alg = 'lr' mask_updater.update_freq = 500 mask_updater.init_drop_fraction = 0.3 mask_updater.last_update_step=-1 ================================================ FILE: rigl/rigl_tf2/mlp_configs/scratch.gin ================================================ training.use_metainit = False # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_method = None network.shuffle_mask = False pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/mlp_configs/set.gin ================================================ training.use_metainit = False # NON-DEFAULT network.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719' network.weight_init_method = None pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. unit_scaled_init.method='fanin_normal' # Mask Updates mask_updater.update_alg = 'set' mask_updater.schedule_alg = 'lr' mask_updater.update_freq = 500 mask_updater.init_drop_fraction = 0.3 mask_updater.last_update_step=-1 ================================================ FILE: rigl/rigl_tf2/mlp_configs/small_dense.gin ================================================ training.total_steps = 11719 # 6e4/128*25 epochs=11719 training.batch_size = 128 training.save_freq = 500 # Log every 5 steps. training.log_freq = 200 network.network_name = 'mlp' network.weight_decay = 0.0001 # (28*28*300 + 300*100 + 100*10)*0.02 + 410 = 5734 params # (28*28*8 + 8*8 + 8*10) + 8+8+10 = 6442 mlp.hidden_sizes = (8, 8) optimizer.name = "momentum" optimizer.learning_rate = 0.2 optimizer.momentum = 0.9 optimizer.clipvalue = None optimizer.clipnorm = None # NON-DEFAULT pruning.mode = 'constant' pruning.final_sparsity = 0. pruning.begin_step = 100000000 # High begin_step, so it never starts. ================================================ FILE: rigl/rigl_tf2/networks.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This module has networks used in experiments. """ from typing import Optional, Tuple # Non-expensive-to-import types. import gin import tensorflow.compat.v2 as tf @gin.configurable(allowlist=['hidden_sizes', 'use_batch_norm']) def lenet5(input_shape, num_classes, activation, kernel_regularizer, use_batch_norm = False, hidden_sizes = (6, 16, 120, 84)): """Lenet5 implementation.""" network = tf.keras.Sequential() kwargs = { 'activation': activation, 'kernel_regularizer': kernel_regularizer, } def maybe_add_batchnorm(): if use_batch_norm: network.add(tf.keras.layers.BatchNormalization()) network.add(tf.keras.layers.Conv2D( hidden_sizes[0], 5, input_shape=input_shape, **kwargs)) network.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2))) maybe_add_batchnorm() network.add(tf.keras.layers.Conv2D(hidden_sizes[1], 5, **kwargs)) network.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2))) maybe_add_batchnorm() network.add(tf.keras.layers.Flatten()) network.add(tf.keras.layers.Dense(hidden_sizes[2], **kwargs)) maybe_add_batchnorm() network.add(tf.keras.layers.Dense(hidden_sizes[3], **kwargs)) maybe_add_batchnorm() kwargs['activation'] = None network.add(tf.keras.layers.Dense(num_classes, **kwargs)) return network @gin.configurable(allowlist=['hidden_sizes', 'use_batch_norm']) def mlp(input_shape, num_classes, activation, kernel_regularizer, use_batch_norm = False, hidden_sizes = (300, 100)): """Lenet5 implementation.""" network = tf.keras.Sequential() kwargs = { 'activation': activation, 'kernel_regularizer': kernel_regularizer } def maybe_add_batchnorm(): if use_batch_norm: network.add(tf.keras.layers.BatchNormalization()) network.add(tf.keras.layers.Flatten(input_shape=input_shape)) network.add(tf.keras.layers.Dense(hidden_sizes[0], **kwargs)) maybe_add_batchnorm() network.add(tf.keras.layers.Dense(hidden_sizes[1], **kwargs)) maybe_add_batchnorm() kwargs['activation'] = None network.add(tf.keras.layers.Dense(num_classes, **kwargs)) return network ================================================ FILE: rigl/rigl_tf2/train.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Training script for running experiments. """ import os from typing import List # Non-expensive-to-import types. from absl import app from absl import flags from absl import logging import gin import jax from jax.scipy.linalg import eigh import numpy as np from rigl.rigl_tf2 import mask_updaters from rigl.rigl_tf2 import metainit from rigl.rigl_tf2 import utils import tensorflow.compat.v2 as tf from pyglib import timer FLAGS = flags.FLAGS flags.DEFINE_string('logdir', '/tmp/sparse_spectrum', 'Directory to save experiment in.') flags.DEFINE_string('preload_gin_config', '', 'If non-empty reads a gin file ' 'before parsing gin_config and bindings. This is useful,' 'when you want to start from a configuration of another ' 'run. Values are then overwritten by additional configs ' 'and bindings provided.') flags.DEFINE_bool('use_tpu', True, 'Whether to run on TPU or not.') flags.DEFINE_enum('mode', 'train_eval', ('train_eval', 'hessian'), 'Whether to run on TPU or not.') flags.DEFINE_string( 'tpu_job_name', 'tpu_worker', 'Name of the TPU worker job. This is required when having ' 'multiple TPU worker jobs.') flags.DEFINE_integer('seed', default=0, help=('Sets the random seed.')) flags.DEFINE_multi_string('gin_config', [], 'List of paths to the config files.') flags.DEFINE_multi_string('gin_bindings', [], 'Newline separated list of Gin parameter bindings.') @tf.function def get_rows(model, variables, masks, ind_l, indices, x_batch, y_batch, is_dense_spectrum): """Calculates the rows (given by `ind_l`) of the Hessian.""" loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) with tf.GradientTape(persistent=True) as tape: predictions = model(x_batch, training=True) loss = loss_object(y_batch, predictions) grads, = tape.gradient(loss, [variables[ind_l]]) # Since the variables are masked before not during the forward pass, # gradients are dense. We need to ensure they are sparse. sparse_grads = grads * masks[ind_l] single_grad = tf.reshape(sparse_grads, [-1]) s_grads = tf.gather(single_grad, indices) flattened_list = [] hessians_slice_vars = tape.jacobian( s_grads, variables, experimental_use_pfor=False) for h, m in zip(hessians_slice_vars, masks): if is_dense_spectrum: # We apply the masks since weights are not hard constrained with sparsity. vals = tf.reshape(h * m, (h.shape[0], -1)) else: boolean_mask = tf.broadcast_to(tf.equal(m, 1), h.shape) vals = tf.reshape(h[boolean_mask], (h.shape[0], -1)) flattened_list.append(vals) res = tf.concat(flattened_list, 1) return res def sparse_hessian_calculator(model, data, rows_at_once, eigvals_path, overwrite, is_dense_spectrum=False): """Calculates the Hessian of the model parameters. Biases are dense.""" # Read all data at once x_batch, y_batch = list(data.batch(100000))[0] if tf.io.gfile.exists(eigvals_path) and overwrite: logging.info('Deleting existing Eigvals: %s', eigvals_path) tf.io.gfile.rmtree(eigvals_path) if tf.io.gfile.exists(eigvals_path): with tf.io.gfile.GFile(eigvals_path, 'rb') as f: eigvals = np.load(f) logging.info('Eigvals exists, skipping :%s', eigvals_path) return eigvals # First lets create lists that indicate the valid dimension of each variable. # If we want to calculate sparse spectrum, then we have to omit masked # dimensions. Biases are dense, therefore have masks of 1's. masks = [] variables = [] layer_group_indices = [] for l in model.layers: if isinstance(l, utils.PRUNING_WRAPPER): # TODO following the outcome of b/148083099, update following. # Add the weight, mask and the valid dimensions. weight = l.weights[0] variables.append(weight) mask = l.weights[2] masks.append(mask) logging.info(mask.shape) if is_dense_spectrum: n_params = tf.size(mask) layer_group_indices.append(tf.range(n_params)) else: fmask = tf.reshape(mask, [-1]) indices = tf.where(tf.equal(fmask, 1))[:, 0] layer_group_indices.append(indices) # Add the bias mask of ones and all of its dimensions. bias = l.weights[1] variables.append(bias) masks.append(tf.ones_like(bias)) layer_group_indices.append(tf.range(tf.size(bias))) else: # For now we assume all parameterized layers are wrapped with # PruneLowMagnitude. assert not l.trainable_variables result_all = [] init_timer = timer.Timer() init_timer.Start() n_total = 0 logging.info('Calculating Hessian...') for i, inds in enumerate(layer_group_indices): n_split = np.ceil(tf.size(inds).numpy() / rows_at_once) logging.info('Nsplit: %d', n_split) for c_slice in np.array_split(inds.numpy(), n_split): res = get_rows(model, variables, masks, i, c_slice, x_batch, y_batch, is_dense_spectrum) result_all.append(res.numpy()) n_total += res.shape[0] target_n = float(res.shape[1]) logging.info('%.3f %% ..', (n_total / target_n)) # We convert in numpy so that it is on cpu automatically and we don't get OOM. c_hessian = np.concatenate(result_all, 0) logging.info('Total runtime for hessian: %.3f s', init_timer.GetDuration()) init_timer.Start() eigens = jax.jit(eigh, backend='cpu')(c_hessian) eigvals = np.asarray(eigens[0]) with tf.io.gfile.GFile(eigvals_path, 'wb') as f: np.save(f, eigvals) logging.info('EigVals saved: %s', eigvals_path) logging.info('Total runtime for eigvals: %.3f s', init_timer.GetDuration()) return eigvals @gin.configurable(denylist=['model', 'ds_train', 'logdir']) def hessian(model, ds_train, logdir, ckpt_ids = gin.REQUIRED, overwrite = False, batch_size = 1000, rows_at_once = 10, is_dense_spectrum = False): """Loads checkpoints under a folder and calculates their hessian spectrum.""" # Note that hessian is calculated using the same batch in different runs. # This is needed since if the job dies and restarted we want it to be same. data_hessian = ds_train.take(batch_size) for ckpt_id in ckpt_ids: # `cp-0005.ckpt.index` -> 15012 ckpt = tf.train.Checkpoint(model=model) c_path = os.path.join(logdir, 'ckpt-%d' % ckpt_id) ckpt.restore(c_path) logging.info('Loaded from: %s', c_path) eigvals_path = c_path + '.eigvals' sparse_hessian_calculator( model=model, data=data_hessian, eigvals_path=eigvals_path, overwrite=overwrite, is_dense_spectrum=is_dense_spectrum, rows_at_once=rows_at_once) def update_prune_step(model, step): for layer in model.layers: if isinstance(layer, utils.PRUNING_WRAPPER): # Assign iteration count to the layer pruning_step. layer.pruning_step.assign(step) def log_sparsities(model): for layer in model.layers: if isinstance(layer, utils.PRUNING_WRAPPER): for _, mask, threshold in layer.pruning_vars: scalar_name = f'sparsity/{mask.name}' sparsity = 1 - tf.reduce_mean(mask) tf.summary.scalar(scalar_name, sparsity) tf.summary.scalar(f'threshold/{threshold.name}', threshold) def cosine_distance(x, y): """Calculates the distance between 2 tensors of same shape.""" normalizedx = tf.math.l2_normalize(x) normalizedy = tf.math.l2_normalize(y) return 1. - tf.reduce_sum(tf.multiply(normalizedx, normalizedy)) def flatten_list_of_vars(var_list): flat_vars = [tf.reshape(v, -1) for v in var_list] return tf.concat(flat_vars, axis=-1) def var_to_img(tensor): if len(tensor.shape) <= 1: gray_image = tf.reshape(tensor, [1, -1]) elif len(tensor.shape) == 2: gray_image = tensor else: gray_image = tf.reshape(tensor, [-1, tensor.shape[-1]]) # (H, W) -> (1, H, W, 1) return tf.expand_dims(tf.expand_dims(gray_image, 0), -1) def mask_gradients(model, gradients, variables): name_to_grad = {var.name: grad for grad, var in zip(gradients, variables)} for layer in model.layers: if isinstance(layer, utils.PRUNING_WRAPPER): for weights, mask, _ in layer.pruning_vars: if weights.name in name_to_grad: name_to_grad[weights.name] = name_to_grad[weights.name] * mask masked_gradients = [name_to_grad[var.name] for var in variables] return masked_gradients @gin.configurable( 'training', denylist=['model', 'ds_train', 'ds_test', 'logdir']) def train_model(model, ds_train, ds_test, logdir, total_steps = 5000, batch_size = 128, val_batch_size = 1000, save_freq = 5, log_freq = 250, use_metainit = False, oneshot_prune_fraction = 0., gradient_regularization=0): """Training of the CNN on MNIST.""" logging.info('Writing training logs to %s', logdir) writer = tf.summary.create_file_writer(os.path.join(logdir, 'train_logs')) optimizer = utils.get_optimizer(total_steps) loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) train_batch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='train_batch_accuracy') # Let's create 2 disjoint validation sets. (val_x, val_y), (val2_x, val2_y) = [ d for d in ds_train.take(val_batch_size * 2).batch(val_batch_size) ] # We use a separate set than the one we are using in our training. def loss_fn(x, y): loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction=tf.keras.losses.Reduction.NONE) predictions = model(x, training=True) reg_loss = tf.add_n(model.losses) if model.losses else 0 return loss_object(y, predictions) + reg_loss mask_updater = mask_updaters.get_mask_updater(model, optimizer, loss_fn) if mask_updater: mask_updater.set_validation_data(val2_x, val2_y) update_prune_step(model, 0) if oneshot_prune_fraction > 0: logging.info('Running one shot prunning at the beginning.') if not mask_updater: raise ValueError('mask_updater does not exists. Please set ' 'mask_updater.update_alg flag for one shot pruning.') mask_updater.prune(oneshot_prune_fraction) if use_metainit: n_params = 0 for layer in model.layers: if isinstance(layer, utils.PRUNING_WRAPPER): for _, mask, _ in layer.pruning_vars: n_params += tf.reduce_sum(mask) metainit.meta_init(model, loss_object, (128, 28, 28, 1), (128, 10), n_params, mask_gradient_fn=mask_gradients) # This is used to calculate some distances, would give incorrect results when # we restart the training. initial_params = list(map(lambda a: a.numpy(), model.trainable_variables)) # Create the checkpoint object and restore if there is a checkpoint in the # folder. ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager( checkpoint=ckpt, directory=logdir, max_to_keep=None) if ckpt_manager.latest_checkpoint: logging.info('Restored from %s', ckpt_manager.latest_checkpoint) ckpt.restore(ckpt_manager.latest_checkpoint) is_restored = True else: logging.info('Starting from scratch.') is_restored = False # Obtain global_step after loading checkpoint. global_step = optimizer.iterations tf.summary.experimental.set_step(global_step) trainable_vars = model.trainable_variables def get_gradients(x, y, log_batch_gradient=False, is_regularized=True): """Gets spars gradients and possibly logs some statistics.""" is_grad_regularized = gradient_regularization != 0 with tf.GradientTape(persistent=is_grad_regularized) as tape: predictions = model(x, training=True) batch_loss = loss_object(y, predictions) if is_regularized and is_grad_regularized: gradients = tape.gradient(batch_loss, trainable_vars) gradients = mask_gradients(model, gradients, trainable_vars) grad_vec = flatten_list_of_vars(gradients) batch_loss += tf.nn.l2_loss(grad_vec) * gradient_regularization # Regularization might have been disabled. reg_loss = tf.add_n(model.losses) if model.losses else 0 if is_regularized: batch_loss += reg_loss gradients = tape.gradient(batch_loss, trainable_vars) # Gradients are dense, we should mask them to ensure updates are sparse; # So is the norm calculation. gradients = mask_gradients(model, gradients, trainable_vars) # If batch gradient log it. if log_batch_gradient: tf.summary.scalar('train_batch_loss', batch_loss) tf.summary.scalar('train_batch_reg_loss', reg_loss) train_batch_accuracy.update_state(y, predictions) tf.summary.scalar('train_batch_accuracy', train_batch_accuracy.result()) train_batch_accuracy.reset_states() return gradients def log_fn(): logging.info('Logging at iter: %d', global_step.numpy()) log_sparsities(model) test_loss, test_acc = test_model(model, ds_test) tf.summary.scalar('test_loss', test_loss) tf.summary.scalar('test_acc', test_acc) # Log gradient norm. # We want to obtain/log gradients without regularization term. gradients = get_gradients(val_x, val_y, log_batch_gradient=False, is_regularized=False) for var, grad in zip(trainable_vars, gradients): tf.summary.scalar(f'gradnorm/{var.name}', tf.norm(grad)) # Log all gradients together all_norm = tf.norm(flatten_list_of_vars(gradients)) tf.summary.scalar('.allparams/gradnorm', all_norm) # Log momentum values: for s_name in optimizer.get_slot_names(): # Currently we only log momentum. if s_name not in ['momentum']: continue all_slots = [optimizer.get_slot(var, s_name) for var in trainable_vars] all_norm = tf.norm(flatten_list_of_vars(all_slots)) tf.summary.scalar(f'.allparams/norm_{s_name}', all_norm) # Log distance to init. for initial_val, val in zip(initial_params, model.trainable_variables): tf.summary.scalar(f'dist_init_l2/{val.name}', tf.norm(initial_val - val)) cos_distance = cosine_distance(initial_val, val) tf.summary.scalar(f'dist_init_cosine/{val.name}', cos_distance) # Mask update logs: if mask_updater: tf.summary.scalar('drop_fraction', mask_updater.last_drop_fraction) # Log all distances together. flat_initial = flatten_list_of_vars(initial_params) flat_current = flatten_list_of_vars(model.trainable_variables) tf.summary.scalar('.allparams/dist_init_l2/', tf.norm(flat_initial - flat_current)) tf.summary.scalar('.allparams/dist_init_cosine/', cosine_distance(flat_initial, flat_current)) # Log masks for layer in model.layers: if isinstance(layer, utils.PRUNING_WRAPPER): for _, mask, _ in layer.pruning_vars: tf.summary.image('mask/%s' % mask.name, var_to_img(mask)) writer.flush() def save_fn(step=None): save_step = step if step else global_step saved_ckpt = ckpt_manager.save(checkpoint_number=save_step) logging.info('Saved checkpoint: %s', saved_ckpt) with writer.as_default(): for x, y in ds_train.repeat().shuffle( buffer_size=60000).batch(batch_size): if global_step >= total_steps: logging.info('Total steps: %d is completed', global_step.numpy()) save_fn() break update_prune_step(model, global_step) if tf.equal(global_step, 0): logging.info('Seed: %s First 10 Label: %s', FLAGS.seed, y[:10]) if global_step % save_freq == 0: # If just loaded, don't save it again. if is_restored: is_restored = False else: save_fn() if global_step % log_freq == 0: log_fn() gradients = get_gradients(x, y, log_batch_gradient=True) tf.summary.scalar('lr', optimizer.lr(global_step)) optimizer.apply_gradients(zip(gradients, trainable_vars)) if mask_updater and mask_updater.is_update_iter(global_step): # Save the network before mask_update, we want to use negative integers # for this. save_fn(step=(-global_step + 1)) # Gradient norm before. gradients = get_gradients( val_x, val_y, log_batch_gradient=False, is_regularized=False) norm_before = tf.norm(flatten_list_of_vars(gradients)) results = mask_updater.update(global_step) # Save network again save_fn(step=-global_step) if results: for mask_name, drop_frac in results.items(): tf.summary.scalar('drop_fraction/%s' % mask_name, drop_frac) # Gradient norm after mask update. gradients = get_gradients( val_x, val_y, log_batch_gradient=False, is_regularized=False) norm_after = tf.norm(flatten_list_of_vars(gradients)) tf.summary.scalar('.allparams/gradnorm_mask_update_improvment', norm_after - norm_before) logging.info('Performance after training:') log_fn() return model def test_model(model, d_test, batch_size=1000): """Tests the model and calculates cross entropy loss and accuracy.""" test_loss = tf.keras.metrics.Mean(name='test_loss') test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='test_accuracy') loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) for x, y in d_test.batch(batch_size): predictions = model(x, training=False) batch_loss = loss_object(y, predictions) test_loss.update_state(batch_loss) test_accuracy.update_state(y, predictions) logging.info('Test loss: %f', test_loss.result().numpy()) logging.info('Test accuracy: %f', test_accuracy.result().numpy()) return test_loss.result(), test_accuracy.result() def main(unused_argv): tf.random.set_seed(FLAGS.seed) init_timer = timer.Timer() init_timer.Start() if FLAGS.mode == 'hessian': # Load default values from the original experiment. FLAGS.preload_gin_config = os.path.join(FLAGS.logdir, 'operative_config.gin') # Maybe preload a gin config. if FLAGS.preload_gin_config: config_path = FLAGS.preload_gin_config gin.parse_config_file(config_path) logging.info('Gin configuration pre-loaded from: %s', config_path) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) ds_train, ds_test, info = utils.get_dataset() input_shape = info.features['image'].shape num_classes = info.features['label'].num_classes logging.info('Input Shape: %s', input_shape) logging.info('train samples: %s', info.splits['train'].num_examples) logging.info('test samples: %s', info.splits['test'].num_examples) pruning_params = utils.get_pruning_params() model = utils.get_network(pruning_params, input_shape, num_classes) model.summary(print_fn=logging.info) if FLAGS.mode == 'train_eval': train_model(model, ds_train, ds_test, FLAGS.logdir) elif FLAGS.mode == 'hessian': test_model(model, ds_test) hessian(model, ds_train, FLAGS.logdir) logging.info('Total runtime: %.3f s', init_timer.GetDuration()) logconfigfile_path = os.path.join( FLAGS.logdir, 'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str()) if __name__ == '__main__': tf.enable_v2_behavior() app.run(main) ================================================ FILE: rigl/rigl_tf2/utils.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for training. """ import functools from typing import Optional, Tuple from absl import flags from absl import logging import gin from rigl.rigl_tf2 import init_utils from rigl.rigl_tf2 import networks import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper FLAGS = flags.FLAGS PRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude PRUNED_LAYER_TYPES = (tf.keras.layers.Conv2D, tf.keras.layers.Dense) @gin.configurable('data') def get_dataset(): """Loads the dataset.""" # the data, shuffled and split between train and test sets. datasets, info = tfds.load('mnist', with_info=True) ds_train, ds_test = datasets['train'].cache(), datasets['test'].cache() preprocess_fn = lambda x: (tf.cast(x['image'], tf.float32) / 255., x['label']) ds_train = ds_train.map(preprocess_fn) ds_test = tfds.load('mnist', split='test').cache() ds_test = ds_test.map(preprocess_fn) return ds_train, ds_test, info @gin.configurable('pruning') def get_pruning_params(mode='prune', initial_sparsity=0.0, final_sparsity=0.8, begin_step=2000, end_step=4000, frequency=200): """Gets pruning hyper-parameters.""" p_params = {} if mode == 'prune': p_params['pruning_schedule'] = pruning_schedule.PolynomialDecay( initial_sparsity=initial_sparsity, final_sparsity=final_sparsity, begin_step=begin_step, end_step=end_step, frequency=frequency) elif mode == 'constant': p_params['pruning_schedule'] = pruning_schedule.ConstantSparsity( target_sparsity=final_sparsity, begin_step=begin_step) else: raise ValueError('Mode: %s, is not valid' % mode) return p_params # Forked from tensorflow_model_optimization/python/core/sparsity/keras/prune.py def maybe_prune_layer(layer, params, filter_fn): if filter_fn(layer): return PRUNING_WRAPPER(layer, **params) return layer @gin.configurable('network') def get_network( pruning_params, input_shape, num_classes, activation = 'relu', network_name = 'lenet5', mask_init_path = None, shuffle_mask = False, weight_init_path = None, weight_init_method = None, weight_decay = 0., noise_stddev = 0., pruned_layer_types = PRUNED_LAYER_TYPES): """Creates the network.""" kernel_regularizer = ( tf.keras.regularizers.l2(weight_decay) if (weight_decay > 0) else None) # (1) Create keras model. model = getattr(networks, network_name)( input_shape, num_classes, activation=activation, kernel_regularizer=kernel_regularizer) model.summary(print_fn=logging.info) # (2) Adding wrappers. i.e. sparsify if conv or dense. filter_fn = lambda layer: isinstance(layer, pruned_layer_types) clone_fn = functools.partial(maybe_prune_layer, params=pruning_params, filter_fn=filter_fn) model = tf.keras.models.clone_model(model, clone_function=clone_fn) # (3) Update parameters of the model as necessary. if mask_init_path: logging.info('Loading masks from: %s', mask_init_path) mask_init_model = tf.keras.models.clone_model(model) ckpt = tf.train.Checkpoint(model=mask_init_model) ckpt.restore(mask_init_path) for l_source, l_target in zip(mask_init_model.layers, model.layers): if isinstance(l_source, PRUNING_WRAPPER): # l.pruning_vars[0][1] is the mask. mask = l_target.pruning_vars[0][1] n_active = tf.reduce_sum(mask) n_dense = tf.cast(tf.size(mask), dtype=n_active.dtype) logging.info('Before: %s, %.2f', l_target.name, (n_active / n_dense).numpy()) loaded_mask = l_source.pruning_vars[0][1] if shuffle_mask: # tf shuffle shuffles along the first dim, so we need to flatten. loaded_mask = tf.reshape( tf.random.shuffle(tf.reshape(loaded_mask, -1)), loaded_mask.shape) mask.assign(loaded_mask) n_active = tf.reduce_sum(mask) n_dense = tf.cast(tf.size(mask), dtype=n_active.dtype) logging.info('After: %s, %.2f', l_target.name, (n_active / n_dense).numpy()) del mask_init_model if weight_init_path: logging.info('Loading weights from: %s', weight_init_path) weight_init_model = tf.keras.models.clone_model(model) ckpt = tf.train.Checkpoint(model=weight_init_model) ckpt.restore(weight_init_path) for l_source, l_target in zip(weight_init_model.layers, model.layers): for var_source, var_target in zip(l_source.trainable_variables, l_target.trainable_variables): var_target.assign(var_source) logging.info('Weight %s loaded from ckpt.', var_target.name) del weight_init_model elif weight_init_method == 'unit_scaled': logging.info('Using unit_scaled initialization.') for layer in model.layers: if isinstance(layer, PRUNING_WRAPPER): # TODO following the outcome of b/148083099, update following. # Add the weight, mask and the valid dimensions. weight = layer.weights[0] mask = layer.weights[2] new_init = init_utils.unit_scaled_init(mask) weight.assign(new_init) logging.info('Weight %s updated init.', weight.name) elif weight_init_method == 'layer_scaled': logging.info('Using layer_scaled initialization.') for layer in model.layers: if isinstance(layer, PRUNING_WRAPPER): # TODO following the outcome of b/148083099, update following. # Add the weight, mask and the valid dimensions. weight = layer.weights[0] mask = layer.weights[2] new_init = init_utils.layer_scaled_init(mask) weight.assign(new_init) logging.info('Weight %s updated init.', weight.name) if noise_stddev > 0.: logging.info('Adding noise to the initial point') for layer in model.layers: for var in layer.trainable_variables: noise = tf.random.normal(var.shape, mean=0, stddev=noise_stddev) var.assign_add(noise) # Do this call to mask the weights with existing masks if it is not done # already. This is needed for example when we use initial parameters to cal- # culate distance. model(tf.expand_dims(tf.ones(input_shape), 0)) return model @gin.configurable('optimizer', denylist=['total_steps']) def get_optimizer(total_steps, name = 'adam', learning_rate = 0.001, clipnorm = None, clipvalue = None, momentum = None): """Creates the optimizer according to the arguments.""" name = name.lower() # We use cosine decay. lr_decayed_fn = tf.keras.experimental.CosineDecay(learning_rate, total_steps) kwargs = {} if clipnorm: # Not correct implementation, see http://b/152868229 . kwargs['clipnorm'] = clipnorm if clipvalue: kwargs['clipvalue'] = clipvalue if name == 'adam': return tf.keras.optimizers.Adam(lr_decayed_fn, **kwargs) if name == 'momentum': return tf.keras.optimizers.SGD(lr_decayed_fn, momentum=momentum, **kwargs) if name == 'sgd': return tf.keras.optimizers.SGD(lr_decayed_fn, **kwargs) if name == 'rmsprop': return tf.keras.optimizers.RMSprop( lr_decayed_fn, momentum=momentum, **kwargs) raise NotImplementedError(f'Optimizers {name} not implemented.') ================================================ FILE: rigl/rl/README.md ================================================ # The State of Sparse Training in Deep Reinforcement Learning [**Paper**] [goo.gle/sparserl-paper](https://goo.gle/sparserl-paper) [**Video**] [goo.gle/sparserl-video](https://goo.gle/sparserl-video) This code requires Tensorflow 2.0; therefore we need to use a separate requirements file. Please follow the instructions below: First clone this repo. ```bash git clone https://github.com/google-research/rigl.git cd rigl ``` We use [Neurips 2019 MicroNet Challenge](https://micronet-challenge.github.io/) code for counting operations and size of our networks. Let's clone the google_research repo and add current folder to the python path. ```bash git clone https://github.com/google-research/google-research.git mv google-research/ google_research/ export PYTHONPATH=$PYTHONPATH:$PWD ``` Now we can run some tests. Following script creates a virtual environment and installs the necessary libraries. Finally, it runs few tests. ```bash virtualenv -p python3 env_sparserl source env_sparserl/bin/activate pip install -r rigl/rl/requirements.txt python -m rigl.sparse_utils_test ``` Follow instructions here to install MuJoCo: https://github.com/openai/mujoco-py#install-mujoco To run PPO: ``` python3 rigl/rl/tfagents/ppo_train_eval.py \ --gin_file=rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin \ --root_dir=/tmp/sparserl/ --is_mujoco=True ``` To run SAC: ``` python3 rigl/rl/tfagents/sac_train_eval.py \ --gin_file=rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin \ --root_dir=/tmp/sparserl/ --is_mujoco=True ``` **Citation**: ``` @InProceedings{graesser22a, title = {The State of Sparse Training in Deep Reinforcement Learning}, author = {Graesser, Laura and Evci, Utku and Elsen, Erich and Castro, Pablo Samuel}, booktitle = {Proceedings of the 39th International Conference on Machine Learning}, pages = {7766--7792}, year = {2022}, editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, volume = {162}, series = {Proceedings of Machine Learning Research}, month = {17--23 Jul}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v162/graesser22a/graesser22a.pdf}, url = {https://proceedings.mlr.press/v162/graesser22a.html}, } ``` ================================================ FILE: rigl/rl/dqn_agents.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Variants of DQN with sparsity.""" import functools import math from absl import logging from dopamine.agents.dqn import dqn_agent from dopamine.discrete_domains import atari_lib import gin from rigl.rl import sparse_utils import tensorflow as tf import tensorflow.compat.v1 as tf1 # one of ('dense', 'prune', 'rigl', 'static', 'set'). If 'dense' no modification # done. If 'prune', the agent is pruned after training. # If ('rigl', 'static', 'set') the corresponding sparse-to-sparse training # algorithm is used. LEARNER_MODES = ('dense', 'prune', 'rigl', 'static', 'set') def flatten_list_of_vars(var_list): flat_vars = [tf.reshape(v, [-1]) for v in var_list] return tf.concat(flat_vars, axis=-1) def _get_bn_layer_name(block_id, i): return f'batch_norm_{block_id},{i}' def _get_conv_layer_name(block_id, i): return f'conv_{block_id},{i}' class _Stack(tf.keras.Model): """Stack of pooling and convolutional blocks with residual connections. """ def __init__(self, num_ch, num_blocks, use_max_pooling=True, use_batch_norm=False, name='stack'): super(_Stack, self).__init__(name=name) self._conv = tf.keras.layers.Conv2D(num_ch, 3, strides=1, padding='same') self.use_max_pooling = use_max_pooling self.use_batch_norm = use_batch_norm self.num_blocks = num_blocks if self.use_batch_norm: self._batch_norm = tf.keras.layers.BatchNormalization() if self.use_max_pooling: self._max_pool = tf.keras.layers.MaxPool2D( pool_size=3, padding='same', strides=2) for block_id in range(num_blocks): for i in range(2): name = _get_conv_layer_name(block_id, i) layer = tf.keras.layers.Conv2D( num_ch, 3, strides=1, padding='same', name=f'res_{block_id}/conv2d_{i}') setattr(self, name, layer) if self.use_batch_norm: name = _get_bn_layer_name(block_id, i) setattr(self, name, tf.keras.layers.BatchNormalization()) def call(self, conv_out, training=False): # Downscale. conv_out = self._conv(conv_out) if self.use_max_pooling: conv_out = self._max_pool(conv_out) if self.use_batch_norm: conv_out = self._batch_norm(conv_out, training=training) # Residual block(s). for block_id in range(self.num_blocks): block_input = conv_out for i in range(2): conv_out = tf.nn.relu(conv_out) conv_layer = getattr(self, _get_conv_layer_name(block_id, i)) conv_out = conv_layer(conv_out) if self.use_batch_norm: bn_layer = getattr(self, _get_bn_layer_name(block_id, i)) conv_out = bn_layer(conv_out, training=training) conv_out += block_input return conv_out @gin.configurable class ImpalaNetwork(tf.keras.Model): """Agent with ResNet, but without LSTM and additional inputs. The deep model used for DQN which follows "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" by Espeholt, Soyer, Munos et al. Original implementation by Rishabh Agarwal, with minor modifications as follows: * rename nn_scale to width to fit with the sparserl API * allow for non-integer widths. * add training mode. * removed the option to have multiple heads. * modified the call function to return a compatible type. * added custom logic for sparse training. """ def __init__(self, num_actions, width=1.0, mode='dense', name='impala_deep_network', prune_allow_key='', use_batch_norm=False): super().__init__(name=name) self._width = width self._mode = mode def _scale_width(n): return int(math.ceil(n * width)) self.num_actions = num_actions self.use_batch_norm = use_batch_norm logging.info('Using batch norm in %s: %s', name, use_batch_norm) stack_fn = functools.partial(_Stack, use_batch_norm=use_batch_norm) # Parameters and layers for _torso. self._stacks = [ stack_fn(_scale_width(32), 2, name='stack1'), stack_fn(_scale_width(64), 2, name='stack2'), stack_fn(_scale_width(64), 2, name='stack3'), ] self._dense1 = tf.keras.layers.Dense(_scale_width(256)) self._dense2 = tf.keras.layers.Dense( self.num_actions, name='policy_logits') layer_shape_dict = { '_dense1': (7744, 512), '_dense2': (512, self.num_actions), } def add_stack_shapes(name, in_width, out_width): # First conv layer_shape_dict[f'{name}/_conv'] = (3, 3, in_width, out_width) for i in range(2): for j in range(2): l_name = _get_conv_layer_name(i, j) layer_shape_dict[f'{name}/{l_name}'] = (3, 3, out_width, out_width) add_stack_shapes('stack0', 4, _scale_width(32)) add_stack_shapes('stack1', _scale_width(32), _scale_width(64)) add_stack_shapes('stack2', _scale_width(64), _scale_width(64)) if mode != 'dense': custom_sparsities = sparse_utils.get_pruning_sparsities(layer_shape_dict) for l_name, sparsity in custom_sparsities.items(): logging.info('pruning, layer: %s, sparsity: %.4f', l_name, sparsity) if l_name.startswith('stack'): # stack1 -> 1 stack_id = int(l_name[len('stack')]) c_module = self._stacks[stack_id] # `stack1/_conv` -> `_conv` l_name = l_name.split('/')[1] else: c_module = self if mode == 'prune': if prune_allow_key and (prune_allow_key not in l_name): sparsity = 0 logging.info('%s not pruned since, prune_allow_key: %s', l_name, prune_allow_key) wrapped_layer = sparse_utils.maybe_prune_layer( getattr(c_module, l_name), params=sparse_utils.get_pruning_params( mode, final_sparsity=sparsity)) else: wrapped_layer = sparse_utils.maybe_prune_layer( getattr(c_module, l_name), params=sparse_utils.get_pruning_params(mode)) setattr(c_module, l_name, wrapped_layer) def get_features(self, state, training=True): x = tf.cast(state, tf.float32) x /= 255 conv_out = x for stack in self._stacks: conv_out = stack(conv_out, training=training) conv_out = tf.nn.relu(conv_out) conv_out = tf.keras.layers.Flatten()(conv_out) out = self._dense1(conv_out) out = tf.nn.relu(out) out = self._dense2(out) return out def call(self, state, training=True): out = self.get_features(state, training=training) return atari_lib.DQNNetworkType(out) @gin.configurable class NatureDQNNetwork(tf.keras.Model): """The convolutional network used to compute the agent's Q-values.""" def __init__(self, num_actions, width=1, mode='dense', name=None): """Creates the layers used for calculating Q-values. Args: num_actions: int, number of actions. width: float, Scales the width of the network uniformly. mode: str, one of LEARNER_MODES. name: str, used to create scope for network parameters. """ super().__init__(name=name) self.num_actions = num_actions self._width = width self._mode = mode def _scale_width(n): return int(math.ceil(n * width)) # Defining layers. activation_fn = tf.keras.activations.relu # Setting names of the layers manually to make variable names more similar # with tf.slim variable names/checkpoints. self.conv1 = tf.keras.layers.Conv2D( _scale_width(32), [8, 8], strides=4, padding='same', activation=activation_fn, name='Conv') self.conv2 = tf.keras.layers.Conv2D( _scale_width(64), [4, 4], strides=2, padding='same', activation=activation_fn, name='Conv') self.conv3 = tf.keras.layers.Conv2D( _scale_width(64), [3, 3], strides=1, padding='same', activation=activation_fn, name='Conv') self.flatten = tf.keras.layers.Flatten() self.dense1 = tf.keras.layers.Dense( _scale_width(512), activation=activation_fn, name='fully_connected') self.dense2 = tf.keras.layers.Dense(num_actions, name='fully_connected') layer_shape_dict = { 'conv1': (_scale_width(32), 8, 8, 4), 'conv2': (_scale_width(64), 4, 4, _scale_width(32)), 'conv3': (_scale_width(64), 3, 3, _scale_width(64)), 'dense1': (7744, _scale_width(512)), 'dense2': (_scale_width(512), num_actions) } if mode == 'dense': pass elif mode == 'prune': custom_sparsities = sparse_utils.get_pruning_sparsities(layer_shape_dict) for l_name, sparsity in custom_sparsities.items(): logging.info('pruning, layer: %s, sparsity: %.4f', l_name, sparsity) wrapped_layer = sparse_utils.maybe_prune_layer( getattr(self, l_name), params=sparse_utils.get_pruning_params( mode, final_sparsity=sparsity)) setattr(self, l_name, wrapped_layer) else: # static, rigl, set. for l_name in layer_shape_dict: wrapped_layer = sparse_utils.maybe_prune_layer( getattr(self, l_name), params=sparse_utils.get_pruning_params(mode)) setattr(self, l_name, wrapped_layer) def call(self, state): """Creates the output tensor/op given the state tensor as input. See https://www.tensorflow.org/api_docs/python/tf/keras/Model for more information on this. Note that tf.keras.Model implements `call` which is wrapped by `__call__` function by tf.keras.Model. Parameters created here will have scope according to the `name` argument given at `.__init__()` call. Args: state: Tensor, input tensor. Returns: collections.namedtuple, output ops (graph mode) or output tensors (eager). """ x = tf.cast(state, tf.float32) x = x / 255 x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.flatten(x) x = self.dense1(x) return atari_lib.DQNNetworkType(self.dense2(x)) @gin.configurable class SparseDQNAgent(dqn_agent.DQNAgent): """A variant of DQN that is trained with sparse backbones.""" def __init__(self, sess, num_actions, mode='dense', weight_decay=0., summary_writer=None): """Initializes the agent and constructs graph components. Args: sess: tf.Session, for executing ops. num_actions: int, number of actions the agent can take at any state. mode: str, one of LEARNER_MODES. weight_decay: float, used to regularize online_convnet. summary_writer: tf.SummaryWriter, for Tensorboard. """ self._weight_decay = weight_decay if mode in LEARNER_MODES: self._mode = mode else: raise ValueError(f'mode:{mode} not one of {LEARNER_MODES}') self._global_step = tf1.train.get_or_create_global_step() # update_period=1, we always update as the supervisor is fixed. super().__init__( sess, num_actions, summary_writer=summary_writer) def _create_network(self, name): network = self.network( self.num_actions, name=name + 'learner', mode=self._mode) return network def _set_additional_ops(self): if self._mode == 'dense': self.step_update_op = tf.no_op() self.mask_update_op = tf.no_op() self.mask_init_op = tf.no_op() elif self._mode in ['rigl', 'set', 'static']: self.step_update_op = sparse_utils.update_prune_step( self.online_convnet, self._global_step) # This ensures sparse masks are applied before each run. self.mask_update_op = sparse_utils.update_prune_masks(self.online_convnet) self.mask_init_op = sparse_utils.init_masks(self.online_convnet) # Wrap the optimizer. if self._mode == 'rigl': self.optimizer = sparse_utils.UpdatedRigLOptimizer(self.optimizer) self.optimizer.set_model(self.online_convnet) elif self._mode == 'set': self.optimizer = sparse_utils.UpdatedSETOptimizer(self.optimizer) self.optimizer.set_model(self.online_convnet) elif self._mode == 'prune': self.step_update_op = sparse_utils.update_prune_step( self.online_convnet, self._global_step) self.mask_update_op = sparse_utils.update_prune_masks(self.online_convnet) self.mask_init_op = tf.no_op() else: raise ValueError(f'Invalid mode: {self._mode}') def _build_train_op(self): """Builds a training op. Returns: train_op: An op performing one step of training from replay data. """ replay_action_one_hot = tf.one_hot( self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_chosen_q = tf.reduce_sum( self._replay_net_outputs.q_values * replay_action_one_hot, axis=1, name='replay_chosen_q') target = tf.stop_gradient(self._build_target_q_op()) loss = tf1.losses.huber_loss( target, replay_chosen_q, reduction=tf.losses.Reduction.NONE) loss = tf.reduce_mean(loss) if self.summary_writer is not None: tf1.summary.scalar('Losses/HuberLoss', loss) reg_loss = 0. if self._weight_decay: for v in self.online_convnet.trainable_variables: if 'bias' not in v.name: reg_loss += tf.nn.l2_loss(v) * self._weight_decay loss += reg_loss tf1.summary.scalar('Losses/RegLoss', reg_loss) tf1.summary.scalar('Losses/TotalLoss', loss) sparse_utils.log_sparsities(self.online_convnet) self._set_additional_ops() grads_and_vars = self.optimizer.compute_gradients(loss) train_op = self.optimizer.apply_gradients( grads_and_vars, global_step=self._global_step) self._create_summary_ops(grads_and_vars) return train_op def _create_summary_ops(self, grads_and_vars): with tf1.variable_scope('Norm'): all_norm = tf.norm( flatten_list_of_vars(self.online_convnet.trainable_variables)) tf1.summary.scalar('online_convnet/weights_norm', all_norm) all_norm = tf.norm( flatten_list_of_vars(self.target_convnet.trainable_variables)) tf1.summary.scalar('target_convnet/weights_norm', all_norm) all_grad_norm = tf.norm( flatten_list_of_vars([ g for g, v in grads_and_vars if v in self.online_convnet.trainable_variables ])) tf1.summary.scalar('online_convnet/grad_norm', all_grad_norm) total_params, nparam_dict = sparse_utils.get_total_params( self.online_convnet) tf1.summary.scalar('params/total', total_params) for k, val in nparam_dict.items(): tf1.summary.scalar('params/' + k, val) if self._mode == 'rigl': tf1.summary.scalar('drop_fraction', self.optimizer.drop_fraction) def update_prune_step(self): self._sess.run(self.step_update_op) def maybe_update_and_apply_masks(self): self._sess.run(self.mask_update_op) def maybe_init_masks(self): # If `dense`; no initialization. self._sess.run(self.mask_init_op) def _train_step(self): if self._replay.memory.add_count > self.min_replay_history: if self.training_steps % self.update_period == 0: self.update_prune_step() self.maybe_update_and_apply_masks() self._sess.run(self._train_op) c_step = self._sess.run(self._global_step) if (self.summary_writer is not None and self._merged_summaries is not None and c_step % self.summary_writing_frequency == 0): summary = self._sess.run(self._merged_summaries) self.summary_writer.add_summary(summary, c_step) if self.training_steps % self.target_update_period == 0: # Mask weights before syncing self.maybe_update_and_apply_masks() self._sess.run(self._sync_qt_ops) self.training_steps += 1 def _build_sync_op(self): """Builds ops for assigning weights from online to target network. Returns: ops: A list of ops assigning weights from online to target network. """ # Get trainable variables from online and target DQNs sync_qt_ops = [] online_vars = sparse_utils.get_all_variables_and_masks(self.online_convnet) target_vars = sparse_utils.get_all_variables_and_masks(self.target_convnet) for (v_online, v_target) in zip(online_vars, target_vars): # Assign weights from online to target network. sync_qt_ops.append(v_target.assign(v_online, use_locking=True)) return sync_qt_ops def _build_networks(self): """Builds the Q-value network computations needed for acting and training. Same as the `super` class expect training=True flags are passed. These are: self.online_convnet: For computing the current state's Q-values. self.target_convnet: For computing the next state's target Q-values. self._net_outputs: The actual Q-values. self._q_argmax: The action maximizing the current state's Q-values. self._replay_net_outputs: The replayed states' Q-values. self._replay_next_target_net_outputs: The replayed next states' target Q-values (see Mnih et al., 2015 for details). """ self.online_convnet = self._create_network(name='Online') self.target_convnet = self._create_network(name='Target') self._net_outputs = self.online_convnet(self.state_ph, training=True) self._q_argmax = tf.argmax(self._net_outputs.q_values, axis=1)[0] self._replay_net_outputs = self.online_convnet(self._replay.states, training=True) self._replay_next_target_net_outputs = self.target_convnet( self._replay.next_states) ================================================ FILE: rigl/rl/requirements.txt ================================================ absl-py>=0.6.0 dopamine-rl==4.0.5 gin-config mujoco-py<2.2,>=2.1 numpy>=1.15.4 six>=1.12.0 tensorflow==2.9.1 # change to 'tensorflow-gpu' for gpu support tensorflow-datasets==2.1 tensorflow-model-optimization==0.7.2 tf-agents[reverb]=0.13.0 ================================================ FILE: rigl/rl/run.sh ================================================ # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #!/bin/bash set -e set -x virtualenv -p python3 . source ./bin/activate pip install tensorflow pip install -r sparse_rl/requirements.txt python -m sparse_rl.tfagents.sac_train_eval.py \ --gin_file=sparse_rl/tfagents/configs/sac_mujoco_sparse_config.gin ================================================ FILE: rigl/rl/run_experiment.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run policy evaluation as supervised learning, reloading representations.""" import sys from absl import logging from dopamine.discrete_domains import gym_lib from dopamine.discrete_domains import run_experiment import gin import numpy as np from rigl.rl import dqn_agents import tensorflow.compat.v1 as tf1 # Last 10% of the training is averaged to get final reward. AVG_REWARD_FRAC = 0.1 @gin.configurable def create_sparse_agent(sess, num_actions, agent=None, summary_writer=None): """Creates a sparse agent. Args: sess: tf.Session. num_actions: int, number of actions. agent: str, type of learner/actor agent to create. summary_writer: tf.SummaryWriter, for Tensorboard. Returns: A learner/actor agent. """ assert agent is not None if agent == 'dqn': return dqn_agents.SparseDQNAgent( sess, num_actions, summary_writer=summary_writer) else: raise ValueError('Unknown learner agent: {}'.format(agent)) @gin.configurable class SparseTrainRunner(run_experiment.Runner): """Policy evaluation as supervised learning, from a loaded representation.""" def __init__(self, base_dir, agent_type, checkpoint_file_prefix='ckpt', logging_file_prefix='log', log_every_n=1, num_iterations=200, training_steps=250000, evaluation_steps=125000, max_steps_per_episode=27000, load_env_fn=gym_lib.create_gym_environment, clip_rewards=True, atari_100k_eval=False, num_eval_episodes=100, observation_noise=None): """Initialize SparseTrainRunner in charge of running the experiment. Args: base_dir: str, the base directory to host all required sub-directories. agent_type: str, defines the type of targets to be learned. Can be one of {'dqn', 'rainbow'}. checkpoint_file_prefix: str, the prefix to use for checkpoint files. logging_file_prefix: str, prefix to use for the log files. log_every_n: int, the frequency for writing logs. num_iterations: int, the iteration number threshold (must be greater than start_iteration). training_steps: int, the number of training steps to perform. evaluation_steps: int, the number of evaluation steps to perform. max_steps_per_episode: int, maximum number of steps after which an episode terminates. load_env_fn: fn, function which loads and returns an environment. clip_rewards: bool, whether to clip rewards in [-1, 1]. atari_100k_eval: bool, whether we are using the eval for Atari 100K. num_eval_episodes: int, the number of full episodes to run during eval, only used if atari_100k_eval is True. observation_noise: float (optional), the stddev to use to add noise to the observations before sending to the agent. """ self._logging_file_prefix = logging_file_prefix self._log_every_n = log_every_n self._num_iterations = num_iterations self._training_steps = training_steps self._evaluation_steps = evaluation_steps self._max_steps_per_episode = max_steps_per_episode self._clip_rewards = clip_rewards self._atari_100k_eval = atari_100k_eval self._num_eval_episodes = num_eval_episodes self._base_dir = base_dir self._create_directories() self._summary_writer = tf1.summary.FileWriter(self._base_dir) self._observation_noise = observation_noise self._environment = load_env_fn() num_actions = self._environment.action_space.n config = tf1.ConfigProto(allow_soft_placement=True) # Allocate only subset of the GPU memory as needed which allows for running # multiple agents/workers on the same GPU. config.gpu_options.allow_growth = True # Set up a session and initialize variables. self._sess = tf1.Session('local', config=config) self._agent = create_sparse_agent( self._sess, num_actions, agent=agent_type, summary_writer=self._summary_writer) self._summary_writer.add_graph(graph=tf1.get_default_graph()) self._sess.run(tf1.global_variables_initializer()) self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix) def _run_one_phase_fix_episodes(self, max_episodes, statistics): """Run one eval phase for the Atari 100k benchmark. As opposed to the standard eval phase which runs for a fixed number of steps, this will run for a fixed number of episodes, producing less noisy results. Args: max_episodes: int, max number of episodes to run. statistics: `IterationStatistics` object which records the experimental results. Returns: Tuple containing the number of steps taken in this phase (int), the sum of returns (float), and the number of episodes performed (int). """ step_count = 0 num_episodes = 0 sum_returns = 0. while num_episodes < max_episodes: episode_length, episode_return = self._run_one_episode() statistics.append({ 'eval_episode_lengths': episode_length, 'eval_episode_returns': episode_return }) step_count += episode_length sum_returns += episode_return num_episodes += 1 # We use sys.stdout.write instead of logging so as to flush frequently # without generating a line break. sys.stdout.write('Steps executed: {} '.format(step_count) + 'Episode length: {} '.format(episode_length) + 'Num episodes: {} '.format(num_episodes) + 'Return: {}\r'.format(episode_return)) sys.stdout.flush() return step_count, sum_returns, num_episodes def _run_eval_phase(self, statistics): if not self._atari_100k_eval: return super()._run_eval_phase(statistics) self._agent.eval_mode = True _, sum_returns, num_episodes = self._run_one_phase_fix_episodes( self._num_eval_episodes, statistics) average_return = sum_returns / num_episodes if num_episodes > 0 else 0.0 logging.info('Average undiscounted return per evaluation episode: %.2f', average_return) statistics.append({'eval_average_return': average_return}) return num_episodes, average_return def _run_one_step(self, action): """Maybe adds noise to observations.""" observation, reward, is_terminal, _ = self._environment.step(action) if self._observation_noise is not None: observation += np.random.normal( scale=self._observation_noise, size=observation.shape).astype(observation.dtype) return observation, reward, is_terminal def run_experiment(self): """Runs a full experiment, spread over multiple iterations.""" logging.info('Beginning training...') if self._num_iterations <= self._start_iteration: logging.warning('num_iterations (%d) < start_iteration(%d)', self._num_iterations, self._start_iteration) return self._agent.update_prune_step() self._agent.maybe_init_masks() all_eval_returns = [] for iteration in range(self._start_iteration, self._num_iterations): statistics = self._run_one_iteration(iteration) all_eval_returns.append(statistics['eval_average_return'][-1]) self._log_experiment(iteration, statistics) self._checkpoint_experiment(iteration) last_n = int(self._num_iterations * AVG_REWARD_FRAC) avg_return = np.mean(all_eval_returns[-last_n:]) logging.info('Step %d, Average Return: %f', iteration, avg_return) ================================================ FILE: rigl/rl/sparse_utils.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Defines pruning and sparse training utilities.""" import functools import re import gin from rigl import sparse_optimizers_base as sparse_opt_base from rigl import sparse_utils from rigl.rigl_tf2 import init_utils import tensorflow as tf import tensorflow.compat.v1 as tf1 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper PRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude PRUNED_LAYER_TYPES = (tf.keras.layers.Conv2D, tf.keras.layers.Dense) def get_total_params(model): """Obtains total active parameters of a given network.""" all_layers = get_all_layers(model) total_count = 0. nparams_dict = {} for layer in all_layers: n_param = 0. if isinstance(layer, PRUNING_WRAPPER): mask = layer.pruning_vars[0][1] n_param += tf.reduce_sum(mask) n_param += tf.size(layer.weights[1], out_type=tf.float32) else: for w in layer.weights: n_param += tf.size(w, out_type=tf.float32) nparams_dict[layer.name] = n_param total_count += n_param return total_count, nparams_dict @gin.configurable(denylist=['layer_dict']) def get_pruning_sparsities( layer_dict, mask_init_method='erdos_renyi_kernel', target_sparsity=0.9, erk_power_scale=1., custom_sparsity_map=None): """Creates name/sparsity dict using the name/shapes dict (layer_dict).""" if target_sparsity == 0: return {k: 0 for k in layer_dict.keys()} if custom_sparsity_map is None: custom_sparsity_map = {} extract_name_fn = lambda x: re.findall('(.+):0', x)[0] dummy_masks_dict = {k: tf.ones(v) for k, v in layer_dict.items()} reverse_dict = {v.name: k for k, v in dummy_masks_dict.items()} sparsity_dict = sparse_utils.get_sparsities( list(dummy_masks_dict.values()), mask_init_method, target_sparsity, custom_sparsity_map, extract_name_fn=extract_name_fn, erk_power_scale=erk_power_scale) renamed_sparsity_dict = {reverse_dict[k]: float(v) for k, v in sparsity_dict.items()} return renamed_sparsity_dict @gin.configurable('pruning') def get_pruning_params(mode, initial_sparsity=0.0, final_sparsity=0.95, begin_step=30000, end_step=100000, frequency=1000): """Gets pruning hyper-parameters.""" p_params = {} if mode == 'prune': p_params['pruning_schedule'] = pruning_schedule.PolynomialDecay( initial_sparsity=initial_sparsity, final_sparsity=final_sparsity, begin_step=begin_step, end_step=end_step, frequency=frequency) elif mode in ('rigl', 'static', 'set'): # For sparse training methods we don't use the pruning library to update the # masks. Therefore we need to disable it. Following `pruning` flags serve # that purpose. # 1B. High begin_step, so it never starts. p_params['pruning_schedule'] = pruning_schedule.ConstantSparsity( target_sparsity=0, begin_step=1000000000) else: raise ValueError('Mode: %s, is not valid' % mode) return p_params def maybe_prune_layer(layer, params, filter_fn=None): if filter_fn is None: filter_fn = lambda l: isinstance(l, PRUNED_LAYER_TYPES) if filter_fn(layer): return PRUNING_WRAPPER(layer, **params) return layer def get_wrap_fn(mode): """Creates a function that wraps a given layer conditionally. Args: mode: str, If 'dense' no modification done. Otherwise the layer is pruned. Returns: function that accepts layer and returns a possibly wrapped one. """ if mode == 'dense': # Do not wrap the layer. wrap_fn = lambda x: x else: wrap_fn = functools.partial( maybe_prune_layer, params=get_pruning_params(mode)) return wrap_fn def update_prune_step(model, step): """Updates the pruning steps of each pruning layer.""" assign_ops = [] for layer in get_all_pruning_layers(model): # Assign iteration count to the layer pruning_step. # pruning wrapper requires step to be >0. assign_op = tf1.assign(layer.pruning_step, tf.maximum(step, 1)) assign_ops.append(assign_op) return tf.group(assign_ops) def update_prune_masks(model): """Updates the masks if it is an update iteration.""" update_ops = [op for op in model.updates if 'prune_low_magnitude' in op.name] return tf.group(update_ops) def get_all_layers(model, filter_fn=lambda _: True): """Gets all layers of a model and layers of a layer if it is a keras.Model.""" all_layers = [] for l in model.layers: if hasattr(l, 'layers'): all_layers.extend(get_all_layers(l, filter_fn=filter_fn)) elif filter_fn(l): all_layers.append(l) return all_layers def get_all_variables_and_masks(model): """Gets all trainable variables (+their masks) of a model.""" all_layers = get_all_layers(model) all_variables = [] for l in all_layers: all_variables.extend(l.trainable_variables) if isinstance(l, PRUNING_WRAPPER): all_variables.append(l.pruning_vars[0][1]) # Adding mask. return all_variables def get_all_pruning_layers(model): """Gets all pruned layers of a model and layers of a layer if keras.Model.""" return get_all_layers( model, filter_fn=lambda l: isinstance(l, PRUNING_WRAPPER)) def log_sparsities(model): for layer in get_all_pruning_layers(model): for _, mask, threshold in layer.pruning_vars: scalar_name = f'sparsity/{mask.name}' sparsity = 1 - tf.reduce_mean(mask) if len(mask.shape) == 2: reshaped_mask = tf.expand_dims(tf.expand_dims(mask, 0), -1) tf1.summary.image(f'img/{mask.name}', reshaped_mask) tf1.summary.scalar(scalar_name, sparsity) tf1.summary.scalar(f'threshold/{threshold.name}', threshold) class SparseOptTf2Mixin: """Tf2 model_optimization pruning library specific variable retrieval.""" def compute_gradients(self, *args, **kwargs): """Wraps the compute gradient of passed optimizer.""" return self._optimizer.compute_gradients(*args, **kwargs) def set_model(self, model): self.model = model def get_weights(self): all_weights = [ layer.pruning_vars[0][0] for layer in get_all_pruning_layers(self.model) ] return all_weights def get_masks(self): all_masks = [ layer.pruning_vars[0][1] for layer in get_all_pruning_layers(self.model) ] return all_masks def get_masked_weights(self): all_masked_weights = [ w * m for w, m in zip(self.get_weights(), self.get_masks()) ] return all_masked_weights @gin.configurable() class UpdatedSETOptimizer(SparseOptTf2Mixin, sparse_opt_base.SparseSETOptimizerBase): def _before_apply_gradients(self, grads_and_vars): return tf1.no_op() @gin.configurable() class UpdatedRigLOptimizer(SparseOptTf2Mixin, sparse_opt_base.SparseRigLOptimizerBase): def _before_apply_gradients(self, grads_and_vars): """Updates momentum before updating the weights with gradient.""" self._weight2masked_grads = {w.name: g for g, w in grads_and_vars} return tf1.no_op() @gin.configurable() def init_masks(model, mask_init_method='random', sparsity=0.9, erk_power_scale=1., custom_sparsity_map=None, fixed_sparse_init=False): """Inits the masks randomly according to the given sparsity.""" if sparsity == 0: return None if custom_sparsity_map is None: custom_sparsity_map = {} all_masks = [ layer.pruning_vars[0][1] for layer in get_all_pruning_layers(model) ] assigner = sparse_utils.get_mask_init_fn( all_masks, mask_init_method, sparsity, custom_sparsity_map, erk_power_scale=erk_power_scale) if fixed_sparse_init: all_weights = [ layer.pruning_vars[0][0] for layer in get_all_pruning_layers(model) ] with tf.control_dependencies([assigner]): assign_ops = [] for param, mask in zip(all_weights, all_masks): new_init = init_utils.unit_scaled_init_tf1(mask) assign_ops.append(tf1.assign(param, new_init)) assigner = tf.group(assign_ops) return assigner ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_dense.gin ================================================ include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin' import rigl.rl.dqn_agents DQNAgent.network = @dqn_agents.NatureDQNNetwork DQNAgent.optimizer = @tf.train.AdamOptimizer() tf.train.AdamOptimizer.learning_rate = 0.00025 WrappedReplayBuffer.batch_size = 32 # Same as original SparseDQNAgent.mode = 'dense' SparseDQNAgent.weight_decay = 0.0 atari_lib.create_atari_environment.game_name = 'Pong' SparseTrainRunner.load_env_fn = @atari_lib.create_atari_environment SparseTrainRunner.agent_type = 'dqn' SparseTrainRunner.num_iterations = 40 SparseTrainRunner.training_steps = 250000 SparseTrainRunner.evaluation_steps = 125000 SparseTrainRunner.max_steps_per_episode = 27000 # Default max episode length. ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin ================================================ include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin' import rigl.rl.dqn_agents DQNAgent.network = @dqn_agents.ImpalaNetwork DQNAgent.optimizer = @tf.train.AdamOptimizer() tf.train.AdamOptimizer.learning_rate = 0.0001 tf.train.AdamOptimizer.epsilon = 0.0003125 WrappedReplayBuffer.batch_size = 32 # Same as original SparseDQNAgent.mode = 'dense' SparseDQNAgent.weight_decay = 1e-05 atari_lib.create_atari_environment.game_name = 'Pong' SparseTrainRunner.load_env_fn = @atari_lib.create_atari_environment SparseTrainRunner.agent_type = 'dqn' SparseTrainRunner.num_iterations = 40 SparseTrainRunner.training_steps = 250000 SparseTrainRunner.evaluation_steps = 125000 SparseTrainRunner.max_steps_per_episode = 27000 # Default max episode length. ImpalaNetwork.use_batch_norm = False ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_prune.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin' SparseDQNAgent.mode = 'prune' get_pruning_sparsities.target_sparsity = 0.95 get_pruning_sparsities.mask_init_method = 'erdos_renyi_kernel' pruning.initial_sparsity = 0.0 # 0.5M = 20% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). pruning.begin_step = 500000 # 500k # 2M = 80% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). pruning.end_step = 2000000 # 2M pruning.frequency = 5000 ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_prune_impala_net.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin' SparseDQNAgent.mode = 'prune' get_pruning_sparsities.target_sparsity = 0.95 get_pruning_sparsities.mask_init_method = 'erdos_renyi_kernel' pruning.initial_sparsity = 0.0 # 0.5M = 20% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). pruning.begin_step = 500000 # 500k # 2M = 80% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). pruning.end_step = 2000000 # 2M pruning.frequency = 5000 ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_rigl.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin' SparseDQNAgent.mode = 'rigl' # For sparse training methods we don't use the pruning library to update the # masks. Therefore we need to disable it. Following `pruning` flags serve that # purpose. pruning.final_sparsity = 0. pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts. init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.fixed_sparse_init = True init_masks.sparsity = 0.9 UpdatedRigLOptimizer.begin_step = 0 # 2M = 80% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). UpdatedRigLOptimizer.end_step = 2000000 UpdatedRigLOptimizer.frequency = 5000 UpdatedRigLOptimizer.drop_fraction_anneal = 'cosine' UpdatedRigLOptimizer.drop_fraction = 0.3 ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_rigl_impala_net.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin' SparseDQNAgent.mode = 'rigl' # For sparse training methods we don't use the pruning library to update the # masks. Therefore we need to disable it. Following `pruning` flags serve that # purpose. pruning.final_sparsity = 0. pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts. init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.fixed_sparse_init = True init_masks.sparsity = 0.9 UpdatedRigLOptimizer.begin_step = 0 # 2M = 80% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). UpdatedRigLOptimizer.end_step = 2000000 UpdatedRigLOptimizer.frequency = 5000 UpdatedRigLOptimizer.drop_fraction_anneal = 'cosine' UpdatedRigLOptimizer.drop_fraction = 0.3 ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_set.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin' SparseDQNAgent.mode = 'set' # For sparse training methods we don't use the pruning library to update the # masks. Therefore we need to disable it. Following `pruning` flags serve that # purpose. pruning.final_sparsity = 0. pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts. init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.fixed_sparse_init = True init_masks.sparsity = 0.9 UpdatedSETOptimizer.begin_step = 0 # 2M = 80% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). UpdatedSETOptimizer.end_step = 2000000 UpdatedSETOptimizer.frequency = 5000 UpdatedSETOptimizer.drop_fraction_anneal = 'cosine' UpdatedSETOptimizer.drop_fraction = 0.3 ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_set_impala_net.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin' SparseDQNAgent.mode = 'set' # For sparse training methods we don't use the pruning library to update the # masks. Therefore we need to disable it. Following `pruning` flags serve that # purpose. pruning.final_sparsity = 0. pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts. init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.fixed_sparse_init = True init_masks.sparsity = 0.9 UpdatedSETOptimizer.begin_step = 0 # 2M = 80% optimizer steps when training for 40M env steps with a frame skip # of 4 (= 10M transitions), and training every 4th env transition (2.5M train # steps in total). UpdatedSETOptimizer.end_step = 2000000 UpdatedSETOptimizer.frequency = 5000 UpdatedSETOptimizer.drop_fraction_anneal = 'cosine' UpdatedSETOptimizer.drop_fraction = 0.3 ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_static.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin' SparseDQNAgent.mode = 'static' # For sparse training methods we don't use the pruning library to update the # masks. Therefore we need to disable it. Following `pruning` flags serve that # purpose. pruning.final_sparsity = 0. pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts. init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.sparsity = 0.95 ================================================ FILE: rigl/rl/sparsetrain_configs/dqn_atari_static_impala_net.gin ================================================ include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin' SparseDQNAgent.mode = 'static' # For sparse training methods we don't use the pruning library to update the # masks. Therefore we need to disable it. Following `pruning` flags serve that # purpose. pruning.final_sparsity = 0. pruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts. init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.sparsity = 0.95 ================================================ FILE: rigl/rl/tfagents/configs/dqn_gym_dense_config.gin ================================================ # Configs to run DQN training for dense networks on classic control environments. train_eval.env_name='CartPole-v0' train_eval.fc_layer_params = (512, 512) train_eval.target_update_period = 100 train_eval.batch_size = 128 # Environment:train steps ratio is 1:1 train_eval.num_iterations = 100000 train_eval.weight_decay = 1e-6 train_eval.width = 1.0 train_eval.policy_save_interval = 10000 train_eval.epsilon_greedy = 0.01 train_eval.eval_interval = 2000 train_eval.eval_episodes = 20 train_eval.sparse_output_layer = False train_eval.train_mode = 'dense' mask_updater.update_alg = '' mask_updater.schedule_alg = '' log_snr.freq=5000 ================================================ FILE: rigl/rl/tfagents/configs/dqn_gym_pruning_config.gin ================================================ include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin' # Configs to run DQN training for pruning on classic control environments. train_eval.sparse_output_layer = True train_eval.train_mode = 'sparse' # This must be set to 0 when pruning to avoid # initializing the masks init_masks.sparsity = 0.0 wrap_all_layers.mode = 'prune' wrap_all_layers.initial_sparsity = 0.0 wrap_all_layers.final_sparsity = 0.9 wrap_all_layers.mask_init_method = 'erdos_renyi_kernel' # Environment:train steps ratio is 1:1 # We start pruning after 20% training (20,000) and stop after 75% (75,000) wrap_all_layers.begin_step = 20000 wrap_all_layers.end_step = 75000 wrap_all_layers.frequency = 1000 log_sparsities.log_images = False ================================================ FILE: rigl/rl/tfagents/configs/dqn_gym_sparse_config.gin ================================================ include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin' # Configs to run DQN training for static, set, and rigl on classic control # environments. train_eval.sparse_output_layer = True train_eval.train_mode = 'sparse' init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.fixed_sparse_init = True init_masks.sparsity = 0.9 # For static, set this to '' # For rigl set this to 'rigl' # For set set this to 'set' mask_updater.update_alg = '' mask_updater.schedule_alg = 'cosine' mask_updater.update_freq = 1000 mask_updater.init_drop_fraction = 0.5 # Environment:train steps ratio is 1:1, we stop after 75% training = 75,000 mask_updater.last_update_step = 75000 mask_updater.use_stateless = False wrap_all_layers.mode = 'constant' log_sparsities.log_images = False ================================================ FILE: rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin ================================================ # Config to run training for dense on mujoco environments. train_eval.env_name='HalfCheetah-v2' train_eval.actor_fc_layers = (64, 64) train_eval.value_fc_layers = (64, 64) # In order to execute ~1M environment steps, we run 489 iterations # (`--num_iterations=489`) which results in 1,001,472 environment steps. Each # iteration results in 320 training steps (or 320 gradient updates, this is # calulated from environemnt_steps * num_epochs / minibatch_size) and 2,048 # environment steps. Thus 489 *2,048 = 1,001,472 environment steps and # 489 * 320 = 156,480 training steps. train_eval.num_iterations = 489 train_eval.weight_decay = 1e-6 train_eval.width = 1.0 train_eval.policy_save_interval = 51000 train_eval.num_epochs = 10 train_eval.eval_interval = 2000 train_eval.eval_episodes = 20 train_eval.sparse_output_layer = False train_eval.train_mode_actor = 'dense' train_eval.train_mode_value = 'dense' mask_updater.update_alg = '' mask_updater.schedule_alg = '' log_snr.freq=5000 ================================================ FILE: rigl/rl/tfagents/configs/ppo_mujoco_pruning_config.gin ================================================ include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin' train_eval.sparse_output_layer = True train_eval.train_mode_actor = 'sparse' train_eval.train_mode_value = 'sparse' # This must be set to 0 when pruning to avoid # initializing the masks init_masks.sparsity = 0.0 wrap_all_layers.mode = 'prune' wrap_all_layers.initial_sparsity = 0.0 wrap_all_layers.final_sparsity = 0.9 wrap_all_layers.mask_init_method = 'erdos_renyi_kernel' # 156,480 steps total # Start at ~20% = 31,296 # End at ~75% = 117,360 wrap_all_layers.begin_step = 32000 wrap_all_layers.end_step = 120000 wrap_all_layers.frequency = 500 log_sparsities.log_images = False ================================================ FILE: rigl/rl/tfagents/configs/ppo_mujoco_sparse_config.gin ================================================ include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin' # Config to run PPO training for static, set, and rigl on mujoco environments. train_eval.sparse_output_layer = True train_eval.train_mode_actor = 'sparse' train_eval.train_mode_value = 'sparse' train_eval.weight_decay = 1e-4 init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.fixed_sparse_init = True init_masks.sparsity = 0.9 # For static, set this to '' # For rigl set this to 'rigl' # For set set this to 'set' mask_updater.update_alg = '' mask_updater.schedule_alg = 'cosine' mask_updater.update_freq = 250 mask_updater.init_drop_fraction = 0.3 # 156,480 steps total, end at 75% = 117,360 mask_updater.last_update_step = 120000 mask_updater.use_stateless = False wrap_all_layers.mode = 'constant' log_sparsities.log_images = False ================================================ FILE: rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin ================================================ # Config to run SAC training for dense on mujoco environments. train_eval.env_name = 'Humanoid-v2' train_eval.initial_collect_steps = 1000 train_eval.num_iterations = 1000000 # 1M train_eval.width = 1.0 train_eval.weight_decay = 1e-4 ================================================ FILE: rigl/rl/tfagents/configs/sac_mujoco_pruning_config.gin ================================================ include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin' # Configs to run SAC training for pruning on mujoco environments. train_eval.train_mode_actor = 'sparse' # Both critics train_eval.train_mode_value = 'sparse' train_eval.sparse_output_layer = True init_masks.fixed_sparse_init = True # This must be set to 0 when pruning to avoid # initializing the masks init_masks.sparsity = 0.0 wrap_all_layers.mode = 'prune' wrap_all_layers.initial_sparsity = 0.0 wrap_all_layers.final_sparsity = 0.9 wrap_all_layers.mask_init_method = 'erdos_renyi_kernel' # 1M steps total # Start at 20%, end at 80% wrap_all_layers.begin_step = 200000 wrap_all_layers.end_step = 800000 wrap_all_layers.frequency = 1000 log_sparsities.log_images = False ================================================ FILE: rigl/rl/tfagents/configs/sac_mujoco_sparse_config.gin ================================================ include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin' # Configs to run SAC training for static, set, and rigl on mujoco # environments. train_eval.sparse_output_layer = True train_eval.train_mode_actor = 'sparse' # Both critics train_eval.train_mode_value = 'sparse' train_eval.actor_critic_sparsities_str = '' train_eval.weight_decay = 1e-6 init_masks.mask_init_method = 'erdos_renyi_kernel' init_masks.fixed_sparse_init = True init_masks.sparsity = 0.9 mask_updater.update_alg = '' mask_updater.schedule_alg = 'cosine' mask_updater.update_freq = 1000 mask_updater.init_drop_fraction = 0.5 # 1M / train_eval.num_iterations * 0.8 mask_updater.last_update_step = 800000 mask_updater.use_stateless = False wrap_all_layers.mode = 'constant' log_sparsities.log_images = False ================================================ FILE: rigl/rl/tfagents/dqn_train_eval.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Sparse training DQN using actor/learner in a gym environment. """ import functools import os from typing import Tuple from absl import app from absl import flags from absl import logging import gin import numpy as np import reverb from rigl.rigl_tf2 import mask_updaters from rigl.rl import sparse_utils from rigl.rl.tfagents import tf_sparse_utils import tensorflow.compat.v2 as tf from tf_agents.agents.dqn import dqn_agent from tf_agents.environments import suite_atari from tf_agents.environments import suite_gym from tf_agents.metrics import py_metrics from tf_agents.networks import sequential from tf_agents.policies import py_tf_eager_policy from tf_agents.policies import random_py_policy from tf_agents.replay_buffers import reverb_replay_buffer from tf_agents.replay_buffers import reverb_utils from tf_agents.specs import tensor_spec from tf_agents.system import system_multiprocessing as multiprocessing from tf_agents.train import actor from tf_agents.train import learner from tf_agents.train import triggers from tf_agents.train.utils import train_utils from tf_agents.utils import common from tf_agents.utils import eager_utils FLAGS = flags.FLAGS flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_integer( 'reverb_port', None, 'Port for reverb server, if None, use a randomly chosen unused port.') flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override the values set in the config files ' '(e.g. "train_eval.env_name=Acrobot-v1",' ' "init_masks.sparsity=0.9").') flags.DEFINE_float( 'average_last_fraction', 0.1, 'Tells what fraction latest evaluation scores are averaged. This is used' ' to reduce variance.') @gin.configurable class SparseDqnAgent(dqn_agent.DqnAgent): """Wrapped DqnAgent that supports sparse training.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) _ = sparse_utils.init_masks(self._q_network) def loss_fn(experience_data, weights_data): # The following is just to fit to the existing API. loss_info = self._loss( experience_data, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights_data, training=True) return loss_info.extra.td_loss # Create mask updater if doesn't exists self._mask_updater = mask_updaters.get_mask_updater( self._q_network, self._optimizer, loss_fn) def _train(self, experience, weights): tf.compat.v2.summary.experimental.set_step(self.train_step_counter) tf_sparse_utils.update_prune_step(self._q_network, self._train_step_counter) with tf.GradientTape(persistent=True) as tape: loss_info = self._loss( experience, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights, training=True) tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan') variables_to_train = self._q_network.trainable_weights non_trainable_weights = self._q_network.non_trainable_weights assert list(variables_to_train), "No variables in the agent's q_network." grads = tape.gradient(loss_info.loss, variables_to_train) tf_sparse_utils.log_snr(tape, loss_info.extra.td_loss, self.train_step_counter, variables_to_train) # Tuple is used for py3, where zip is a generator producing values once. grads_and_vars = list(zip(grads, variables_to_train)) def _mask_update_step(): # Second argument is not used. self._mask_updater.set_validation_data(experience, weights) self._mask_updater.update(self.train_step_counter) with tf.name_scope('/'): tf.summary.scalar( name='drop_fraction', data=self._mask_updater.last_drop_fraction) tf_sparse_utils.log_sparsities(self._q_network) if self._mask_updater is not None: is_update = self._mask_updater.is_update_iter(self.train_step_counter) tf.cond(is_update, _mask_update_step, lambda: None) if self._gradient_clipping is not None: grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars, self._gradient_clipping) if self._summarize_grads_and_vars: grads_and_vars_with_non_trainable = ( grads_and_vars + [(None, v) for v in non_trainable_weights]) eager_utils.add_variables_summaries(grads_and_vars_with_non_trainable, self.train_step_counter) eager_utils.add_gradients_summaries(grads_and_vars, self.train_step_counter) self._optimizer.apply_gradients(grads_and_vars) self.train_step_counter.assign_add(1) self._update_target() return loss_info def _scale_width(num_units, width): assert width > 0 return int(max(1, num_units * width)) def build_network( fc_layer_params, num_actions, is_sparse, input_dim, width = 1.0, weight_decay = 0.0, sparse_output_layer = True ): """Builds a Sequential model.""" def dense_layer(num_units): return tf.keras.layers.Dense( num_units, activation=tf.keras.activations.relu, kernel_initializer=tf.keras.initializers.VarianceScaling( scale=2.0, mode='fan_in', distribution='truncated_normal'), kernel_regularizer=tf.keras.regularizers.L2(weight_decay),) # QNetwork consists of a sequence of Dense layers followed by a dense layer # with `num_actions` units to generate one q_value per available action as # its output. all_layers = [ dense_layer(_scale_width(num_units, width=width) ) for num_units in fc_layer_params] all_layers.append( tf.keras.layers.Dense( num_actions, activation=None, kernel_initializer=tf.keras.initializers.RandomUniform( minval=-0.03, maxval=0.03), bias_initializer=tf.keras.initializers.Constant(-0.2))) if is_sparse: if sparse_output_layer: all_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim) else: all_layers = (tf_sparse_utils.wrap_all_layers(all_layers[:-1], input_dim) + all_layers[-1:]) return sequential.Sequential(all_layers) @gin.configurable def train_eval( root_dir, env_name='CartPole-v0', # Training params update_frequency=1, initial_collect_steps=1000, num_iterations=100000, fc_layer_params=(100,), # Agent params epsilon_greedy=0.1, epsilon_decay_period=250000, batch_size=64, learning_rate=1e-3, n_step_update=1, gamma=0.99, target_update_tau=1.0, target_update_period=100, reward_scale_factor=1.0, # Replay params reverb_port=None, replay_capacity=100000, # Others policy_save_interval=1000, eval_interval=1000, eval_episodes=10, weight_decay = 0.0, width = 1.0, debug_summaries=False, sparse_output_layer=True, train_mode='dense'): """Trains and evaluates DQN.""" logging.info('DQN params: Fc layer params: %s', fc_layer_params) logging.info('DQN params: Train mode: %s', train_mode) logging.info('DQN params: Target update period: %s', target_update_period) logging.info('DQN params: Policy save interval: %s', policy_save_interval) logging.info('DQN params: Eval interval: %s', eval_interval) logging.info('DQN params: Environment name: %s', env_name) logging.info('DQN params: Weight decay: %s', weight_decay) logging.info('DQN params: Width: %s', width) logging.info('DQN params: Batch size: %s', batch_size) logging.info('DQN params: Target update period: %s', target_update_period) logging.info('DQN params: Learning rate: %s', learning_rate) logging.info('DQN params: Num iterations: %s', num_iterations) logging.info('DQN params: Sparse output layer: %s', sparse_output_layer) collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) logging.info('Collect env: %s', collect_env) logging.info('Eval env: %s', eval_env) time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec()) action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec()) train_step = train_utils.create_train_step() num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1 observation_shape = collect_env.observation_spec().shape # Build network and get pruning params is_atari = False if not is_atari: q_net = build_network( fc_layer_params=fc_layer_params, num_actions=num_actions, is_sparse=(train_mode == 'sparse'), # observation_shape is 1-dimensional. We need this so that we can # calculate the dimensions of the first layer. input_dim=observation_shape[-1], width=width, weight_decay=weight_decay, sparse_output_layer=sparse_output_layer) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) loss = common.element_wise_squared_loss decay_fn = epsilon_greedy agent = SparseDqnAgent( time_step_tensor_spec, action_tensor_spec, q_network=q_net, epsilon_greedy=decay_fn, n_step_update=n_step_update, target_update_tau=target_update_tau, target_update_period=target_update_period, optimizer=optimizer, td_errors_loss_fn=loss, gamma=gamma, reward_scale_factor=reward_scale_factor, train_step_counter=train_step, debug_summaries=debug_summaries) table_name = 'uniform_table' table = reverb.Table( table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_server = reverb.Server([table], port=reverb_port) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) dataset = reverb_replay.as_dataset( num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2).prefetch(3) experience_dataset_fn = lambda: dataset saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.StepPerSecondLogTrigger(train_step, interval=100), ] dqn_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, run_optimizer_variable_init=False) # If we haven't trained yet make sure we collect some random samples first to # fill up the Replay Buffer with some experience. random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=update_frequency, observers=[rb_observer, env_step_metric], metrics=actor.collect_metrics(10), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), ) tf_greedy_policy = agent.policy greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), ) average_returns = [] if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() for metric in eval_actor.metrics: if isinstance(metric, py_metrics.AverageReturnMetric): average_returns.append(metric._buffer.mean()) logging.info('Training.') for _ in range(num_iterations): collect_actor.run() dqn_learner.run(iterations=1) if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() for metric in eval_actor.metrics: if isinstance(metric, py_metrics.AverageReturnMetric): average_returns.append(metric._buffer.mean()) # Log last section of evaluation scores for the final metric. idx = int(FLAGS.average_last_fraction * len(average_returns)) avg_return = np.mean(average_returns[-idx:]) logging.info('Step %d, Average Return: %f', env_step_metric.result(), avg_return) rb_observer.close() reverb_server.stop() def main(_): tf.config.experimental_run_functions_eagerly(False) logging.set_verbosity(logging.INFO) tf.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) logging.info('Gin bindings: %s', FLAGS.gin_bindings) train_eval( FLAGS.root_dir, reverb_port=FLAGS.reverb_port) if __name__ == '__main__': flags.mark_flag_as_required('root_dir') multiprocessing.handle_main(functools.partial(app.run, main)) ================================================ FILE: rigl/rl/tfagents/ppo_train_eval.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Sparse training PPO using actor/learner in a gym environment. """ import collections import functools import os from typing import Optional from absl import app from absl import flags from absl import logging import gin import numpy as np import reverb from rigl.rigl_tf2 import mask_updaters from rigl.rl import sparse_utils from rigl.rl.tfagents import sparse_ppo_actor_network from rigl.rl.tfagents import sparse_ppo_discrete_actor_network from rigl.rl.tfagents import sparse_value_network from rigl.rl.tfagents import tf_sparse_utils import tensorflow.compat.v2 as tf from tf_agents.agents import tf_agent from tf_agents.agents.ppo import ppo_clip_agent from tf_agents.agents.ppo import ppo_utils from tf_agents.environments import suite_gym from tf_agents.environments import suite_mujoco from tf_agents.metrics import py_metrics from tf_agents.networks import network from tf_agents.policies import py_tf_eager_policy from tf_agents.replay_buffers import reverb_replay_buffer from tf_agents.replay_buffers import reverb_utils from tf_agents.specs import tensor_spec from tf_agents.system import system_multiprocessing as multiprocessing from tf_agents.train import actor from tf_agents.train import learner from tf_agents.train import ppo_learner from tf_agents.train import triggers from tf_agents.train.utils import spec_utils from tf_agents.train.utils import train_utils from tf_agents.trajectories import time_step as ts from tf_agents.typing import types from tf_agents.utils import common from tf_agents.utils import eager_utils from tf_agents.utils import nest_utils from tf_agents.utils import object_identity FLAGS = flags.FLAGS flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_integer( 'reverb_port', None, 'Port for reverb server, if None, use a randomly chosen unused port.') flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override the values set in the config files ' '(e.g. "train_eval.env_name=Acrobot-v1",' ' "init_masks.sparsity=0.9").') # Env params flags.DEFINE_bool('is_atari', False, 'Whether the env is an atari game.') flags.DEFINE_bool('is_mujoco', False, 'Whether the env is a mujoco game.') flags.DEFINE_bool('is_classic', False, 'Whether the env is a classic control game.') flags.DEFINE_float( 'average_last_fraction', 0.1, 'Tells what fraction latest evaluation scores are averaged. This is used' ' to reduce variance.') SparsePPOLossInfo = collections.namedtuple('SparsePPOLossInfo', ( 'policy_gradient_loss', 'value_estimation_loss', 'l2_regularization_loss', 'entropy_regularization_loss', 'kl_penalty_loss', 'total_loss_per_sample', )) def _normalize_advantages(advantages, axes=(0,), variance_epsilon=1e-8): adv_mean, adv_var = tf.nn.moments(advantages, axes=axes, keepdims=True) normalized_advantages = tf.nn.batch_normalization( advantages, adv_mean, adv_var, offset=None, scale=None, variance_epsilon=variance_epsilon) return normalized_advantages @gin.configurable class SparsePPOAgent(ppo_clip_agent.PPOClipAgent): """Wrapped PPOClipAgent that supports sparse training.""" def __init__(self, *args, policy_l2_reg=0.0, value_function_l2_reg=0.0, shared_vars_l2_reg=0.0, **kwargs): super().__init__(*args, policy_l2_reg=policy_l2_reg, value_function_l2_reg=value_function_l2_reg, shared_vars_l2_reg=shared_vars_l2_reg, **kwargs) # Name scoping has been removed here so # debug_summaries are permenantly disabled. To restore with proper # scoping. self._debug_summaries = False # Pruning layer requires the pruning_step to be >1 during forward pass. tf_sparse_utils.update_prune_step( self._actor_net, self.train_step_counter + 1) tf_sparse_utils.update_prune_step( self._value_net, self.train_step_counter + 1) _ = sparse_utils.init_masks(self._actor_net) _ = sparse_utils.init_masks(self._value_net) # BEGIN: sparse training create mask updaters def loss_fn(experience_data, weights_data): # The following is just to fit to the existing API. (time_steps, actions, old_act_log_probs, returns, normalized_advantages, old_action_distribution_parameters, masked_weights, old_value_predictions) = self._process_experience_weights( experience_data, weights_data) loss_info = self.get_loss( time_steps, actions, old_act_log_probs, returns, normalized_advantages, old_action_distribution_parameters, masked_weights, self.train_step_counter, False, old_value_predictions=old_value_predictions, training=True) return loss_info.extra.total_loss_per_sample self._mask_updater_actor = mask_updaters.get_mask_updater( self._actor_net, self._optimizer, loss_fn) self._mask_updater_value = mask_updaters.get_mask_updater( self._value_net, self._optimizer, loss_fn) # END: sparse training create mask updaters logging.info('SparsePPOAgent: policy_l2_reg %.5f.', policy_l2_reg) logging.info('SparsePPOAgent: value_function_l2_reg %.5f.', value_function_l2_reg) logging.info('SparsePPOAgent: shared_vars_l2_reg %.5f.', shared_vars_l2_reg) def _process_experience_weights(self, experience, weights): experience = self._as_trajectory(experience) if self._compute_value_and_advantage_in_train: processed_experience = self._preprocess(experience) else: processed_experience = experience # Mask trajectories that cannot be used for training. valid_mask = ppo_utils.make_trajectory_mask(processed_experience) if weights is None: masked_weights = valid_mask else: masked_weights = weights * valid_mask # Reconstruct per-timestep policy distribution from stored distribution # parameters. old_action_distribution_parameters = processed_experience.policy_info[ 'dist_params'] old_actions_distribution = ( ppo_utils.distribution_from_spec( self._action_distribution_spec, old_action_distribution_parameters, legacy_distribution_network=isinstance( self._actor_net, network.DistributionNetwork))) # Compute log probability of actions taken during data collection, using the # collect policy distribution. old_act_log_probs = common.log_probability(old_actions_distribution, processed_experience.action, self._action_spec) if self._debug_summaries and not tf.config.list_logical_devices('TPU'): actions_list = tf.nest.flatten(processed_experience.action) show_action_index = len(actions_list) != 1 for i, single_action in enumerate(actions_list): action_name = ('actions_{}'.format(i) if show_action_index else 'actions') tf.compat.v2.summary.histogram( name=action_name, data=single_action, step=self.train_step_counter) time_steps = ts.TimeStep( step_type=processed_experience.step_type, reward=processed_experience.reward, discount=processed_experience.discount, observation=processed_experience.observation) actions = processed_experience.action returns = processed_experience.policy_info['return'] advantages = processed_experience.policy_info['advantage'] normalized_advantages = _normalize_advantages(advantages, variance_epsilon=1e-8) if self._debug_summaries and not tf.config.list_logical_devices('TPU'): tf.compat.v2.summary.histogram( name='advantages_normalized', data=normalized_advantages, step=self.train_step_counter) old_value_predictions = processed_experience.policy_info['value_prediction'] return (time_steps, actions, old_act_log_probs, returns, normalized_advantages, old_action_distribution_parameters, masked_weights, old_value_predictions) def _train(self, experience, weights): tf.compat.v2.summary.experimental.set_step(self.train_step_counter) (time_steps, actions, old_act_log_probs, returns, normalized_advantages, old_action_distribution_parameters, masked_weights, old_value_predictions) = self._process_experience_weights( experience, weights) if self._compute_value_and_advantage_in_train: processed_experience = self._preprocess(experience) else: processed_experience = experience batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] # Loss tensors across batches will be aggregated for summaries. policy_gradient_losses = [] value_estimation_losses = [] l2_regularization_losses = [] entropy_regularization_losses = [] kl_penalty_losses = [] loss_info = None variables_to_train = list( object_identity.ObjectIdentitySet(self._actor_net.trainable_weights + self._value_net.trainable_weights)) # Sort to ensure tensors on different processes end up in same order. variables_to_train = sorted(variables_to_train, key=lambda x: x.name) for _ in range(self._num_epochs): # Name scoping has been removed here so # debug_summaries are permenantly disabled. To restore with proper # scoping. debug_summaries = False with tf.GradientTape(persistent=True) as tape: loss_info = self.get_loss( time_steps, actions, old_act_log_probs, returns, normalized_advantages, old_action_distribution_parameters, masked_weights, self.train_step_counter, debug_summaries, old_value_predictions=old_value_predictions, training=True) grads = tape.gradient(loss_info.loss, variables_to_train) tf_sparse_utils.log_snr(tape, loss_info.extra.total_loss_per_sample, self.train_step_counter, variables_to_train) # BEGIN sparse training mask update # We use the lastest set of gradients to update the masks for sparse # training. Note, we do this before gradient clipping. def _mask_update_step(mask_updater, updater_name): mask_updater.set_validation_data(experience, weights) mask_updater.update(self.train_step_counter) with tf.name_scope('Drop_fraction/'): tf.summary.scalar( name=f'{updater_name}', data=mask_updater.last_drop_fraction) mask_update_step_actor = functools.partial( _mask_update_step, self._mask_updater_actor, 'actor') mask_update_step_value = functools.partial( _mask_update_step, self._mask_updater_value, 'value') tf_sparse_utils.log_sparsities(self._actor_net, 'actor') tf_sparse_utils.log_sparsities(self._value_net, 'value') tf_sparse_utils.log_total_params([self._actor_net, self._value_net]) if self._mask_updater_actor is not None: is_update_actor = self._mask_updater_actor.is_update_iter( self.train_step_counter) tf.cond(is_update_actor, mask_update_step_actor, lambda: None) if self._mask_updater_value is not None: is_update_value = self._mask_updater_value.is_update_iter( self.train_step_counter) tf.cond(is_update_value, mask_update_step_value, lambda: None) # END sparse training mask update if self._gradient_clipping > 0: grads, _ = tf.clip_by_global_norm(grads, self._gradient_clipping) # Tuple is used for py3, where zip is a generator producing values once. grads_and_vars = tuple(zip(grads, variables_to_train)) # If summarize_gradients, create functions for summarizing both # gradients and variables. if self._summarize_grads_and_vars and debug_summaries: eager_utils.add_gradients_summaries(grads_and_vars, self.train_step_counter) eager_utils.add_variables_summaries(grads_and_vars, self.train_step_counter) self._optimizer.apply_gradients(grads_and_vars) self.train_step_counter.assign_add(1) policy_gradient_losses.append(loss_info.extra.policy_gradient_loss) value_estimation_losses.append(loss_info.extra.value_estimation_loss) l2_regularization_losses.append(loss_info.extra.l2_regularization_loss) entropy_regularization_losses.append( loss_info.extra.entropy_regularization_loss) kl_penalty_losses.append(loss_info.extra.kl_penalty_loss) if self._initial_adaptive_kl_beta > 0: # After update epochs, update adaptive kl beta, then update observation # normalizer and reward normalizer. policy_state = self._collect_policy.get_initial_state(batch_size) # Compute the mean kl from previous action distribution. kl_divergence = self._kl_divergence( time_steps, old_action_distribution_parameters, self._collect_policy.distribution(time_steps, policy_state).action) self.update_adaptive_kl_beta(kl_divergence) if self.update_normalizers_in_train: self.update_observation_normalizer(time_steps.observation) self.update_reward_normalizer(processed_experience.reward) loss_info = tf.nest.map_structure(tf.identity, loss_info) # Make summaries for total loss averaged across all epochs. # The *_losses lists will have been populated by # calls to self.get_loss. Assumes all the losses have same length. with tf.name_scope('Losses/'): num_epochs = len(policy_gradient_losses) total_policy_gradient_loss = tf.add_n(policy_gradient_losses) / num_epochs total_value_estimation_loss = tf.add_n( value_estimation_losses) / num_epochs total_l2_regularization_loss = tf.add_n( l2_regularization_losses) / num_epochs total_entropy_regularization_loss = tf.add_n( entropy_regularization_losses) / num_epochs total_kl_penalty_loss = tf.add_n(kl_penalty_losses) / num_epochs tf.compat.v2.summary.scalar( name='policy_gradient_loss', data=total_policy_gradient_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='value_estimation_loss', data=total_value_estimation_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='l2_regularization_loss', data=total_l2_regularization_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_regularization_loss', data=total_entropy_regularization_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='kl_penalty_loss', data=total_kl_penalty_loss, step=self.train_step_counter) total_abs_loss = ( tf.abs(total_policy_gradient_loss) + tf.abs(total_value_estimation_loss) + tf.abs(total_entropy_regularization_loss) + tf.abs(total_l2_regularization_loss) + tf.abs(total_kl_penalty_loss)) tf.compat.v2.summary.scalar( name='total_abs_loss', data=total_abs_loss, step=self.train_step_counter) with tf.name_scope('LearningRate/'): learning_rate = ppo_utils.get_learning_rate(self._optimizer) tf.compat.v2.summary.scalar( name='learning_rate', data=learning_rate, step=self.train_step_counter) if self._summarize_grads_and_vars and not tf.config.list_logical_devices( 'TPU'): with tf.name_scope('Variables/'): all_vars = ( self._actor_net.trainable_weights + self._value_net.trainable_weights) for var in all_vars: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) return loss_info def get_loss(self, time_steps, actions, act_log_probs, returns, normalized_advantages, action_distribution_parameters, weights, train_step, debug_summaries, old_value_predictions = None, training = False): """Compute the loss and create optimization op for one training epoch. All tensors should have a single batch dimension. Args: time_steps: A minibatch of TimeStep tuples. actions: A minibatch of actions. act_log_probs: A minibatch of action probabilities (probability under the sampling policy). returns: A minibatch of per-timestep returns. normalized_advantages: A minibatch of normalized per-timestep advantages. action_distribution_parameters: Parameters of data-collecting action distribution. Needed for KL computation. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps. train_step: A train_step variable to increment for each train step. Typically the global_step. debug_summaries: True if debug summaries should be created. old_value_predictions: (Optional) The saved value predictions, used for calculating the value estimation loss when value clipping is performed. training: Whether this loss is being used for training. Returns: A tf_agent.LossInfo named tuple with the total_loss and all intermediate losses in the extra field contained in a PPOLossInfo named tuple. """ # Evaluate the current policy on timesteps. # batch_size from time_steps batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._collect_policy.get_initial_state(batch_size) # We must use _distribution because the distribution API doesn't pass down # the training= kwarg. distribution_step = self._collect_policy._distribution( time_steps, policy_state, training=training) current_policy_distribution = distribution_step.action # Call all loss functions and add all loss values. (value_estimation_loss, value_estimation_loss_per_sample) = self.value_estimation_loss( time_steps=time_steps, returns=returns, old_value_predictions=old_value_predictions, weights=weights, debug_summaries=debug_summaries, training=training) (policy_gradient_loss, policy_gradient_loss_per_sample) = self.policy_gradient_loss( time_steps, actions, tf.stop_gradient(act_log_probs), tf.stop_gradient(normalized_advantages), current_policy_distribution, weights, debug_summaries=debug_summaries) if (self._policy_l2_reg > 0.0 or self._value_function_l2_reg > 0.0 or self._shared_vars_l2_reg > 0.0): l2_regularization_loss = self.l2_regularization_loss(debug_summaries) else: l2_regularization_loss = tf.zeros_like(policy_gradient_loss) l2_regularization_loss_per_sample = tf.repeat( l2_regularization_loss / tf.cast(batch_size, tf.float32), batch_size) if self._entropy_regularization > 0.0: (entropy_regularization_loss, entropy_regularization_loss_per_sample ) = self.entropy_regularization_loss(time_steps, current_policy_distribution, weights, debug_summaries) else: entropy_regularization_loss = tf.zeros_like(policy_gradient_loss) entropy_regularization_loss_per_sample = tf.repeat( tf.constant(0, dtype=tf.float32), batch_size) if self._initial_adaptive_kl_beta == 0: kl_penalty_loss = tf.zeros_like(policy_gradient_loss) else: kl_penalty_loss = self.kl_penalty_loss(time_steps, action_distribution_parameters, current_policy_distribution, weights, debug_summaries) kl_penalty_loss_per_sample = tf.repeat( kl_penalty_loss / tf.cast(batch_size, tf.float32), batch_size) total_loss = ( policy_gradient_loss + value_estimation_loss + l2_regularization_loss + entropy_regularization_loss + kl_penalty_loss) total_loss_per_sample = ( policy_gradient_loss_per_sample + value_estimation_loss_per_sample + l2_regularization_loss_per_sample + entropy_regularization_loss_per_sample + kl_penalty_loss_per_sample) return tf_agent.LossInfo( total_loss, SparsePPOLossInfo( policy_gradient_loss=policy_gradient_loss, value_estimation_loss=value_estimation_loss, l2_regularization_loss=l2_regularization_loss, entropy_regularization_loss=entropy_regularization_loss, kl_penalty_loss=kl_penalty_loss, total_loss_per_sample=total_loss_per_sample )) def value_estimation_loss(self, time_steps, returns, weights, old_value_predictions = None, debug_summaries = False, training = False): """Computes the value estimation loss for actor-critic training. All tensors should have a single batch dimension. Args: time_steps: A batch of timesteps. returns: Per-timestep returns for value function to predict. (Should come from TD-lambda computation.) weights: Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps. old_value_predictions: (Optional) The saved value predictions from policy_info, required when self._value_clipping > 0. debug_summaries: True if debug summaries should be created. training: Whether this loss is going to be used for training. Returns: value_estimation_loss: A scalar value_estimation_loss loss. Raises: ValueError: If old_value_predictions was not passed in, but value clipping was performed. """ observation = time_steps.observation if debug_summaries and not tf.config.list_logical_devices('TPU'): observation_list = tf.nest.flatten(observation) show_observation_index = len(observation_list) != 1 for i, single_observation in enumerate(observation_list): observation_name = ('observations_{}'.format(i) if show_observation_index else 'observations') tf.compat.v2.summary.histogram( name=observation_name, data=single_observation, step=self.train_step_counter) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] value_state = self._collect_policy.get_initial_value_state(batch_size) value_preds, _ = self._collect_policy.apply_value_network( time_steps.observation, time_steps.step_type, value_state=value_state, training=training) value_estimation_error = tf.math.squared_difference(returns, value_preds) if self._value_clipping > 0: if old_value_predictions is None: raise ValueError( 'old_value_predictions is None but needed for value clipping.') clipped_value_preds = old_value_predictions + tf.clip_by_value( value_preds - old_value_predictions, -self._value_clipping, self._value_clipping) clipped_value_estimation_error = tf.math.squared_difference( returns, clipped_value_preds) value_estimation_error = tf.maximum(value_estimation_error, clipped_value_estimation_error) if self._aggregate_losses_across_replicas: value_estimation_loss = ( common.aggregate_losses( per_example_loss=value_estimation_error, sample_weight=weights).total_loss * self._value_pred_loss_coef) else: value_estimation_loss = tf.math.reduce_mean( value_estimation_error * weights) * self._value_pred_loss_coef value_estimation_loss_per_sample = tf.reduce_mean(value_estimation_error, axis=0) if debug_summaries: tf.compat.v2.summary.scalar( name='value_pred_avg', data=tf.reduce_mean(input_tensor=value_preds), step=self.train_step_counter) tf.compat.v2.summary.scalar( name='value_actual_avg', data=tf.reduce_mean(input_tensor=returns), step=self.train_step_counter) tf.compat.v2.summary.scalar( name='value_estimation_loss', data=value_estimation_loss, step=self.train_step_counter) if not tf.config.list_logical_devices('TPU'): tf.compat.v2.summary.histogram( name='value_preds', data=value_preds, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='value_estimation_error', data=value_estimation_error, step=self.train_step_counter) if self._check_numerics: value_estimation_loss = tf.debugging.check_numerics( value_estimation_loss, 'value_estimation_loss') value_estimation_loss_per_sample = tf.debugging.check_numerics( value_estimation_loss_per_sample, 'value_estimation_loss_per_sample') return value_estimation_loss, value_estimation_loss_per_sample def policy_gradient_loss( self, time_steps, actions, sample_action_log_probs, advantages, current_policy_distribution, weights, debug_summaries = False): """Create tensor for policy gradient loss. All tensors should have a single batch dimension. Args: time_steps: TimeSteps with observations for each timestep. actions: Tensor of actions for timesteps, aligned on index. sample_action_log_probs: Tensor of sample probability of each action. advantages: Tensor of advantage estimate for each timestep, aligned on index. Works better when advantage estimates are normalized. current_policy_distribution: The policy distribution, evaluated on all time_steps. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps. debug_summaries: True if debug summaries should be created. Returns: policy_gradient_loss: A tensor that will contain policy gradient loss for the on-policy experience. """ nest_utils.assert_same_structure(time_steps, self.time_step_spec) action_log_prob = common.log_probability(current_policy_distribution, actions, self._action_spec) action_log_prob = tf.cast(action_log_prob, tf.float32) if self._log_prob_clipping > 0.0: action_log_prob = tf.clip_by_value(action_log_prob, -self._log_prob_clipping, self._log_prob_clipping) if self._check_numerics: action_log_prob = tf.debugging.check_numerics(action_log_prob, 'action_log_prob') # Prepare both clipped and unclipped importance ratios. importance_ratio = tf.exp(action_log_prob - sample_action_log_probs) importance_ratio_clipped = tf.clip_by_value( importance_ratio, 1 - self._importance_ratio_clipping, 1 + self._importance_ratio_clipping) if self._check_numerics: importance_ratio = tf.debugging.check_numerics(importance_ratio, 'importance_ratio') if self._importance_ratio_clipping > 0.0: importance_ratio_clipped = tf.debugging.check_numerics( importance_ratio_clipped, 'importance_ratio_clipped') # Pessimistically choose the minimum objective value for clipped and # unclipped importance ratios. per_timestep_objective = importance_ratio * advantages per_timestep_objective_clipped = importance_ratio_clipped * advantages per_timestep_objective_min = tf.minimum(per_timestep_objective, per_timestep_objective_clipped) if self._importance_ratio_clipping > 0.0: policy_gradient_loss = -per_timestep_objective_min else: policy_gradient_loss = -per_timestep_objective policy_gradient_loss_per_sample = tf.reduce_mean(policy_gradient_loss, axis=0) if self._aggregate_losses_across_replicas: policy_gradient_loss = common.aggregate_losses( per_example_loss=policy_gradient_loss, sample_weight=weights).total_loss else: policy_gradient_loss = tf.math.reduce_mean(policy_gradient_loss * weights) if debug_summaries: if self._importance_ratio_clipping > 0.0: clip_fraction = tf.reduce_mean( input_tensor=tf.cast( tf.greater( tf.abs(importance_ratio - 1.0), self._importance_ratio_clipping), tf.float32)) tf.compat.v2.summary.scalar( name='clip_fraction', data=clip_fraction, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='importance_ratio_mean', data=tf.reduce_mean(input_tensor=importance_ratio), step=self.train_step_counter) entropy = common.entropy(current_policy_distribution, self.action_spec) tf.compat.v2.summary.scalar( name='policy_entropy_mean', data=tf.reduce_mean(input_tensor=entropy), step=self.train_step_counter) if not tf.config.list_logical_devices('TPU'): tf.compat.v2.summary.histogram( name='action_log_prob', data=action_log_prob, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='action_log_prob_sample', data=sample_action_log_probs, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='importance_ratio', data=importance_ratio, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='importance_ratio_clipped', data=importance_ratio_clipped, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='per_timestep_objective', data=per_timestep_objective, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='per_timestep_objective_clipped', data=per_timestep_objective_clipped, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='per_timestep_objective_min', data=per_timestep_objective_min, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='policy_entropy', data=entropy, step=self.train_step_counter) for i, (single_action, single_distribution) in enumerate( zip( tf.nest.flatten(self.action_spec), tf.nest.flatten(current_policy_distribution))): # Categorical distribution (used for discrete actions) doesn't have a # mean. distribution_index = '_{}'.format(i) if i > 0 else '' if not tensor_spec.is_discrete(single_action): tf.compat.v2.summary.histogram( name='actions_distribution_mean' + distribution_index, data=single_distribution.mean(), step=self.train_step_counter) tf.compat.v2.summary.histogram( name='actions_distribution_stddev' + distribution_index, data=single_distribution.stddev(), step=self.train_step_counter) tf.compat.v2.summary.histogram( name='policy_gradient_loss', data=policy_gradient_loss, step=self.train_step_counter) if self._check_numerics: policy_gradient_loss = tf.debugging.check_numerics( policy_gradient_loss, 'policy_gradient_loss') policy_gradient_loss_per_sample = tf.debugging.check_numerics( policy_gradient_loss_per_sample, 'policy_gradient_loss_per_sample') return policy_gradient_loss, policy_gradient_loss_per_sample def entropy_regularization_loss( self, time_steps, current_policy_distribution, weights, debug_summaries = False): """Create regularization loss tensor based on agent parameters.""" if self._entropy_regularization > 0: nest_utils.assert_same_structure(time_steps, self.time_step_spec) with tf.name_scope('entropy_regularization'): entropy = tf.cast( common.entropy(current_policy_distribution, self.action_spec), tf.float32) if self._aggregate_losses_across_replicas: entropy_reg_loss = common.aggregate_losses( per_example_loss=-entropy, sample_weight=weights).total_loss * self._entropy_regularization else: entropy_reg_loss = ( tf.math.reduce_mean(-entropy * weights) * self._entropy_regularization) if self._check_numerics: entropy_reg_loss = tf.debugging.check_numerics( entropy_reg_loss, 'entropy_reg_loss') if debug_summaries and not tf.config.list_logical_devices('TPU'): tf.compat.v2.summary.histogram( name='entropy_reg_loss', data=entropy_reg_loss, step=self.train_step_counter) else: raise ValueError('This is not allowed, this is handled at loss level.') entropy_reg_loss_per_sample = -entropy if self._check_numerics: entropy_reg_loss_per_sample = tf.debugging.check_numerics( entropy_reg_loss_per_sample, 'entropy_reg_loss_per_sample') return entropy_reg_loss, entropy_reg_loss_per_sample class ReverbFixedLengthSequenceObserver(reverb_utils.ReverbAddTrajectoryObserver ): """Reverb fixed length sequence observer. This is a specialized observer similar to ReverbAddTrajectoryObserver but each sequence contains a fixed number of steps and can span multiple episodes. This implementation is consistent with (Schulman, 17). **Note**: Counting of steps in drivers does not include boundary steps. To guarantee only 1 item is pushed to the replay when collecting n steps with a `sequence_length` of n make sure to set the `stride_length`. """ def __call__(self, trajectory): """Writes the trajectory into the underlying replay buffer. Allows trajectory to be a flattened trajectory. No batch dimension allowed. Args: trajectory: The trajectory to be written which could be (possibly nested) trajectory object or a flattened version of a trajectory. It assumes there is *no* batch dimension. """ self._writer.append(trajectory) self._cached_steps += 1 self._write_cached_steps() @gin.configurable def train_eval( root_dir, env_name='HalfCheetah-v2', # Training params num_iterations=1600, actor_fc_layers=(64, 64), value_fc_layers=(64, 64), learning_rate=3e-4, collect_sequence_length=2048, minibatch_size=64, num_epochs=10, # Agent params importance_ratio_clipping=0.2, lambda_value=0.95, discount_factor=0.99, entropy_regularization=0., value_pred_loss_coef=0.5, use_gae=True, use_td_lambda_return=True, gradient_clipping=0.5, value_clipping=None, # Replay params reverb_port=None, replay_capacity=10000, # Others policy_save_interval=5000, summary_interval=1000, eval_interval=10000, eval_episodes=100, debug_summaries=False, summarize_grads_and_vars=False, train_mode_actor='dense', train_mode_value='dense', sparse_output_layer=True, weight_decay=0.0, width=1.0): """Trains and evaluates DQN.""" logging.info('Actor fc layer params: %s', actor_fc_layers) logging.info('Value fc layer params: %s', value_fc_layers) logging.info('Policy save interval: %s', policy_save_interval) logging.info('Eval interval: %s', eval_interval) logging.info('Environment name: %s', env_name) logging.info('Learning rate: %s', learning_rate) logging.info('Num iterations: %s', num_iterations) logging.info('Sparse output layer: %s', sparse_output_layer) logging.info('Train mode actor: %s', train_mode_actor) logging.info('Train mode value: %s', train_mode_value) logging.info('Width: %s', width) logging.info('Weight decay: %s', weight_decay) if FLAGS.is_mujoco: collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) logging.info('Loaded Mujoco environment %s', env_name) elif FLAGS.is_classic: collect_env = suite_gym.load(env_name) eval_env = suite_gym.load(env_name) logging.info('Loaded Classic control environment %s', env_name) else: raise ValueError('Environment init for Atari not supported yet.') num_environments = 1 observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) observation_tensor_spec = tf.TensorSpec( dtype=tf.float32, shape=observation_tensor_spec.shape) train_step = train_utils.create_train_step() if FLAGS.is_classic: actor_net_constructor = sparse_ppo_discrete_actor_network.PPODiscreteActorNetwork else: actor_net_constructor = sparse_ppo_actor_network.PPOActorNetwork actor_net_builder = actor_net_constructor( is_sparse=train_mode_actor == 'sparse', sparse_output_layer=sparse_output_layer, weight_decay=0, width=width) actor_net = actor_net_builder.create_sequential_actor_net( actor_fc_layers, action_tensor_spec, input_dim=time_step_tensor_spec.observation.shape[0]) value_net = sparse_value_network.ValueNetwork( observation_tensor_spec, fc_layer_params=value_fc_layers, kernel_initializer=tf.keras.initializers.Orthogonal(), is_sparse=train_mode_value == 'sparse', sparse_output_layer=sparse_output_layer, weight_decay=0, width=width) logging.info('Train eval: weight decay %.5f.', weight_decay) current_iteration = tf.Variable(0, dtype=tf.int64) def learning_rate_fn(): # Linearly decay the learning rate. return learning_rate * (1 - current_iteration / num_iterations) agent = SparsePPOAgent( time_step_tensor_spec, action_tensor_spec, optimizer=tf.compat.v1.train.AdamOptimizer( learning_rate=learning_rate_fn, epsilon=1e-5), actor_net=actor_net, value_net=value_net, importance_ratio_clipping=importance_ratio_clipping, lambda_value=lambda_value, discount_factor=discount_factor, entropy_regularization=entropy_regularization, value_pred_loss_coef=value_pred_loss_coef, policy_l2_reg=weight_decay, value_function_l2_reg=weight_decay, shared_vars_l2_reg=weight_decay, # This is a legacy argument for the number of times we repeat the data # inside of the train function, incompatible with mini batch learning. # We set the epoch number from the replay buffer and tf.Data instead. num_epochs=1, use_gae=use_gae, use_td_lambda_return=use_td_lambda_return, gradient_clipping=gradient_clipping, value_clipping=value_clipping, compute_value_and_advantage_in_train=False, # Skips updating normalizers in the agent, as it's handled in the learner. update_normalizers_in_train=False, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() reverb_server = reverb.Server( [ reverb.Table( # Replay buffer storing experience for training. name='training_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ), reverb.Table( # Replay buffer storing experience for normalization. name='normalization_table', sampler=reverb.selectors.Fifo(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), max_size=replay_capacity, max_times_sampled=1, ) ], port=reverb_port) # Create the replay buffer. reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='training_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, num_workers_per_iterator=1, max_samples_per_stream=1, rate_limiter_timeout_ms=1000) reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=collect_sequence_length, table_name='normalization_table', server_address='localhost:{}'.format(reverb_server.port), # The only collected sequence is used to populate the batches. max_cycle_length=1, num_workers_per_iterator=1, max_samples_per_stream=1, rate_limiter_timeout_ms=1000) rb_observer = ReverbFixedLengthSequenceObserver( reverb_replay_train.py_client, ['training_table', 'normalization_table'], sequence_length=collect_sequence_length, stride_length=collect_sequence_length) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) collect_env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={ triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric }), triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), ] def training_dataset_fn(): return reverb_replay_train.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) def normalization_dataset_fn(): return reverb_replay_normalization.as_dataset( sample_batch_size=num_environments, sequence_preprocess_fn=agent.preprocess_sequence) agent_learner = ppo_learner.PPOLearner( root_dir, train_step, agent, experience_dataset_fn=training_dataset_fn, normalization_dataset_fn=normalization_dataset_fn, num_samples=1, summary_interval=10, num_epochs=num_epochs, minibatch_size=minibatch_size, shuffle_buffer_size=collect_sequence_length, triggers=learning_triggers) tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=collect_sequence_length, observers=[rb_observer, collect_env_step_metric], metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric], reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), summary_interval=summary_interval) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( agent.policy, use_tf_function=True) average_returns = [] if eval_interval: logging.info('Intial evaluation.') eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, metrics=actor.eval_metrics(eval_episodes), reference_metrics=[collect_env_step_metric], summary_dir=os.path.join(root_dir, 'eval'), episodes_per_run=eval_episodes) eval_actor.run_and_log() for metric in eval_actor.metrics: if isinstance(metric, py_metrics.AverageReturnMetric): average_returns.append(metric._buffer.mean()) logging.info('Training on %s', env_name) last_eval_step = 0 for i in range(num_iterations): logging.info('collect_actor.run') collect_actor.run() # Reset the reverb observer to make sure the data collected is flushed and # written to the RB. # At this point, there a small number of steps left in the cache because the # actor does not count a boundary step as a step, whereas it still gets # added to Reverb for training. We throw away those extra steps without # padding to align with the paper implementation which never collects them # in the first place. logging.info('rb_observer.reset') rb_observer.reset(write_cached_steps=False) logging.info('reverb_replay_normalization.size: %d', reverb_replay_normalization.get_table_info().current_size) logging.info('reverb_replay_train.size: %d', reverb_replay_train.get_table_info().current_size) logging.info('agent_learner.run') agent_learner.run() logging.info('reverb_replay_train.clear') reverb_replay_train.clear() logging.info('reverb_replay_normalization.clear') reverb_replay_normalization.clear() current_iteration.assign_add(1) # Eval only if `eval_interval` has been set. Then, eval if the current train # step is equal or greater than the `last_eval_step` + `eval_interval` or if # this is the last iteration. This logic exists because agent_learner.run() # does not return after every train step. if (eval_interval and (agent_learner.train_step_numpy >= eval_interval + last_eval_step or i == num_iterations - 1)): logging.info('Evaluating.') eval_actor.run_and_log() last_eval_step = agent_learner.train_step_numpy for metric in eval_actor.metrics: if isinstance(metric, py_metrics.AverageReturnMetric): average_returns.append(metric._buffer.mean()) # Log last section of evaluation scores for the final metric. idx = int(FLAGS.average_last_fraction * len(average_returns)) avg_return = np.mean(average_returns[-idx:]) logging.info('Step %d, Average Return: %f', collect_env_step_metric.result(), avg_return) rb_observer.close() reverb_server.stop() def main(_): tf.config.experimental_run_functions_eagerly(False) logging.set_verbosity(logging.INFO) tf.enable_v2_behavior() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) logging.info('Gin bindings: %s', FLAGS.gin_bindings) train_eval( FLAGS.root_dir, reverb_port=FLAGS.reverb_port) if __name__ == '__main__': flags.mark_flag_as_required('root_dir') multiprocessing.handle_main(functools.partial(app.run, main)) ================================================ FILE: rigl/rl/tfagents/sac_train_eval.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Train and Eval SAC. """ import functools import os from absl import app from absl import flags from absl import logging import gin import numpy as np import reverb from rigl.rigl_tf2 import mask_updaters from rigl.rl import sparse_utils from rigl.rl.tfagents import sparse_tanh_normal_projection_network from rigl.rl.tfagents import tf_sparse_utils import tensorflow as tf from tf_agents.agents import tf_agent from tf_agents.agents.sac import sac_agent from tf_agents.environments import suite_mujoco from tf_agents.keras_layers import inner_reshape from tf_agents.metrics import py_metrics from tf_agents.networks import nest_map from tf_agents.networks import sequential from tf_agents.policies import greedy_policy from tf_agents.policies import py_tf_eager_policy from tf_agents.policies import random_py_policy from tf_agents.replay_buffers import reverb_replay_buffer from tf_agents.replay_buffers import reverb_utils from tf_agents.train import actor from tf_agents.train import learner from tf_agents.train import triggers from tf_agents.train.utils import spec_utils from tf_agents.train.utils import strategy_utils from tf_agents.train.utils import train_utils from tf_agents.utils import common from tf_agents.utils import object_identity FLAGS = flags.FLAGS flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_integer( 'reverb_port', None, 'Port for reverb server, if None, use a randomly chosen unused port.') flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.') flags.DEFINE_multi_string('gin_bindings', [], 'Gin binding parameters.') # Env params flags.DEFINE_bool('is_atari', False, 'Whether the env is an atari game.') flags.DEFINE_bool('is_mujoco', False, 'Whether the env is a mujoco game.') flags.DEFINE_bool('is_classic', False, 'Whether the env is a classic control game.') flags.DEFINE_float( 'average_last_fraction', 0.1, 'Tells what fraction latest evaluation scores are averaged. This is used' ' to reduce variance.') dense = functools.partial( tf.keras.layers.Dense, activation=tf.keras.activations.relu, kernel_initializer='glorot_uniform') def create_fc_layers(layer_units, width=1.0, weight_decay=0): layers = [ dense(tf_sparse_utils.scale_width(num_units, width=width), kernel_regularizer=tf.keras.regularizers.L2(weight_decay)) for num_units in layer_units ] return layers def create_identity_layer(): return tf.keras.layers.Lambda(lambda x: x) def create_sequential_critic_network(obs_fc_layer_units, action_fc_layer_units, joint_fc_layer_units, input_dim, is_sparse = False, width = 1.0, weight_decay = 0.0, sparse_output_layer = True): """Create a sequential critic network.""" # Split the inputs into observations and actions. def split_inputs(inputs): return {'observation': inputs[0], 'action': inputs[1]} # Create an observation network layers. obs_network_layers = ( create_fc_layers(obs_fc_layer_units, width=width, weight_decay=weight_decay) if obs_fc_layer_units else None) # Create an action network layers. action_network_layers = ( create_fc_layers(action_fc_layer_units, width=width, weight_decay=weight_decay) if action_fc_layer_units else None) # Create a joint network layers. joint_network_layers = ( create_fc_layers(joint_fc_layer_units, width=width, weight_decay=weight_decay) if joint_fc_layer_units else None) # Final layer. value_layer = tf.keras.layers.Dense( 1, kernel_initializer='glorot_uniform', kernel_regularizer=tf.keras.regularizers.L2(weight_decay)) layer_list = [obs_network_layers, action_network_layers, joint_network_layers] if is_sparse: # We need to process all-layers together to distribute sparsities for # pruning. all_layers = [] for layers in layer_list: if layers is not None: all_layers += layers if sparse_output_layer: all_layers.append(value_layer) new_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim) value_layer = new_layers[-1] new_layers = new_layers[:-1] else: new_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim) # Split back the layers to their own groups c_index = 0 new_layer_list = [] for layers in layer_list: if layers is None: new_layer_list.append(None) else: new_layer_list.append(new_layers[c_index:len(layers)]) c_index += len(layers) layer_list = new_layer_list # Convert layer_list to sequential or identity lambdas: module_list = [create_identity_layer() if layers is None else sequential.Sequential(layers) for layers in layer_list] obs_network, action_network, joint_network = module_list return sequential.Sequential([ tf.keras.layers.Lambda(split_inputs), nest_map.NestMap({ 'observation': obs_network, 'action': action_network }), nest_map.NestFlatten(), tf.keras.layers.Concatenate(), joint_network, value_layer, inner_reshape.InnerReshape(current_shape=[1], new_shape=[]) ], name='sequential_critic') class _TanhNormalProjectionNetworkWrapper( sparse_tanh_normal_projection_network.SparseTanhNormalProjectionNetwork): """Wrapper to pass predefined `outer_rank` to underlying projection net.""" def __init__(self, sample_spec, predefined_outer_rank=1, weight_decay=0.0): super(_TanhNormalProjectionNetworkWrapper, self).__init__( sample_spec=sample_spec, weight_decay=weight_decay) self.predefined_outer_rank = predefined_outer_rank def call(self, inputs, network_state=(), **kwargs): kwargs['outer_rank'] = self.predefined_outer_rank if 'step_type' in kwargs: del kwargs['step_type'] return super(_TanhNormalProjectionNetworkWrapper, self).call(inputs, **kwargs) def create_sequential_actor_network(actor_fc_layers, action_tensor_spec, input_dim, is_sparse = False, width = 1.0, weight_decay = 0.0, sparse_output_layer = True): """Create a sequential actor network.""" def tile_as_nest(non_nested_output): return tf.nest.map_structure(lambda _: non_nested_output, action_tensor_spec) dense_layers = [ dense(tf_sparse_utils.scale_width(num_units, width=width), kernel_regularizer=tf.keras.regularizers.L2(weight_decay)) for num_units in actor_fc_layers ] tanh_normal_projection_network_fn = functools.partial( _TanhNormalProjectionNetworkWrapper, weight_decay=weight_decay) last_layer = nest_map.NestMap( tf.nest.map_structure(tanh_normal_projection_network_fn, action_tensor_spec)) if is_sparse: if sparse_output_layer: dense_layers.append(last_layer.layers[0]._projection_layer) new_layers = tf_sparse_utils.wrap_all_layers(dense_layers, input_dim) dense_layers = new_layers[:-1] last_layer.layers[0]._projection_layer = new_layers[-1] else: dense_layers = tf_sparse_utils.wrap_all_layers(dense_layers, input_dim) return sequential.Sequential( dense_layers + [tf.keras.layers.Lambda(tile_as_nest)] + [last_layer]) @gin.configurable class SparseSacAgent(sac_agent.SacAgent): """Wrapped DqnAgent that supports sparse training.""" def __init__(self, time_step_spec, action_spec, *args, actor_sparsity=None, critic_sparsity=None, **kwargs): super().__init__(time_step_spec, action_spec, *args, **kwargs) # Pruning layer requires the pruning_step to be >1 during forward pass. tf_sparse_utils.update_prune_step( self._critic_network_1, self.train_step_counter + 1) tf_sparse_utils.update_prune_step( self._critic_network_2, self.train_step_counter + 1) tf_sparse_utils.update_prune_step( self._actor_network, self.train_step_counter + 1) if critic_sparsity is not None: _ = sparse_utils.init_masks(self._critic_network_1, sparsity=critic_sparsity) _ = sparse_utils.init_masks(self._critic_network_2, sparsity=critic_sparsity) else: # Uses init_mask.sparsity value. Either the default or set via gin. _ = sparse_utils.init_masks(self._critic_network_1) _ = sparse_utils.init_masks(self._critic_network_2) if actor_sparsity is not None: _ = sparse_utils.init_masks(self._actor_network, sparsity=actor_sparsity) else: _ = sparse_utils.init_masks(self._actor_network) net_observation_spec = time_step_spec.observation critic_spec = (net_observation_spec, action_spec) self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_1, None, input_spec=critic_spec, name='TargetCriticNetwork1')) self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_2, None, input_spec=critic_spec, name='TargetCriticNetwork2')) def critic_loss_fn(experience, weights): # The following is just to fit to the existing API. transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action return self._critic_loss_weight * self.critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights, training=True) def actor_loss_fn(experience, weights): # The following is just to fit to the existing API. transition = self._as_transition(experience) time_steps, _, _ = transition return self._actor_loss_weight*self.actor_loss( time_steps, weights=weights, training=True) # Create mask updater if doesn't exists self._mask_updater_critic_1 = mask_updaters.get_mask_updater( self._critic_network_1, self._critic_optimizer, critic_loss_fn) self._mask_updater_critic_2 = mask_updaters.get_mask_updater( self._critic_network_2, self._critic_optimizer, critic_loss_fn) self._mask_updater_actor = mask_updaters.get_mask_updater( self._actor_network, self._actor_optimizer, actor_loss_fn) def _train(self, experience, weights): """Returns a train op to update the agent's networks. This method trains with the provided batched experience. Args: experience: A time-stacked trajectory object. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: A train_op. Raises: ValueError: If optimizers are None and no default value was provided to the constructor. """ tf.summary.experimental.set_step(self.train_step_counter) transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action trainable_critic_variables = list(object_identity.ObjectIdentitySet( self._critic_network_1.trainable_variables + self._critic_network_2.trainable_variables)) with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_critic_variables, ('No trainable critic variables to ' 'optimize.') tape.watch(trainable_critic_variables) critic_loss = self._critic_loss_weight*self.critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights, training=True) tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.') critic_grads = tape.gradient(critic_loss, trainable_critic_variables) self._apply_gradients(critic_grads, trainable_critic_variables, self._critic_optimizer) trainable_actor_variables = self._actor_network.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_actor_variables, ('No trainable actor variables to ' 'optimize.') tape.watch(trainable_actor_variables) actor_loss = self._actor_loss_weight*self.actor_loss( time_steps, weights=weights, training=True) tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.') actor_grads = tape.gradient(actor_loss, trainable_actor_variables) self._apply_gradients(actor_grads, trainable_actor_variables, self._actor_optimizer) # BEGIN sparse training mask update # We use the lastest set of gradients to update the masks for sparse # training. Note, we do this before gradient clipping. # Define helper methods. def _mask_update_step(mask_updater, updater_name): mask_updater.set_validation_data(experience, weights) mask_updater.update(self.train_step_counter) with tf.name_scope('Drop_fraction/'): tf.summary.scalar( name=f'{updater_name}', data=mask_updater.last_drop_fraction) mask_update_step_critic_1 = functools.partial(_mask_update_step, self._mask_updater_critic_1, 'critic_1') mask_update_step_critic_2 = functools.partial(_mask_update_step, self._mask_updater_critic_2, 'critic_2') mask_update_step_actor = functools.partial(_mask_update_step, self._mask_updater_actor, 'actor') # Log sparsities every 1000 train steps. def _log_sparsities(): tf_sparse_utils.log_sparsities(self._critic_network_1, 'critic_1') tf_sparse_utils.log_sparsities(self._critic_network_2, 'critic_2') tf_sparse_utils.log_sparsities(self._actor_network, 'actor') tf_sparse_utils.log_total_params( [self._critic_network_1, self._critic_network_2, self._actor_network]) tf.cond(self.train_step_counter % 1000 == 0, _log_sparsities, lambda: None) # Update critics if self._mask_updater_critic_1 is not None: is_update_critic_1 = self._mask_updater_critic_1.is_update_iter( self.train_step_counter) tf.cond(is_update_critic_1, mask_update_step_critic_1, lambda: None) if self._mask_updater_critic_2 is not None: is_update_critic_2 = self._mask_updater_critic_2.is_update_iter( self.train_step_counter) tf.cond(is_update_critic_2, mask_update_step_critic_2, lambda: None) # Update actor if self._mask_updater_actor is not None: is_update_actor = self._mask_updater_actor.is_update_iter( self.train_step_counter) tf.cond(is_update_actor, mask_update_step_actor, lambda: None) # END sparse training mask update alpha_variable = [self._log_alpha] with tf.GradientTape(watch_accessed_variables=False) as tape: assert alpha_variable, 'No alpha variable to optimize.' tape.watch(alpha_variable) alpha_loss = self._alpha_loss_weight * self.alpha_loss( time_steps, weights=weights, training=True) tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.') alpha_grads = tape.gradient(alpha_loss, alpha_variable) self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar( name='critic_loss', data=critic_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='actor_loss', data=actor_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='alpha_loss', data=alpha_loss, step=self.train_step_counter) self.train_step_counter.assign_add(1) self._update_target() total_loss = critic_loss + actor_loss + alpha_loss extra = sac_agent.SacLossInfo( critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss) return tf_agent.LossInfo(loss=total_loss, extra=extra) @gin.configurable def train_eval( root_dir, strategy, env_name='HalfCheetah-v2', # Training params initial_collect_steps=10000, num_iterations=1000000, actor_fc_layers=(256, 256), critic_obs_fc_layers=None, critic_action_fc_layers=None, critic_joint_fc_layers=(256, 256), # Agent params batch_size=256, actor_learning_rate=3e-4, critic_learning_rate=3e-4, alpha_learning_rate=3e-4, gamma=0.99, target_update_tau=0.005, target_update_period=1, reward_scale_factor=0.1, # Replay params reverb_port=None, replay_capacity=1000000, # Others policy_save_interval=10000, replay_buffer_save_interval=100000, eval_interval=10000, eval_episodes=30, debug_summaries=False, summarize_grads_and_vars=False, sparse_output_layer = False, width = 1.0, train_mode_actor = 'dense', train_mode_value = 'dense', weight_decay = 0.0, actor_critic_sparsities_str = '', actor_critic_widths_str = ''): """Trains and evaluates SAC.""" assert FLAGS.is_mujoco if actor_critic_widths_str: actor_critic_widths = [float(s) for s in actor_critic_widths_str.split('_')] width_actor = actor_critic_widths[0] width_value = actor_critic_widths[1] else: width_actor = width width_value = width if actor_critic_sparsities_str: actor_critic_sparsities = [ float(s) for s in actor_critic_sparsities_str.split('_') ] else: # init_mask.sparsity value will be used. Either the default or set via gin. actor_critic_sparsities = [None, None] logging.info('Training SAC on: %s', env_name) logging.info('SAC params: train mode actor: %s', train_mode_actor) logging.info('SAC params: train mode value: %s', train_mode_value) logging.info('SAC params: sparse_output_layer: %s', sparse_output_layer) logging.info('SAC params: width: %s', width) logging.info('SAC params: actor_critic_widths_str: %s', actor_critic_widths_str) logging.info('SAC params: width_actor: %s', width_actor) logging.info('SAC params: width_value: %s', width_value) logging.info('SAC params: weight_decay: %s', weight_decay) logging.info('SAC params: actor_critic_sparsities_str %s type %s', actor_critic_sparsities_str, type(actor_critic_sparsities_str)) logging.info('SAC params: actor_sparsity: %s', actor_critic_sparsities[0]) logging.info('SAC params: critic_sparsity: %s', actor_critic_sparsities[1]) collect_env = suite_mujoco.load(env_name) eval_env = suite_mujoco.load(env_name) _, action_tensor_spec, time_step_tensor_spec = ( spec_utils.get_tensor_specs(collect_env)) actor_net = create_sequential_actor_network( actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec, input_dim=time_step_tensor_spec.observation.shape[0], is_sparse=(train_mode_actor == 'sparse'), width=width_actor, weight_decay=weight_decay, sparse_output_layer=sparse_output_layer) critic_input_dim = ( action_tensor_spec.shape[0] + time_step_tensor_spec.observation.shape[0]) critic_net = create_sequential_critic_network( obs_fc_layer_units=critic_obs_fc_layers, action_fc_layer_units=critic_action_fc_layers, joint_fc_layer_units=critic_joint_fc_layers, input_dim=critic_input_dim, is_sparse=(train_mode_value == 'sparse'), width=width_value, weight_decay=weight_decay, sparse_output_layer=sparse_output_layer) with strategy.scope(): train_step = train_utils.create_train_step() agent = SparseSacAgent( time_step_spec=time_step_tensor_spec, action_spec=action_tensor_spec, actor_sparsity=actor_critic_sparsities[0], critic_sparsity=actor_critic_sparsities[1], actor_network=actor_net, critic_network=critic_net, actor_optimizer=tf.keras.optimizers.Adam( learning_rate=actor_learning_rate), critic_optimizer=tf.keras.optimizers.Adam( learning_rate=critic_learning_rate), alpha_optimizer=tf.keras.optimizers.Adam( learning_rate=alpha_learning_rate), target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=tf.math.squared_difference, gamma=gamma, reward_scale_factor=reward_scale_factor, gradient_clipping=None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step) agent.initialize() table_name = 'uniform_table' table = reverb.Table( table_name, max_size=replay_capacity, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1)) reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR, learner.REPLAY_BUFFER_CHECKPOINT_DIR) reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer( path=reverb_checkpoint_dir) reverb_server = reverb.Server([table], port=reverb_port, checkpointer=reverb_checkpointer) reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( agent.collect_data_spec, sequence_length=2, table_name=table_name, local_server=reverb_server) rb_observer = reverb_utils.ReverbAddTrajectoryObserver( reverb_replay.py_client, table_name, sequence_length=2, stride_length=1) def experience_dataset_fn(): return reverb_replay.as_dataset( sample_batch_size=batch_size, num_steps=2).prefetch(50) saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) env_step_metric = py_metrics.EnvironmentSteps() learning_triggers = [ triggers.PolicySavedModelTrigger( saved_model_dir, agent, train_step, interval=policy_save_interval, metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}), triggers.ReverbCheckpointTrigger( train_step, interval=replay_buffer_save_interval, reverb_client=reverb_replay.py_client), triggers.StepPerSecondLogTrigger(train_step, interval=1000), ] agent_learner = learner.Learner( root_dir, train_step, agent, experience_dataset_fn, triggers=learning_triggers, strategy=strategy) random_policy = random_py_policy.RandomPyPolicy( collect_env.time_step_spec(), collect_env.action_spec()) initial_collect_actor = actor.Actor( collect_env, random_policy, train_step, steps_per_run=initial_collect_steps, observers=[rb_observer]) logging.info('Doing initial collect.') initial_collect_actor.run() tf_collect_policy = agent.collect_policy collect_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_collect_policy, use_tf_function=True) collect_actor = actor.Actor( collect_env, collect_policy, train_step, steps_per_run=1, metrics=actor.collect_metrics(10), summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), observers=[rb_observer, env_step_metric]) tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy) eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( tf_greedy_policy, use_tf_function=True) eval_actor = actor.Actor( eval_env, eval_greedy_policy, train_step, episodes_per_run=eval_episodes, metrics=actor.eval_metrics(eval_episodes), summary_dir=os.path.join(root_dir, 'eval'), ) average_returns = [] if eval_interval: logging.info('Evaluating.') eval_actor.run_and_log() for metric in eval_actor.metrics: if isinstance(metric, py_metrics.AverageReturnMetric): average_returns.append(metric._buffer.mean()) logging.info('Training.') for _ in range(num_iterations): collect_actor.run() agent_learner.run(iterations=1) if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: logging.info('Evaluating.') eval_actor.run_and_log() for metric in eval_actor.metrics: if isinstance(metric, py_metrics.AverageReturnMetric): average_returns.append(metric._buffer.mean()) # Log last section of evaluation scores for the final metric. idx = int(FLAGS.average_last_fraction * len(average_returns)) avg_return = np.mean(average_returns[-idx:]) logging.info('Step %d, Average Return: %f', env_step_metric.result(), avg_return) rb_observer.close() reverb_server.stop() def main(_): tf.config.run_functions_eagerly(False) logging.set_verbosity(logging.INFO) tf.compat.v1.enable_v2_behavior() strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) logging.info('Gin bindings: %s', FLAGS.gin_bindings) logging.info('# Gin-Config:\n %s', gin.config.operative_config_str()) train_eval( FLAGS.root_dir, strategy=strategy, reverb_port=FLAGS.reverb_port) if __name__ == '__main__': flags.mark_flag_as_required('root_dir') app.run(main) ================================================ FILE: rigl/rl/tfagents/sparse_encoding_network.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras Encoding Network. Implements a network that will generate the following layers: [optional]: preprocessing_layers # preprocessing_layers [optional]: (Add | Concat(axis=-1) | ...) # preprocessing_combiner [optional]: Conv2D # conv_layer_params Flatten [optional]: Dense # fc_layer_params """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl import logging import gin from rigl.rl.tfagents import tf_sparse_utils from six.moves import zip import tensorflow as tf from tf_agents.keras_layers import permanent_variable_rate_dropout from tf_agents.networks import network from tf_agents.networks import utils from tf_agents.utils import nest_utils CONV_TYPE_2D = '2d' CONV_TYPE_1D = '1d' def _copy_layer(layer): """Create a copy of a Keras layer with identical parameters. The new layer will not share weights with the old one. Args: layer: An instance of `tf.keras.layers.Layer`. Returns: A new keras layer. Raises: TypeError: If `layer` is not a keras layer. ValueError: If `layer` cannot be correctly cloned. """ if not isinstance(layer, tf.keras.layers.Layer): raise TypeError('layer is not a keras layer: %s' % str(layer)) # pylint:disable=unidiomatic-typecheck if type(layer) == tf.compat.v1.keras.layers.DenseFeatures: raise ValueError('DenseFeatures V1 is not supported. ' 'Use tf.compat.v2.keras.layers.DenseFeatures instead.') if layer.built: logging.warning( 'Beware: Copying a layer that has already been built: \'%s\'. ' 'This can lead to subtle bugs because the original layer\'s weights ' 'will not be used in the copy.', layer.name) # Get a fresh copy so we don't modify an incoming layer in place. Weights # will not be shared. return type(layer).from_config(layer.get_config()) @gin.configurable class EncodingNetwork(network.Network): """Feed Forward network with CNN and FNN layers.""" def __init__(self, input_tensor_spec, preprocessing_layers=None, preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=None, dropout_layer_params=None, activation_fn=tf.keras.activations.relu, weight_decay_params=None, kernel_initializer=None, batch_squash=True, dtype=tf.float32, name='EncodingNetwork', conv_type=CONV_TYPE_2D, width=1.0): """Creates an instance of `EncodingNetwork`. Network supports calls with shape outer_rank + input_tensor_spec.shape. Note outer_rank must be at least 1. For example an input tensor spec with shape `(2, 3)` will require inputs with at least a batch size, the input shape is `(?, 2, 3)`. Input preprocessing is possible via `preprocessing_layers` and `preprocessing_combiner` Layers. If the `preprocessing_layers` nest is shallower than `input_tensor_spec`, then the layers will get the subnests. For example, if: ```python input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5) preprocessing_layers = (Layer1(), Layer2()) ``` then preprocessing will call: ```python preprocessed = [preprocessing_layers[0](observations[0]), preprocessing_layers[1](observations[1])] ``` However if ```python preprocessing_layers = ([Layer1() for _ in range(2)], [Layer2() for _ in range(5)]) ``` then preprocessing will call: ```python preprocessed = [ layer(obs) for layer, obs in zip(flatten(preprocessing_layers), flatten(observations)) ] ``` **NOTE** `preprocessing_layers` and `preprocessing_combiner` are not allowed to have already been built. This ensures calls to `network.copy()` in the future always have an unbuilt, fresh set of parameters. Furtheremore, a shallow copy of the layers is always created by the Network, so the layer objects passed to the network are never modified. For more details of the semantics of `copy`, see the docstring of `tf_agents.networks.Network.copy`. Args: input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the input observations. preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer` representing preprocessing for the different observations. All of these layers must not be already built. preprocessing_combiner: (Optional.) A keras layer that takes a flat list of tensors and combines them. Good options include `tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`. This layer must not be already built. conv_layer_params: Optional list of convolution layers parameters, where each item is either a length-three tuple indicating `(filters, kernel_size, stride)` or a length-four tuple indicating `(filters, kernel_size, stride, dilation_rate)`. fc_layer_params: Optional list of fully_connected parameters, where each item is the number of units in the layer. dropout_layer_params: Optional list of dropout layer parameters, each item is the fraction of input units to drop or a dictionary of parameters according to the keras.Dropout documentation. The additional parameter `permanent`, if set to True, allows to apply dropout at inference for approximated Bayesian inference. The dropout layers are interleaved with the fully connected layers; there is a dropout layer after each fully connected layer, except if the entry in the list is None. This list must have the same length of fc_layer_params, or be None. activation_fn: Activation function, e.g. tf.keras.activations.relu. weight_decay_params: Optional list of weight decay parameters for the fully connected layers. kernel_initializer: Initializer to use for the kernels of the conv and dense layers. If none is provided a default variance_scaling_initializer batch_squash: If True the outer_ranks of the observation are squashed into the batch dimension. This allow encoding networks to be used with observations with shape [BxTx...]. dtype: The dtype to use by the convolution and fully connected layers. name: A string representing name of the network. conv_type: string, '1d' or '2d'. Convolution layers will be 1d or 2D respectively width: Scaling factor to apply to the layers. Raises: ValueError: If any of `preprocessing_layers` is already built. ValueError: If `preprocessing_combiner` is already built. ValueError: If the number of dropout layer parameters does not match the number of fully connected layer parameters. ValueError: If conv_layer_params tuples do not have 3 or 4 elements each. """ self._width = width flat_preprocessing_layers = None if (len(tf.nest.flatten(input_tensor_spec)) > 1 and preprocessing_combiner is None): raise ValueError( 'preprocessing_combiner layer is required when more than 1 ' 'input_tensor_spec is provided.') if preprocessing_combiner is not None: preprocessing_combiner = _copy_layer(preprocessing_combiner) if not kernel_initializer: kernel_initializer = tf.compat.v1.variance_scaling_initializer( scale=2.0, mode='fan_in', distribution='truncated_normal') layers = [] if conv_layer_params: if conv_type == '2d': conv_layer_type = tf.keras.layers.Conv2D elif conv_type == '1d': conv_layer_type = tf.keras.layers.Conv1D else: raise ValueError('unsupported conv type of %s. Use 1d or 2d' % ( conv_type)) for config in conv_layer_params: if len(config) == 4: (filters, kernel_size, strides, dilation_rate) = config elif len(config) == 3: (filters, kernel_size, strides) = config dilation_rate = (1, 1) if conv_type == '2d' else (1,) else: raise ValueError( 'only 3 or 4 elements permitted in conv_layer_params tuples') kernel_regularizer = None # We use the first weight decay param for all conv layers. weight_decay = weight_decay_params[0] if weight_decay is not None: kernel_regularizer = tf.keras.regularizers.l2(weight_decay) filters = tf_sparse_utils.scale_width(filters, self._width) layers.append( conv_layer_type( filters=filters, kernel_size=kernel_size, strides=strides, dilation_rate=dilation_rate, activation=activation_fn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, dtype=dtype)) layers.append(tf.keras.layers.Flatten()) if fc_layer_params: if dropout_layer_params is None: dropout_layer_params = [None] * len(fc_layer_params) else: if len(dropout_layer_params) != len(fc_layer_params): raise ValueError('Dropout and fully connected layer parameter lists' 'have different lengths (%d vs. %d.)' % (len(dropout_layer_params), len(fc_layer_params))) if weight_decay_params is None: weight_decay_params = [None] * len(fc_layer_params) else: if len(weight_decay_params) != len(fc_layer_params): raise ValueError('Weight decay and fully connected layer parameter ' 'lists have different lengths (%d vs. %d.)' % (len(weight_decay_params), len(fc_layer_params))) for num_units, dropout_params, weight_decay in zip( fc_layer_params, dropout_layer_params, weight_decay_params): kernel_regularizer = None if weight_decay is not None: kernel_regularizer = tf.keras.regularizers.l2(weight_decay) layers.append( tf.keras.layers.Dense( tf_sparse_utils.scale_width(num_units, self._width), activation=activation_fn, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, dtype=dtype)) if not isinstance(dropout_params, dict): dropout_params = {'rate': dropout_params} if dropout_params else None if dropout_params is not None: layers.append( permanent_variable_rate_dropout.PermanentVariableRateDropout( **dropout_params)) super(EncodingNetwork, self).__init__( input_tensor_spec=input_tensor_spec, state_spec=(), name=name) # Pull out the nest structure of the preprocessing layers. This avoids # saving the original kwarg layers as a class attribute which Keras would # then track. self._preprocessing_nest = tf.nest.map_structure(lambda l: None, preprocessing_layers) self._flat_preprocessing_layers = flat_preprocessing_layers self._preprocessing_combiner = preprocessing_combiner self._postprocessing_layers = layers self._batch_squash = batch_squash self.built = True # Allow access to self.variables def call(self, observation, step_type=None, network_state=(), training=False): del step_type # unused. if self._batch_squash: outer_rank = nest_utils.get_outer_rank( observation, self.input_tensor_spec) batch_squash = utils.BatchSquash(outer_rank) observation = tf.nest.map_structure(batch_squash.flatten, observation) if self._flat_preprocessing_layers is None: processed = observation else: raise ValueError('Flat preprocessing layers should be None.') states = processed if self._preprocessing_combiner is not None: states = self._preprocessing_combiner(states) for layer in self._postprocessing_layers: states = layer(states, training=training) if self._batch_squash: states = tf.nest.map_structure(batch_squash.unflatten, states) return states, network_state ================================================ FILE: rigl/rl/tfagents/sparse_ppo_actor_network.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Sequential Actor Network for PPO.""" import sys import numpy as np from rigl.rl.tfagents import tf_sparse_utils import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tf_agents.keras_layers import bias_layer from tf_agents.networks import nest_map from tf_agents.networks import sequential def tanh_and_scale_to_spec(inputs, spec): """Maps inputs with arbitrary range to range defined by spec using `tanh`.""" means = (spec.maximum + spec.minimum) / 2.0 magnitudes = (spec.maximum - spec.minimum) / 2.0 return means + magnitudes * tf.tanh(inputs) class PPOActorNetwork(): """Contains the actor network structure.""" def __init__(self, seed_stream_class=tfp.util.SeedStream, is_sparse=False, sparse_output_layer=False, weight_decay=0.0, width=1.0): self.seed_stream_class = seed_stream_class self._is_sparse = is_sparse self._sparse_output_layer = sparse_output_layer self._weight_decay = weight_decay self._width = width def create_sequential_actor_net(self, fc_layer_units, action_tensor_spec, input_dim, seed=None): """Helper method for creating the actor network.""" self._seed_stream = self.seed_stream_class( seed=seed, salt='tf_agents_sequential_layers') def _get_seed(): seed = self._seed_stream() if seed is not None: seed = seed % sys.maxsize return seed def create_dist(loc_and_scale): loc = loc_and_scale['loc'] loc = tanh_and_scale_to_spec(loc, action_tensor_spec) scale = loc_and_scale['scale'] scale = tf.math.softplus(scale) return tfp.distributions.MultivariateNormalDiag( loc=loc, scale_diag=scale, validate_args=True) def means_layers(): layer = tf.keras.layers.Dense( action_tensor_spec.shape.num_elements(), kernel_initializer=tf.keras.initializers.VarianceScaling( scale=0.1, seed=_get_seed()), kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay), name='means_projection_layer') return layer def std_layers(): std_bias_initializer_value = np.log(np.exp(0.35) - 1) return bias_layer.BiasLayer( bias_initializer=tf.constant_initializer( value=std_bias_initializer_value)) def no_op_layers(): return tf.keras.layers.Lambda(lambda x: x) def dense_layer(num_units): layer = tf.keras.layers.Dense( tf_sparse_utils.scale_width(num_units, self._width), activation=tf.nn.tanh, kernel_initializer=tf.keras.initializers.Orthogonal(seed=_get_seed()), kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay), ) return layer all_layers = [dense_layer(n) for n in fc_layer_units] all_layers.append(means_layers()) if self._is_sparse: if self._sparse_output_layer: all_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim) else: new_layers = tf_sparse_utils.wrap_all_layers(all_layers[:-1], input_dim) all_layers = new_layers + all_layers[-1:] return sequential.Sequential( all_layers + [tf.keras.layers.Lambda( lambda x: {'loc': x, 'scale': tf.zeros_like(x)})] + [nest_map.NestMap({ 'loc': no_op_layers(), 'scale': std_layers(), })] + # Create the output distribution from the mean and standard deviation. [tf.keras.layers.Lambda(create_dist)]) ================================================ FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""Sparse Discrete Sequential Actor Network for PPO.""" import functools import sys import numpy as np from rigl.rl.tfagents import tf_sparse_utils import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tf_agents.networks import sequential from tf_agents.specs import distribution_spec from tf_agents.specs import tensor_spec def tanh_and_scale_to_spec(inputs, spec): """Maps inputs with arbitrary range to range defined by spec using `tanh`.""" mean = (spec.maximum + spec.minimum) / 2.0 magnitude = spec.maximum - spec.minimum return mean + (magnitude * tf.tanh(inputs)) / 2.0 class PPODiscreteActorNetwork(): """Contains the actor network structure.""" def __init__(self, seed_stream_class=tfp.util.SeedStream, is_sparse=False, sparse_output_layer=False, weight_decay=0, width=1.0): if is_sparse: raise ValueError('This functionality is not enabled. wrap_all_layers,' 'functionality needs to be implemented') self.seed_stream_class = seed_stream_class # Sparse params. self._is_sparse = is_sparse self._sparse_output_layer = sparse_output_layer self._width = width self._weight_decay = weight_decay def create_sequential_actor_net(self, fc_layer_units, action_tensor_spec, logits_init_output_factor=0.1, seed=None): """Helper method for creating the actor network.""" self._seed_stream = self.seed_stream_class( seed=seed, salt='tf_agents_sequential_layers') # action_tensor_spec is a BoundedArraySpec which is an array with defined # bounds. Maximum and minimum are arrays with the same shape as the # main array. unique_num_actions = np.unique(action_tensor_spec.maximum - action_tensor_spec.minimum + 1) if len(unique_num_actions) > 1 or np.any(unique_num_actions <= 0): raise ValueError('Bounds on discrete actions must be the same for all ' 'dimensions and have at least 1 action. Projection ' 'Network requires num_actions to be equal across ' 'action dimensions. Implement a more general ' 'categorical projection if you need more flexibility.') output_shape = action_tensor_spec.shape.concatenate( [int(unique_num_actions)]) def _get_seed(): seed = self._seed_stream() if seed is not None: seed = seed % sys.maxsize return seed def create_dist(logits): input_param_spec = { 'logits': tensor_spec.TensorSpec( shape=(1,) + output_shape, dtype=tf.float32) } dist_spec = distribution_spec.DistributionSpec( tfp.distributions.Categorical, input_param_spec, sample_spec=action_tensor_spec, dtype=action_tensor_spec.dtype) logits = tf.reshape(logits, [-1] + output_shape.as_list()) return dist_spec.build_distribution(logits=logits) def dense_layer(num_units): dense = functools.partial( tf.keras.layers.Dense, activation=tf.nn.tanh, kernel_initializer=tf.keras.initializers.Orthogonal(seed=_get_seed()), kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay)) layer = dense(tf_sparse_utils.scale_width(num_units, self._width)) if self._is_sparse: return tf_sparse_utils.wrap_layer(layer) else: return layer output_layer = tf.keras.layers.Dense( output_shape.num_elements(), kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling( scale=logits_init_output_factor, seed=_get_seed()), kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay), bias_initializer=tf.keras.initializers.Zeros(), name='logits', dtype=tf.float32) if self._is_sparse and self._sparse_output_layer: output_layer = tf_sparse_utils.wrap_layer(output_layer) return sequential.Sequential( [dense_layer(num_units) for num_units in fc_layer_units] + [output_layer] + [tf.keras.layers.Lambda(create_dist)]) ================================================ FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for sparse_ppo_discrete_actor_network.""" from absl import flags from absl.testing import parameterized from rigl.rl.tfagents import sparse_ppo_discrete_actor_network import tensorflow as tf from tf_agents.distributions import utils as distribution_utils from tf_agents.specs import tensor_spec from tf_agents.utils import test_utils from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper FLAGS = flags.FLAGS class DeterministicSeedStream(object): """A fake seed stream class that always generates a deterministic seed.""" def __init__(self, seed, salt=''): del salt self._seed = seed def __call__(self): return self._seed class PpoActorNetworkTest(parameterized.TestCase, test_utils.TestCase): def setUp(self): super(PpoActorNetworkTest, self).setUp() # Run in full eager mode in order to inspect the content of tensors. tf.config.experimental_run_functions_eagerly(True) self.observation_tensor_spec = tf.TensorSpec(shape=[3], dtype=tf.float32) self.action_tensor_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 3) def tearDown(self): tf.config.experimental_run_functions_eagerly(False) super(PpoActorNetworkTest, self).tearDown() def _init_network( self, is_sparse=False, sparse_output_layer=False, width=1.0, weight_decay=0): actor_net_lib = sparse_ppo_discrete_actor_network.PPODiscreteActorNetwork( is_sparse=is_sparse, sparse_output_layer=sparse_output_layer, width=width, weight_decay=weight_decay) actor_net_lib.seed_stream_class = DeterministicSeedStream return actor_net_lib.create_sequential_actor_net( fc_layer_units=(1,), action_tensor_spec=self.action_tensor_spec, seed=1) def test_no_mismatched_shape(self): actor_net = self._init_network() actor_output_spec = actor_net.create_variables(self.observation_tensor_spec) distribution_utils.assert_specs_are_compatible( actor_output_spec, self.action_tensor_spec, 'actor_network output spec does not match action spec') @parameterized.named_parameters( ('dense-output-F', False, False, (tf.keras.layers.Dense, tf.keras.layers.Dense)), ('dense-output-T', False, True, (tf.keras.layers.Dense, tf.keras.layers.Dense)), ('sparse-all', True, True, (pruning_wrapper.PruneLowMagnitude, pruning_wrapper.PruneLowMagnitude)), ('sparse-outp-dense', True, False, (pruning_wrapper.PruneLowMagnitude, tf.keras.layers.Dense)), ) def test_is_sparse(self, is_sparse, sparse_output_layer, expected_layers): expected_units = (1, 4) actor_net = self._init_network( is_sparse=is_sparse, sparse_output_layer=sparse_output_layer) for i, (expected_layer, exp_units) in enumerate( zip(expected_layers, expected_units)): layer = actor_net.layers[i] self.assertIsInstance(layer, expected_layer) if isinstance(layer, pruning_wrapper.PruneLowMagnitude): self.assertEqual(layer.layer.units, exp_units) else: self.assertEqual(layer.units, exp_units) def test_width_scaling(self): with self.subTest('dense'): actor_net = self._init_network(width=2.0) self.assertEqual(actor_net.layers[0].units, 2) self.assertEqual(actor_net.layers[1].units, 4) with self.subTest('sparse'): actor_net = self._init_network( is_sparse=True, sparse_output_layer=True, width=2.0) self.assertEqual(actor_net.layers[0].layer.units, 2) self.assertEqual(actor_net.layers[1].layer.units, 4) @parameterized.named_parameters( ('no-wd-d-d', False, False, 0), ('no-wd-s-d', True, False, 0), ('no-wd-s-s', True, True, 0), ('wd-d-d', False, False, 0.1), ('wd-s-d', True, False, 0.1), ('wd-s-s', True, True, 0.1)) def test_weight_decay(self, is_sparse, sparse_output_layer, expected_weight_decay): actor_net = self._init_network(is_sparse=is_sparse, sparse_output_layer=sparse_output_layer, weight_decay=expected_weight_decay) for i in range(2): layer = actor_net.layers[i] if isinstance(layer, pruning_wrapper.PruneLowMagnitude): l2_weight_decay = layer.layer.kernel_regularizer.get_config()['l2'] else: l2_weight_decay = layer.kernel_regularizer.get_config()['l2'] self.assertAlmostEqual(l2_weight_decay, expected_weight_decay) if __name__ == '__main__': tf.test.main() ================================================ FILE: rigl/rl/tfagents/sparse_tanh_normal_projection_network.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Project inputs to a tanh-squashed MultivariateNormalDiag distribution. This network reproduces Soft Actor-Critic refererence implementation in: https://github.com/rail-berkeley/softlearning/ """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from typing import Callable, Optional, Text import gin import tensorflow as tf from tf_agents.agents.sac import tanh_normal_projection_network from tf_agents.typing import types @gin.configurable class SparseTanhNormalProjectionNetwork( tanh_normal_projection_network.TanhNormalProjectionNetwork): """Generates a tanh-squashed MultivariateNormalDiag distribution. Note: Due to the nature of the `tanh` function, values near the spec bounds cannot be returned. """ def __init__(self, sample_spec, activation_fn = None, std_transform = tf.exp, name = 'SparseTanhNormalProjectionNetwork', weight_decay=0.0): """Creates an instance of SparseTanhNormalProjectionNetwork. Args: sample_spec: A `tensor_spec.BoundedTensorSpec` detailing the shape and dtypes of samples pulled from the output distribution. activation_fn: Activation function to use in dense layer. std_transform: Transformation function to apply to the stddevs. name: A string representing name of the network. weight_decay: Weight decay for L2 regularization. """ super(SparseTanhNormalProjectionNetwork, self).__init__( sample_spec=sample_spec, activation_fn=activation_fn, std_transform=std_transform, name=name) # We reinitialize the projection layer with L2 regularization and also # optionally sparsify it. self._projection_layer = tf.keras.layers.Dense( sample_spec.shape.num_elements() * 2, activation=activation_fn, kernel_regularizer=tf.keras.regularizers.L2(weight_decay), name='projection_layer') ================================================ FILE: rigl/rl/tfagents/sparse_value_network.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sample Keras Value Network. Implements a network that will generate the following layers: [optional]: preprocessing_layers # preprocessing_layers [optional]: (Add | Concat(axis=-1) | ...) # preprocessing_combiner [optional]: Conv2D # conv_layer_params Flatten [optional]: Dense # fc_layer_params Dense -> 1 # Value output """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import gin from rigl.rl.tfagents import sparse_encoding_network from rigl.rl.tfagents import tf_sparse_utils import tensorflow as tf from tf_agents.networks import network @gin.configurable class ValueNetwork(network.Network): """Feed Forward value network. Reduces to 1 value output per batch item.""" def __init__(self, input_tensor_spec, preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=(75, 40), dropout_layer_params=None, weight_decay=0.0, activation_fn=tf.keras.activations.relu, kernel_initializer=None, batch_squash=True, dtype=tf.float32, name='ValueNetwork', is_sparse=False, sparse_output_layer=False, width=1.0): """Creates an instance of `ValueNetwork`. Network supports calls with shape outer_rank + observation_spec.shape. Note outer_rank must be at least 1. Args: input_tensor_spec: A `tensor_spec.TensorSpec` or a tuple of specs representing the input observations. preprocessing_combiner: (Optional.) A keras layer that takes a flat list of tensors and combines them. Good options include `tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`. This layer must not be already built. For more details see the documentation of `networks.EncodingNetwork`. conv_layer_params: Optional list of convolution layers parameters, where each item is a length-three tuple indicating (filters, kernel_size, stride). fc_layer_params: Optional list of fully_connected parameters, where each item is the number of units in the layer. dropout_layer_params: Optional list of dropout layer parameters, each item is the fraction of input units to drop or a dictionary of parameters according to the keras.Dropout documentation. The additional parameter `permanent`, if set to True, allows to apply dropout at inference for approximated Bayesian inference. The dropout layers are interleaved with the fully connected layers; there is a dropout layer after each fully connected layer, except if the entry in the list is None. This list must have the same length of fc_layer_params, or be None. weight_decay: L2 weight decay regularization parameter. activation_fn: Activation function, e.g. tf.keras.activations.relu,. kernel_initializer: Initializer to use for the kernels of the conv and dense layers. If none is provided a default variance_scaling_initializer batch_squash: If True the outer_ranks of the observation are squashed into the batch dimension. This allow encoding networks to be used with observations with shape [BxTx...]. dtype: The dtype to use by the convolution and fully connected layers. name: A string representing name of the network. is_sparse: Whether the network is sparse. sparse_output_layer: Whether the output layer should be sparse. Only applied when is_sparse=True. width: Scaling factor to apply to the layers. Raises: ValueError: If input_tensor_spec is not an instance of network.InputSpec. """ super(ValueNetwork, self).__init__( input_tensor_spec=input_tensor_spec, state_spec=(), name=name) self._is_sparse = is_sparse self._sparse_output_layer = sparse_output_layer self._width = width if not kernel_initializer: kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform() self._encoder = sparse_encoding_network.EncodingNetwork( input_tensor_spec, preprocessing_layers=None, preprocessing_combiner=preprocessing_combiner, conv_layer_params=conv_layer_params, fc_layer_params=fc_layer_params, dropout_layer_params=dropout_layer_params, activation_fn=activation_fn, weight_decay_params=[weight_decay] * len(fc_layer_params), kernel_initializer=kernel_initializer, batch_squash=batch_squash, dtype=dtype, width=self._width) self._postprocessing_layers = tf.keras.layers.Dense( 1, activation=None, kernel_initializer=tf.random_uniform_initializer( minval=-0.03, maxval=0.03), kernel_regularizer=tf.keras.regularizers.L2(weight_decay)) if is_sparse: layers_to_wrap = [l for l in self._encoder._postprocessing_layers if tf_sparse_utils.is_valid_layer_to_wrap(l)] input_dim = input_tensor_spec.shape[0] if sparse_output_layer: layers_to_wrap.append(self._postprocessing_layers) wrapped_layers = tf_sparse_utils.wrap_all_layers( layers_to_wrap, input_dim) self._postprocessing_layers = wrapped_layers[-1] wrapped_layers = wrapped_layers[:-1] else: wrapped_layers = tf_sparse_utils.wrap_all_layers( layers_to_wrap, input_dim) # We need to recreate the original layer list after wrapping the layers. new_layer_list = [] i = 0 for unwrapped_layer in self._encoder._postprocessing_layers: if tf_sparse_utils.is_valid_layer_to_wrap(unwrapped_layer): new_layer_list.append(wrapped_layers[i]) i += 1 else: new_layer_list.append(unwrapped_layer) self._encoder._postprocessing_layers = new_layer_list def call(self, observation, step_type=None, network_state=(), training=False): state, network_state = self._encoder( observation, step_type=step_type, network_state=network_state, training=training) value = self._postprocessing_layers(state, training=training) return tf.squeeze(value, -1), network_state ================================================ FILE: rigl/rl/tfagents/tf_sparse_utils.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for sparse tf agents training.""" import re from absl import logging import gin from rigl import sparse_utils as sparse_utils_rigl from rigl.rl import sparse_utils import tensorflow.compat.v2 as tf from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper PRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude _LAYER_TYPES_TO_WRAP = (tf.keras.layers.Dense, tf.keras.layers.Conv2D, tf.keras.layers.Conv1D) def log_total_params(networks): total_params = 0 for net in networks: total_net_params, _ = sparse_utils.get_total_params(net) total_params += total_net_params with tf.name_scope('Params/'): tf.compat.v2.summary.scalar('total', total_params) def scale_width(num_units, width): assert width > 0 return int(max(1, num_units * width)) @gin.configurable def wrap_all_layers(layers, input_dim, mode='constant', mask_init_method='erdos_renyi_kernel', initial_sparsity=0.0, final_sparsity=0.9, begin_step=200000, end_step=600000, frequency=10000): """Wraps a list of dense keras layers to be used by sparse training.""" # We only need to define static masks here, we will update them through # mask updater later. new_layers = [] if mode == 'constant': for layer in layers: schedule = pruning_schedule.ConstantSparsity( target_sparsity=0, begin_step=1000000000) new_layers.append(PRUNING_WRAPPER(layer, pruning_schedule=schedule)) elif mode == 'prune': logging.info('Pruning schedule: initial sparsity: %f', initial_sparsity) logging.info('Pruning schedule: mask_init_method: %s', mask_init_method) logging.info('Pruning schedule: final sparsity: %f', final_sparsity) logging.info('Pruning schedule: begin step: %f', begin_step) logging.info('Pruning schedule: end step: %f', end_step) logging.info('Pruning schedule: frequency: %f', frequency) # Create dummy masks to get layer-wise sparsities. This is because the # get_sparsities function expects mask variables to calculate the # sparsities. dummy_masks_dict = {} layer_input_dim = input_dim for layer in layers: mask = tf.Variable(tf.ones([layer_input_dim, layer.units]), trainable=False, name=f'dummymask_{layer.name}') layer_input_dim = layer.units dummy_masks_dict[layer.name] = mask # Get layer-wise sparsities. extract_name_fn = lambda x: re.findall('(.+):0', x)[0] reverse_dict = {v.name: k for k, v in dummy_masks_dict.items()} sparsity_dict = sparse_utils_rigl.get_sparsities( list(dummy_masks_dict.values()), mask_init_method, final_sparsity, custom_sparsity_map={}, extract_name_fn=extract_name_fn) # This dict will have {layer_name: layer_sparsity} renamed_sparsity_dict = {reverse_dict[k]: float(v) for k, v in sparsity_dict.items()} # Wrap layers with possibly non-uniform pruning schedule. for layer in layers: sparsity = renamed_sparsity_dict[layer.name] logging.info('Layer: %s, sparsity: %f', layer.name, sparsity) schedule = pruning_schedule.PolynomialDecay( initial_sparsity=initial_sparsity, final_sparsity=sparsity, begin_step=begin_step, end_step=end_step, frequency=frequency) new_layers.append(PRUNING_WRAPPER(layer, pruning_schedule=schedule)) return new_layers @gin.configurable def wrap_layer(layer, mode='constant', initial_sparsity=0.0, final_sparsity=0.9, begin_step=200000, end_step=600000, frequency=10000): """Wraps a keras layer to be used by sparse training.""" # We only need to define static masks here, we will update them through # mask updater later. if mode == 'constant': schedule = pruning_schedule.ConstantSparsity( target_sparsity=0, begin_step=1000000000) elif mode == 'prune': logging.info('Pruning schedule: initial sparsity: %f', initial_sparsity) logging.info('Pruning schedule: final sparsity: %f', final_sparsity) logging.info('Pruning schedule: begin step: %f', begin_step) logging.info('Pruning schedule: end step: %f', end_step) logging.info('Pruning schedule: frequency: %f', frequency) schedule = pruning_schedule.PolynomialDecay( initial_sparsity=initial_sparsity, final_sparsity=final_sparsity, begin_step=begin_step, end_step=end_step, frequency=frequency) return PRUNING_WRAPPER(layer, pruning_schedule=schedule) def is_valid_layer_to_wrap(layer): for layer_type in _LAYER_TYPES_TO_WRAP: if isinstance(layer, layer_type): return True return False @gin.configurable def log_sparsities(model, model_name='q_net', log_images=False): """Logs relevant sparsity stats to tensorboard.""" for layer in sparse_utils.get_all_pruning_layers(model): for _, mask, threshold in layer.pruning_vars: if log_images: reshaped_mask = tf.expand_dims(tf.expand_dims(mask, 0), -1) with tf.name_scope('Masks/'): tf.compat.v2.summary.image(f'{model_name}/{mask.name}', reshaped_mask) with tf.name_scope('Sparsity/'): sparsity = 1 - tf.reduce_mean(mask) tf.compat.v2.summary.scalar(f'{model_name}/{mask.name}', sparsity) with tf.name_scope('Threshold/'): tf.compat.v2.summary.scalar(f'{model_name}/{threshold.name}', threshold) total_params, nparam_dict = sparse_utils.get_total_params(model) with tf.name_scope('Params/'): tf.compat.v2.summary.scalar(f'{model_name}/total', total_params) for k, val in nparam_dict.items(): tf.compat.v2.summary.scalar(f'{model_name}/' + k, val) def update_prune_step(model, step): for layer in sparse_utils.get_all_pruning_layers(model): # Assign iteration count to the layer pruning_step. layer.pruning_step.assign(step) def flatten_list_of_vars(var_list): flat_vars = [tf.reshape(v, [-1]) for v in var_list] return tf.concat(flat_vars, axis=-1) @gin.configurable def log_snr(tape, loss, step, variables_to_train, freq=1000): """Given a gradient tape and loss, it logs signal-to-noise ratio.""" def true_fn(): grads_per_sample = tape.jacobian(loss, variables_to_train) list_of_snrs = [] for grad in grads_per_sample: if grad is not None: if isinstance(grad, tf.IndexedSlices): grad_values = grad.values else: grad_values = grad grad_mean = tf.math.reduce_mean(grad_values, axis=0) grad_std = tf.math.reduce_std(grad_values, axis=0) list_of_snrs.append(tf.abs(grad_mean / (grad_std + 1e-10))) snr_mean = tf.reduce_mean(flatten_list_of_vars(list_of_snrs)) snr_std = tf.math.reduce_std((flatten_list_of_vars(list_of_snrs))) with tf.name_scope('SNR/'): tf.compat.v2.summary.scalar(name='mean', data=snr_mean, step=step) tf.compat.v2.summary.scalar(name='std', data=snr_std, step=step) tf.cond(step % freq == 0, true_fn, lambda: None) ================================================ FILE: rigl/rl/train.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. r"""The entry point for training a sparse DQN agent.""" import os from absl import app from absl import flags import gin from rigl.rl import run_experiment import tensorflow as tf flags.DEFINE_string('base_dir', None, 'Base directory to host all required sub-directories.') flags.DEFINE_multi_string( 'gin_files', [], 'List of paths to gin configuration files.') flags.DEFINE_multi_string( 'gin_bindings', [], 'Gin bindings to override the values set in the config files ' '(e.g. "DQNAgent.epsilon_train=0.1",' ' "create_atari_environment.game_name="Pong"").') FLAGS = flags.FLAGS def create_sparsetrain_runner(base_dir): assert base_dir is not None return run_experiment.SparseTrainRunner(base_dir) def main(unused_argv): gin.parse_config_files_and_bindings(FLAGS.gin_files, FLAGS.gin_bindings) runner = create_sparsetrain_runner(FLAGS.base_dir) runner.run_experiment() logconfigfile_path = os.path.join(FLAGS.base_dir, 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str()) if __name__ == '__main__': flags.mark_flag_as_required('base_dir') app.run(main) ================================================ FILE: rigl/sparse_optimizers.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This module implements some common and new sparse training algorithms.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import numpy as np from rigl import sparse_optimizers_base as sparse_opt_base from rigl import sparse_utils from tensorflow.contrib.model_pruning.python import pruning from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.training import moving_averages from tensorflow.python.training import optimizer as tf_optimizer from tensorflow.python.training import training_util class PruningGetterTf1Mixin: """Tf1 model_pruning library specific variable retrieval.""" def get_weights(self): return pruning.get_weights() def get_masks(self): return pruning.get_masks() def get_masked_weights(self): return pruning.get_masked_weights() class SparseSETOptimizer(PruningGetterTf1Mixin, sparse_opt_base.SparseSETOptimizerBase): pass class SparseRigLOptimizer(PruningGetterTf1Mixin, sparse_opt_base.SparseRigLOptimizerBase): pass class SparseStaticOptimizer(SparseSETOptimizer): """Sparse optimizer that re-initializes weak connections during training. Attributes: optimizer: tf.train.Optimizer begin_step: int, first iteration where masks are updated. end_step: int, iteration after which no mask is updated. frequency: int, of mask update operations. drop_fraction: float, of connections to drop during each update. drop_fraction_anneal: str or None, if supplied used to anneal the drop fraction. use_locking: bool, passed to the super. grow_init: str, name of the method used to initialize new connections. momentum: float, for the exponentialy moving average. name: bool, passed to the super. """ def __init__(self, optimizer, begin_step, end_step, frequency, drop_fraction=0.1, drop_fraction_anneal='constant', use_locking=False, grow_init='zeros', name='SparseStaticOptimizer', stateless_seed_offset=0): super(SparseStaticOptimizer, self).__init__( optimizer, begin_step, end_step, frequency, drop_fraction=drop_fraction, drop_fraction_anneal=drop_fraction_anneal, grow_init=grow_init, use_locking=use_locking, name=name, stateless_seed_offset=stateless_seed_offset) def generic_mask_update(self, mask, weights, noise_std=1e-5): """True branch of the condition, updates the mask.""" # Ensure that the weights are masked. masked_weights = mask * weights score_drop = math_ops.abs(masked_weights) # Add noise for slight bit of randomness. score_drop += self._random_normal( score_drop.shape, stddev=noise_std, dtype=score_drop.dtype, seed=hash(weights.name + 'drop')) # Revive n_prune many connections using momentum. score_grow = mask return self._get_update_op( score_drop, score_grow, mask, weights, reinit_when_same=True) class SparseMomentumOptimizer(SparseSETOptimizer): """Sparse optimizer that grows connections with the expected gradients. A simplified implementation of Momentum based sparse optimizer. No redistribution of sparsity. Original implementation: https://github.com/TimDettmers/sparse_learning/blob/master/mnist_cifar/main.py Attributes: optimizer: tf.train.Optimizer begin_step: int, first iteration where masks are updated. end_step: int, iteration after which no mask is updated. frequency: int, of mask update operations. drop_fraction: float, of connections to drop during each update. drop_fraction_anneal: str or None, if supplied used to anneal the drop fraction. use_locking: bool, passed to the super. grow_init: str, name of the method used to initialize new connections. momentum: float, for the exponentialy moving average. use_tpu: bool, if true the masked_gradients are aggregated. name: bool, passed to the super. """ def __init__(self, optimizer, begin_step, end_step, frequency, drop_fraction=0.1, drop_fraction_anneal='constant', use_locking=False, grow_init='zeros', momentum=0.9, use_tpu=False, name='SparseMomentumOptimizer', stateless_seed_offset=0): super(SparseMomentumOptimizer, self).__init__( optimizer, begin_step, end_step, frequency, drop_fraction=drop_fraction, drop_fraction_anneal=drop_fraction_anneal, grow_init=grow_init, use_locking=use_locking, name='SparseMomentumOptimizer', stateless_seed_offset=stateless_seed_offset) self._ema_grads = moving_averages.ExponentialMovingAverage(decay=momentum) self._use_tpu = use_tpu def set_masked_grads(self, grads, weights): if self._use_tpu: grads = [tpu_ops.cross_replica_sum(g) for g in grads] self._masked_grads = grads # Using names since better to hash. self._weight2masked_grads = {w.name: m for w, m in zip(weights, grads)} def compute_gradients(self, loss, **kwargs): """Wraps the compute gradient of passed optimizer.""" grads_and_vars = self._optimizer.compute_gradients(loss, **kwargs) # Need to update the EMA of the masked_weights. This is a bit hacky and # might not work as expected if the gradients are not applied after every # calculation. However, it should be fine if only .minimize() call is used. masked_grads_vars = self._optimizer.compute_gradients( loss, var_list=self.get_masked_weights()) masked_grads = [g for g, _ in masked_grads_vars] self.set_masked_grads(masked_grads, self.get_weights()) return grads_and_vars def _before_apply_gradients(self, grads_and_vars): """Updates momentum before updating the weights with gradient.""" return self._ema_grads.apply(self._masked_grads) def generic_mask_update(self, mask, weights, noise_std=1e-5): """True branch of the condition, updates the mask.""" # Ensure that the weights are masked. casted_mask = math_ops.cast(mask, dtypes.float32) masked_weights = casted_mask * weights score_drop = math_ops.abs(masked_weights) # Add noise for slight bit of randomness. score_drop += self._random_normal( score_drop.shape, stddev=noise_std, dtype=score_drop.dtype, seed=hash(weights.name + 'drop')) # Revive n_prune many connections using momentum. masked_grad = self._weight2masked_grads[weights.name] score_grow = math_ops.abs(self._ema_grads.average(masked_grad)) return self._get_update_op(score_drop, score_grow, mask, weights) class SparseSnipOptimizer(tf_optimizer.Optimizer): """Implementation of dynamic sparsity optimizers. Implementation of Snip https://arxiv.org/abs/1810.02340 Attributes: optimizer: tf.train.Optimizer default_sparsity: float, between 0 and 1. mask_init_method: str, used to determine mask initializations. custom_sparsity_map: dict, key/value pairs where the mask correspond whose name is '{key}/mask:0' is set to the corresponding sparsity value. use_locking: bool, passed to the super. use_tpu: bool, if true the masked_gradients are aggregated. name: bool, passed to the super. """ def __init__(self, optimizer, default_sparsity, mask_init_method, custom_sparsity_map=None, use_locking=False, use_tpu=False, name='SparseSnipOptimizer'): super(SparseSnipOptimizer, self).__init__(use_locking, name) if not custom_sparsity_map: custom_sparsity_map = {} self._optimizer = optimizer self._use_tpu = use_tpu self._default_sparsity = default_sparsity self._mask_init_method = mask_init_method self._custom_sparsity_map = custom_sparsity_map self.is_snipped = variable_scope.get_variable( 'is_snipped', initializer=lambda: False, trainable=False) def compute_gradients(self, loss, **kwargs): """Wraps the compute gradient of passed optimizer.""" return self._optimizer.compute_gradients(loss, **kwargs) def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Wraps the original apply_gradient of the optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. """ def apply_gradient_op(): return self._optimizer.apply_gradients( grads_and_vars, global_step=global_step, name=name) maybe_reduce = lambda x: x if self._use_tpu: maybe_reduce = tpu_ops.cross_replica_sum grads_and_vars_dict = { re.findall('(.+)/weights:0', var.name)[0]: (maybe_reduce(grad), var) for grad, var in grads_and_vars if var.name.endswith('weights:0') } def snip_fn(mask, sparsity, dtype): """Creates a random sparse mask with deterministic sparsity. Args: mask: tf.Tensor, used to obtain correct corresponding gradient. sparsity: float, between 0 and 1. dtype: tf.dtype, type of the return value. Returns: tf.Tensor """ del dtype var_name = sparse_utils.mask_extract_name_fn(mask.name) g, v = grads_and_vars_dict[var_name] score_drop = math_ops.abs(g * v) n_total = np.prod(score_drop.shape.as_list()) n_prune = sparse_utils.get_n_zeros(n_total, sparsity) n_keep = n_total - n_prune # Sort the entire array since the k needs to be constant for TPU. _, sorted_indices = nn_ops.top_k( array_ops.reshape(score_drop, [-1]), k=n_total) sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1) # We will have zeros after having `n_keep` many ones. new_values = array_ops.where( math_ops.range(n_total) < n_keep, array_ops.ones_like(sorted_indices, dtype=mask.dtype), array_ops.zeros_like(sorted_indices, dtype=mask.dtype)) new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values, new_values.shape) return array_ops.reshape(new_mask, mask.shape) def snip_op(): all_masks = pruning.get_masks() assigner = sparse_utils.get_mask_init_fn( all_masks, self._mask_init_method, self._default_sparsity, self._custom_sparsity_map, mask_fn=snip_fn) with ops.control_dependencies([assigner]): assign_op = state_ops.assign( self.is_snipped, True, name='assign_true_after_snipped') return assign_op maybe_snip_op = control_flow_ops.cond( math_ops.logical_and( math_ops.equal(global_step, 0), math_ops.logical_not(self.is_snipped)), snip_op, apply_gradient_op) return maybe_snip_op class SparseDNWOptimizer(tf_optimizer.Optimizer): """Implementation of DNW optimizer. Implementation of DNW. See https://arxiv.org/pdf/1906.00586.pdf This optimizer ensures the mask is updated at every iteration, according to the current set of weights. It uses dense gradient to update weights. Attributes: optimizer: tf.train.Optimizer default_sparsity: float, between 0 and 1. mask_init_method: str, used to determine mask initializations. custom_sparsity_map: dict, key/value pairs where the mask correspond whose name is '{key}/mask:0' is set to the corresponding sparsity value. use_tpu: bool, if true the masked_gradients are aggregated. use_locking: bool, passed to the super. name: bool, passed to the super. """ def __init__(self, optimizer, default_sparsity, mask_init_method, custom_sparsity_map=None, use_tpu=False, use_locking=False, name='SparseDNWOptimizer'): super(SparseDNWOptimizer, self).__init__(use_locking, name) self._optimizer = optimizer self._use_tpu = use_tpu self._default_sparsity = default_sparsity self._mask_init_method = mask_init_method self._custom_sparsity_map = custom_sparsity_map def compute_gradients(self, loss, var_list=None, **kwargs): """Wraps the compute gradient of passed optimizer.""" # Replace masked variables with masked_weights so that the gradient is dense # and not masked if var_list is None: var_list = ( variables.trainable_variables() + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) var_list = self.replace_with_masked_weights(var_list) grads_and_vars = self._optimizer.compute_gradients( loss, var_list=var_list, **kwargs) return self.replace_masked_weights(grads_and_vars) def replace_with_masked_weights(self, var_list): """Replaces masked variables with masked weights.""" weight2masked_weights = { w.name: mw for w, mw in zip(self.get_weights(), self.get_masked_weights()) } updated_var_list = [weight2masked_weights.get(w.name, w) for w in var_list] return updated_var_list def replace_masked_weights(self, grads_and_vars): """Replaces masked weight tensords with weight variables.""" masked_weights2weight = { mw.name: w for w, mw in zip(self.get_weights(), self.get_masked_weights()) } updated_grads_and_vars = [ (g, masked_weights2weight.get(w.name, w)) for g, w in grads_and_vars ] return updated_grads_and_vars def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Wraps the original apply_gradient of the optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. """ optimizer_update = self._optimizer.apply_gradients( grads_and_vars, global_step=global_step, name=name) vars_dict = { re.findall('(.+)/weights:0', var.name)[0]: var for var in self.get_weights() } def dnw_fn(mask, sparsity, dtype): """Creates a mask with smallest magnitudes with deterministic sparsity. Args: mask: tf.Tensor, used to obtain correct corresponding gradient. sparsity: float, between 0 and 1. dtype: tf.dtype, type of the return value. Returns: tf.Tensor """ del dtype var_name = sparse_utils.mask_extract_name_fn(mask.name) v = vars_dict[var_name] score_drop = math_ops.abs(v) n_total = np.prod(score_drop.shape.as_list()) n_prune = sparse_utils.get_n_zeros(n_total, sparsity) n_keep = n_total - n_prune # Sort the entire array since the k needs to be constant for TPU. _, sorted_indices = nn_ops.top_k( array_ops.reshape(score_drop, [-1]), k=n_total) sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1) # We will have zeros after having `n_keep` many ones. new_values = array_ops.where( math_ops.range(n_total) < n_keep, array_ops.ones_like(sorted_indices, dtype=mask.dtype), array_ops.zeros_like(sorted_indices, dtype=mask.dtype)) new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values, new_values.shape) return array_ops.reshape(new_mask, mask.shape) with ops.control_dependencies([optimizer_update]): all_masks = self.get_masks() mask_update_op = sparse_utils.get_mask_init_fn( all_masks, self._mask_init_method, self._default_sparsity, self._custom_sparsity_map, mask_fn=dnw_fn) return mask_update_op def get_weights(self): return pruning.get_weights() def get_masks(self): return pruning.get_masks() def get_masked_weights(self): return pruning.get_masked_weights() ================================================ FILE: rigl/sparse_optimizers_base.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This module implements some common and new sparse training algorithms.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import six from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.tpu.ops import tpu_ops from tensorflow.python.training import learning_rate_decay from tensorflow.python.training import optimizer as tf_optimizer from tensorflow.python.training import training_util def extract_number(token): """Strips the number from the end of the token if it exists. Args: token: str, s or s_d where d is a number: a float or int. `foo_.5`, `foo_foo.5`, `foo_0.5`, `foo_4` are all valid strings. Returns: float, d if exists otherwise 1. """ regexp = re.compile(r'.*_(\d*\.?\d*)$') if regexp.search(token): return float(regexp.search(token).group(1)) else: return 1. class SparseSETOptimizerBase(tf_optimizer.Optimizer): """Implementation of dynamic sparsity optimizers. Implementation of SET. See https://www.nature.com/articles/s41467-018-04316-3 This optimizer wraps a regular optimizer and performs updates on the masks according to schedule given. Attributes: optimizer: tf.train.Optimizer begin_step: int, first iteration where masks are updated. end_step: int, iteration after which no mask is updated. frequency: int, of mask update operations. drop_fraction: float, of connections to drop during each update. drop_fraction_anneal: str or None, if supplied used to anneal the drop fraction. use_locking: bool, passed to the super. grow_init: str, name of the method used to initialize new connections. name: bool, passed to the super. use_stateless: bool, if True stateless operations are used. This is important for multi-worker jobs not to diverge. stateless_seed_offset: int, added to the seed of stateless operations. Use this to create randomness without divergence across workers. """ def __init__(self, optimizer, begin_step, end_step, frequency, drop_fraction=0.1, drop_fraction_anneal='constant', use_locking=False, grow_init='zeros', name='SparseSETOptimizer', use_stateless=True, stateless_seed_offset=0): super(SparseSETOptimizerBase, self).__init__(use_locking, name) self._optimizer = optimizer self._grow_init = grow_init self._drop_fraction_anneal = drop_fraction_anneal self._drop_fraction_initial_value = ops.convert_to_tensor( float(drop_fraction), name='%s_drop_fraction' % self._drop_fraction_anneal) self._begin_step = ops.convert_to_tensor(begin_step, name='begin_step') self._end_step = ops.convert_to_tensor(end_step, name='end_step') self._frequency = ops.convert_to_tensor(frequency, name='frequency') self._frequency_val = frequency self._use_stateless = use_stateless self._stateless_seed_offset = stateless_seed_offset def compute_gradients(self, loss, **kwargs): """Wraps the compute gradient of passed optimizer.""" result = self._optimizer.compute_gradients(loss, **kwargs) return result def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Wraps the original apply_gradient of the optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. """ pre_op = self._before_apply_gradients(grads_and_vars) with ops.control_dependencies([pre_op]): optimizer_update = self._optimizer.apply_gradients( grads_and_vars, global_step=global_step, name=name) # We get the default one after calling the super.apply_gradient(), since # we want to preserve original behavior of the optimizer: don't increment # anything if no global_step is passed. But we need the global step for # the mask_update. global_step = ( global_step if global_step is not None else training_util.get_or_create_global_step()) self._global_step = global_step with ops.control_dependencies([optimizer_update]): return self.cond_mask_update_op(global_step, control_flow_ops.no_op) def _before_apply_gradients(self, grads_and_vars): """Called before applying gradients.""" return control_flow_ops.no_op('before_apply_grad') def cond_mask_update_op(self, global_step, false_branch): """Creates the conditional mask update operation. All masks are updated when it is an update iteration (checked by self.is_mask_update_iter()). Arguments: global_step: tf.Variable, current training iteration. false_branch: function, called when it is not a mask update iteration. Returns: conditional update operation """ # Initializing to -freq so that last_update_step+freq=0. This enables early # mask_updates. last_update_step = variable_scope.get_variable( 'last_mask_update_step', [], initializer=init_ops.constant_initializer( -self._frequency_val, dtype=global_step.dtype), trainable=False, dtype=global_step.dtype) def mask_update_op(): update_ops = [] for mask, weights in zip(self.get_masks(), self.get_weights()): update_ops.append(self.generic_mask_update(mask, weights)) with ops.control_dependencies(update_ops): assign_op = state_ops.assign( last_update_step, global_step, name='last_mask_update_step_assign') with ops.control_dependencies([assign_op]): return control_flow_ops.no_op('mask_update') maybe_update = control_flow_ops.cond( self.is_mask_update_iter(global_step, last_update_step), mask_update_op, false_branch) return maybe_update def get_weights(self): raise NotImplementedError def get_masks(self): raise NotImplementedError def get_masked_weights(self): raise NotImplementedError def is_mask_update_iter(self, global_step, last_update_step): """Function for checking if the current step is a mask update step. It also creates the drop_fraction op and assigns it to the self object. Args: global_step: tf.Variable(int), current training step. last_update_step: tf.Variable(int), holding the last iteration the mask is updated. Used to determine whether current iteration is a mask update step. Returns: bool, whether the current iteration is a mask_update step. """ gs_dtype = global_step.dtype self._begin_step = math_ops.cast(self._begin_step, gs_dtype) self._end_step = math_ops.cast(self._end_step, gs_dtype) self._frequency = math_ops.cast(self._frequency, gs_dtype) is_step_within_update_range = math_ops.logical_and( math_ops.greater_equal(global_step, self._begin_step), math_ops.logical_or( math_ops.less_equal(global_step, self._end_step), # If _end_step is negative, we never stop updating the mask. # In other words we update the mask with given frequency until the # training ends. math_ops.less(self._end_step, 0))) is_update_step = math_ops.less_equal( math_ops.add(last_update_step, self._frequency), global_step) is_mask_update_iter_op = math_ops.logical_and(is_step_within_update_range, is_update_step) self.drop_fraction = self.get_drop_fraction(global_step, is_mask_update_iter_op) return is_mask_update_iter_op def get_drop_fraction(self, global_step, is_mask_update_iter_op): """Returns a constant or annealing drop_fraction op.""" if self._drop_fraction_anneal == 'constant': drop_frac = self._drop_fraction_initial_value elif self._drop_fraction_anneal == 'cosine': decay_steps = self._end_step - self._begin_step drop_frac = learning_rate_decay.cosine_decay( self._drop_fraction_initial_value, global_step, decay_steps, name='cosine_drop_fraction') elif self._drop_fraction_anneal.startswith('exponential'): exponent = extract_number(self._drop_fraction_anneal) div_dtype = self._drop_fraction_initial_value.dtype power = math_ops.divide( math_ops.cast(global_step - self._begin_step, div_dtype), math_ops.cast(self._end_step - self._begin_step, div_dtype), ) drop_frac = math_ops.multiply( self._drop_fraction_initial_value, math_ops.pow(1 - power, exponent), name='%s_drop_fraction' % self._drop_fraction_anneal) else: raise ValueError('drop_fraction_anneal: %s is not valid' % self._drop_fraction_anneal) return array_ops.where(is_mask_update_iter_op, drop_frac, array_ops.zeros_like(drop_frac)) def generic_mask_update(self, mask, weights, noise_std=1e-5): """True branch of the condition, updates the mask.""" # Ensure that the weights are masked. masked_weights = mask * weights score_drop = math_ops.abs(masked_weights) # Add noise for slight bit of randomness. score_drop += self._random_normal( score_drop.shape, stddev=noise_std, dtype=score_drop.dtype, seed=(hash(weights.name + 'drop'))) # Randomly revive n_prune many connections from non-existing connections. score_grow = self._random_uniform( weights.shape, seed=hash(weights.name + 'grow')) return self._get_update_op(score_drop, score_grow, mask, weights) def _get_update_op(self, score_drop, score_grow, mask, weights, reinit_when_same=False): """Prunes+grows connections, all tensors same shape.""" old_dtype = mask.dtype mask_casted = math_ops.cast(mask, dtypes.float32) n_total = array_ops.size(score_drop) n_ones = math_ops.cast(math_ops.reduce_sum(mask_casted), dtype=dtypes.int32) n_prune = math_ops.cast( math_ops.cast(n_ones, dtype=dtypes.float32) * self.drop_fraction, dtypes.int32) n_keep = n_ones - n_prune # Sort the entire array since the k needs to be constant for TPU. _, sorted_indices = nn_ops.top_k( array_ops.reshape(score_drop, [-1]), k=n_total) sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1) # We will have zeros after having `n_keep` many ones. new_values = array_ops.where( math_ops.range(n_total) < n_keep, array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype), array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype)) mask1 = array_ops.scatter_nd(sorted_indices_ex, new_values, new_values.shape) # Flatten the scores score_grow = array_ops.reshape(score_grow, [-1]) # Set scores of the enabled connections(ones) to min(s) - 1, so that they # have the lowest scores. score_grow_lifted = array_ops.where( math_ops.equal(mask1, 1), array_ops.ones_like(mask1) * (math_ops.reduce_min(score_grow) - 1), score_grow) _, sorted_indices = nn_ops.top_k(score_grow_lifted, k=n_total) sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1) new_values = array_ops.where( math_ops.range(n_total) < n_prune, array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype), array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype)) mask2 = array_ops.scatter_nd(sorted_indices_ex, new_values, new_values.shape) # Ensure masks are disjoint. assert_op = control_flow_ops.Assert( math_ops.equal(math_ops.reduce_sum(mask1 * mask2), 0.), [mask1, mask2]) with ops.control_dependencies([assert_op]): # Let's set the weights of the growed connections. mask2_reshaped = array_ops.reshape(mask2, mask.shape) # Set the values of the new connections. grow_tensor = self.get_grow_tensor(weights, self._grow_init) if reinit_when_same: # If dropped and grown, we re-initialize. new_connections = math_ops.equal(mask2_reshaped, 1) else: new_connections = math_ops.logical_and( math_ops.equal(mask2_reshaped, 1), math_ops.equal(mask_casted, 0)) new_weights = array_ops.where(new_connections, grow_tensor, weights) weights_update = state_ops.assign(weights, new_weights) # Ensure there is no momentum value for new connections reset_op = self.reset_momentum(weights, new_connections) with ops.control_dependencies([weights_update, reset_op]): mask_combined = array_ops.reshape(mask1 + mask2, mask.shape) mask_combined = math_ops.cast(mask_combined, dtype=old_dtype) new_mask = state_ops.assign(mask, mask_combined) return new_mask def reset_momentum(self, weights, new_connections): reset_ops = [] for s_name in self._optimizer.get_slot_names(): # Momentum variable for example, we reset the aggregated values to zero. optim_var = self._optimizer.get_slot(weights, s_name) new_values = array_ops.where(new_connections, array_ops.zeros_like(optim_var), optim_var) reset_ops.append(state_ops.assign(optim_var, new_values)) return control_flow_ops.group(reset_ops) def get_grow_tensor(self, weights, method): """Different ways to initialize new connections. Args: weights: tf.Tensor or Variable. method: str, available options: 'zeros', 'random_normal', 'random_uniform' and 'initial_value' Returns: tf.Tensor same shape and type as weights. Raises: ValueError, when the method is not valid. """ if not isinstance(method, six.string_types): raise ValueError('Grow-Init: %s is not a string' % method) if method == 'zeros': grow_tensor = array_ops.zeros_like(weights, dtype=weights.dtype) elif method.startswith('initial_dist'): original_shape = weights.initial_value.shape divisor = extract_number(method) grow_tensor = array_ops.reshape( random_ops.random_shuffle( array_ops.reshape(weights.initial_value, [-1])), original_shape) / divisor elif method.startswith('random_normal'): stddev = math_ops.reduce_std(weights) divisor = extract_number(method) grow_tensor = self._random_normal( weights.shape, stddev=stddev, dtype=weights.dtype, seed=hash(weights.name + 'grow_init_n')) / divisor elif method.startswith('random_uniform'): mean = math_ops.reduce_mean(math_ops.abs(weights)) divisor = extract_number(method) grow_tensor = self._random_uniform( weights.shape, minval=-mean, maxval=mean, dtype=weights.dtype, seed=hash(weights.name + 'grow_init_u')) / divisor else: raise ValueError('Grow-Init: %s is not a valid option.' % method) return grow_tensor def _random_uniform(self, *args, **kwargs): if self._use_stateless: c_seed = self._stateless_seed_offset + kwargs['seed'] kwargs['seed'] = math_ops.cast( array_ops.stack([c_seed, self._global_step]), dtypes.int32) return stateless_random_ops.stateless_random_uniform(*args, **kwargs) else: return random_ops.random_uniform(*args, **kwargs) def _random_normal(self, *args, **kwargs): if self._use_stateless: c_seed = self._stateless_seed_offset + kwargs['seed'] kwargs['seed'] = math_ops.cast( array_ops.stack([c_seed, self._global_step]), dtypes.int32) return stateless_random_ops.stateless_random_normal(*args, **kwargs) else: return random_ops.random_normal(*args, **kwargs) class SparseRigLOptimizerBase(SparseSETOptimizerBase): """Sparse optimizer that grows connections with the pre-removal gradients. Attributes: optimizer: tf.train.Optimizer begin_step: int, first iteration where masks are updated. end_step: int, iteration after which no mask is updated. frequency: int, of mask update operations. drop_fraction: float, of connections to drop during each update. drop_fraction_anneal: str or None, if supplied used to anneal the drop fraction. use_locking: bool, passed to the super. grow_init: str, name of the method used to initialize new connections. init_avg_scale: float, used to scale the gradient when initializing the, momentum values of new connections. We hope this will improve training, compare to starting from 0 for the new connections. Set this to something between 0 and 1 / (1 - momentum). This is because in the current implementation of MomentumOptimizer, aggregated values converge to 1 / (1 - momentum) with constant gradients. use_tpu: bool, if true the masked_gradients are aggregated. name: bool, passed to the super. """ def __init__(self, optimizer, begin_step, end_step, frequency, drop_fraction=0.1, drop_fraction_anneal='constant', use_locking=False, grow_init='zeros', initial_acc_scale=0., use_tpu=False, name='SparseRigLOptimizer', stateless_seed_offset=0): super(SparseRigLOptimizerBase, self).__init__( optimizer, begin_step, end_step, frequency, drop_fraction=drop_fraction, drop_fraction_anneal=drop_fraction_anneal, grow_init=grow_init, use_locking=use_locking, name='SparseRigLOptimizer', stateless_seed_offset=stateless_seed_offset) self._initial_acc_scale = initial_acc_scale self._use_tpu = use_tpu def set_masked_grads(self, grads, weights): if self._use_tpu: grads = [tpu_ops.cross_replica_sum(g) for g in grads] self._masked_grads = grads # Using names since better to hash. self._weight2masked_grads = {w.name: m for w, m in zip(weights, grads)} def compute_gradients(self, loss, **kwargs): """Wraps the compute gradient of passed optimizer.""" grads_and_vars = self._optimizer.compute_gradients(loss, **kwargs) masked_grads_vars = self._optimizer.compute_gradients( loss, var_list=self.get_masked_weights()) masked_grads = [g for g, _ in masked_grads_vars] self.set_masked_grads(masked_grads, self.get_weights()) return grads_and_vars def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Wraps the original apply_gradient of the optimizer. Args: grads_and_vars: List of (gradient, variable) pairs as returned by `compute_gradients()`. global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the name passed to the `Optimizer` constructor. Returns: An `Operation` that applies the specified gradients. If `global_step` was not None, that operation also increments `global_step`. """ pre_op = self._before_apply_gradients(grads_and_vars) with ops.control_dependencies([pre_op]): # Call this to create slots. _ = self._optimizer.apply_gradients( grads_and_vars, global_step=global_step, name=name) def apply_gradient_op(): optimizer_update = self._optimizer.apply_gradients( grads_and_vars, global_step=global_step, name=name) return optimizer_update # We get the default one after calling the super.apply_gradient(), since # we want to preserve original behavior of the optimizer: don't increment # anything if no global_step is passed. But we need the global step for # the mask_update. global_step = ( global_step if global_step is not None else training_util.get_or_create_global_step()) self._global_step = global_step return self.cond_mask_update_op(global_step, apply_gradient_op) def generic_mask_update(self, mask, weights, noise_std=1e-5): """True branch of the condition, updates the mask.""" # Ensure that the weights are masked. casted_mask = math_ops.cast(mask, dtype=dtypes.float32) masked_weights = casted_mask * weights score_drop = math_ops.abs(masked_weights) # Add noise for slight bit of randomness. score_drop += self._random_normal( score_drop.shape, stddev=noise_std, dtype=score_drop.dtype, seed=hash(weights.name + 'drop')) # Revive n_prune many connections using gradient. score_grow = math_ops.abs(self._weight2masked_grads[weights.name]) with ops.control_dependencies([score_grow]): return self._get_update_op(score_drop, score_grow, mask, weights) def get_grow_tensor(self, weights, method): """Returns initialization for grown weights.""" if method.startswith('grad_scale'): masked_grad = self._weight2masked_grads[weights.name] divisor = extract_number(method) grow_tensor = masked_grad / divisor elif method.startswith('grad_sign'): masked_grad_sign = math_ops.sign(self._weight2masked_grads[weights.name]) divisor = extract_number(method) grow_tensor = masked_grad_sign / divisor else: grow_tensor = super(SparseRigLOptimizerBase, self).get_grow_tensor(weights, method) return grow_tensor def reset_momentum(self, weights, new_connections): reset_ops = [] for s_name in self._optimizer.get_slot_names(): # Momentum variable for example, we reset the aggregated values to zero. optim_var = self._optimizer.get_slot(weights, s_name) accum_grad = ( self._weight2masked_grads[weights.name] * self._initial_acc_scale) new_values = array_ops.where(new_connections, accum_grad, optim_var) reset_ops.append(state_ops.assign(optim_var, new_values)) return control_flow_ops.group(reset_ops) ================================================ FILE: rigl/sparse_optimizers_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the sparse_optimizers file.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import itertools from absl import flags from absl.testing import parameterized import numpy as np from rigl import sparse_optimizers from rigl import sparse_utils import tensorflow.compat.v1 as tf # tf from tensorflow.contrib.model_pruning.python import pruning from tensorflow.contrib.model_pruning.python.layers import layers FLAGS = flags.FLAGS class SparseSETOptimizerTest(tf.test.TestCase, parameterized.TestCase): def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseSETOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac) x = tf.random.uniform((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) global_step = tf.train.get_or_create_global_step() weight = pruning.get_weights()[0] # There is one masked layer to be trained. mask = pruning.get_masks()[0] # Around half of the values of the mask is set to zero with `mask_update`. mask_update = tf.assign( mask, tf.constant( np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]), dtype=tf.float32)) loss = tf.reduce_mean(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) sess.run([mask_update]) return sess, train_op, mask, weight, global_step @parameterized.parameters((15, 25, 0.5), (15, 25, 0.2), (3, 5, 0.2)) def testMaskNonUpdateIterations(self, n_inp, n_out, drop_frac): """Training a layer for 5 iterations and see whether mask is kept intact. The mask should be updated only in iterations 1 and 3 (since start_iter=1, end_iter=4, freq_iter=2). Args: n_inp: int, number of input channels. n_out: int, number of output channels drop_frac: float, passed to the sparse optimizer. """ sess, train_op, mask, _, _ = self._setup_graph( n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2) expected_updates = [1, 3] # Running 5 times to make sure the mask is not updated after end_iter. for i in range(1, 6): c_mask, = sess.run([mask]) sess.run([train_op]) c_mask2, = sess.run([mask]) if i not in expected_updates: self.assertAllEqual(c_mask, c_mask2) @parameterized.parameters((15, 25, 0.5), (15, 25, 0.7), (30, 10, 0.9)) def testUpdateIterations(self, n_inp, n_out, drop_frac): """Checking whether the mask is updated during correct iterations. The mask should be updated only in iterations 1 and 3 (since start_iter=1, end_iter=4, freq_iter=2). Number of 1's in the mask should be equal. Args: n_inp: int, number of input channels. n_out: int, number of output channels drop_frac: float, passed to the sparse optimizer. """ sess, train_op, mask, _, _ = self._setup_graph( n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2) expected_updates = [1, 3] # Running 4 times since last update is at 3. for i in range(1, 5): c_mask, = sess.run([mask]) sess.run([train_op]) c_mask2, = sess.run([mask]) if i in expected_updates: # Number of ones (connections) should be same. self.assertEqual(c_mask.sum(), c_mask2.sum()) # Assert there is some change in the mask. self.assertNotAllClose(c_mask, c_mask2) @parameterized.parameters((3, 7, 2), (1, 5, 3), (0, 4, 1)) def testNoDrop(self, start_iter, end_iter, freq_iter): """Checks when the drop fraction is 0, no update is made. The mask should be updated only in iterations 1 and 3 (since start_iter=1, end_iter=4, freq_iter=2). Number of 1's in the mask should be equal. Args: start_iter: int, start iteration for sparse training. end_iter: int, final iteration for sparse training. freq_iter: int, mask update frequency. """ # Setting drop_fraction to 0; so there is nothing dropped, nothing changed. sess, train_op, mask, _, _ = self._setup_graph( 3, 5, 0, start_iter=start_iter, end_iter=end_iter, freq_iter=freq_iter) for _ in range(end_iter+2): c_mask, = sess.run([mask]) sess.run([train_op]) c_mask2, = sess.run([mask]) self.assertAllEqual(c_mask, c_mask2) def testNewConnectionZeroInit(self): """Checks whether the new connections are initialized correctly to zeros. """ end_iter = 4 sess, train_op, mask, weight, _ = self._setup_graph( n_inp=3, n_out=5, drop_frac=0.5, start_iter=0, end_iter=end_iter, freq_iter=1) # Let's iterate until the mask updates are done. for _ in range(end_iter + 1): mask_tensor, = sess.run([mask]) sess.run([train_op]) new_mask_tensor, new_weight_tensor = sess.run([mask, weight]) # Let's sum the values of the new connections new_weights = new_weight_tensor[np.logical_and(mask_tensor == 0, new_mask_tensor == 1)] self.assertTrue(np.all(new_weights == 0)) @parameterized.parameters(itertools.product( ((3, 7, 2), (5, 3), (1,)), ('zeros', 'random_normal', 'random_uniform'))) def testShapeOfGetGrowTensor(self, shape, init_type): """Checks whether the new tensor is created with correct shape.""" optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1, use_stateless=False) weights = tf.random_uniform(shape) grow_tensor = sparse_optim.get_grow_tensor(weights, init_type) self.assertAllEqual(weights.shape, grow_tensor.shape) @parameterized.parameters(itertools.product( (tf.float32, tf.float64), ('zeros', 'random_normal', 'random_uniform'))) def testDtypeOfGetGrowTensor(self, dtype, init_type): """Checks whether the new tensor is created with correct data type.""" optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1, use_stateless=False) weights = tf.random_uniform((3, 4), dtype=dtype, maxval=5) grow_tensor = sparse_optim.get_grow_tensor(weights, init_type) self.assertEqual(grow_tensor.dtype, weights.dtype) @parameterized.parameters('ones', 'zero', None, 0) def testValueErrorOfGetGrowTensor(self, method): """Checks whether the new tensor is created with correct shape and type.""" optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1, use_stateless=False) weights = tf.random_uniform((3, 4)) with self.assertRaises(ValueError): sparse_optim.get_grow_tensor(weights, method) class SparseStaticOptimizerTest(tf.test.TestCase, parameterized.TestCase): def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseStaticOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac) x = tf.random.uniform((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) global_step = tf.train.get_or_create_global_step() weight = pruning.get_weights()[0] # There is one masked layer to be trained. mask = pruning.get_masks()[0] # Around half of the values of the mask is set to zero with `mask_update`. mask_update = tf.assign( mask, tf.constant( np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]), dtype=tf.float32)) loss = tf.reduce_mean(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) sess.run([mask_update]) return sess, train_op, mask, weight, global_step @parameterized.parameters((15, 25, 0.5), (15, 25, 0.2), (3, 5, 0.2)) def testMaskStatic(self, n_inp, n_out, drop_frac): """Training a layer for 5 iterations and see whether mask is kept intact. The mask should be updated only in iterations 1 and 3 (since start_iter=1, end_iter=4, freq_iter=2). Args: n_inp: int, number of input channels. n_out: int, number of output channels drop_frac: float, passed to the sparse optimizer. """ sess, train_op, mask, _, _ = self._setup_graph( n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2) # Running 5 times to make sure the mask is not updated after end_iter. for _ in range(5): c_mask, = sess.run([mask]) sess.run([train_op]) c_mask2, = sess.run([mask]) self.assertAllEqual(c_mask, c_mask2) class SparseMomentumOptimizerTest(tf.test.TestCase, parameterized.TestCase): def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2, momentum=0.5): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(0.1) sparse_optim = sparse_optimizers.SparseMomentumOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac, momentum=momentum) x = tf.ones((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) # Multiplying the output with range of constants to have constant but # different gradients at the masked weights. y = y * tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) weight = pruning.get_weights()[0] masked_grad = sparse_optim._weight2masked_grads[weight.name] masked_grad_ema = sparse_optim._ema_grads.average(masked_grad) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) return sess, train_op, masked_grad_ema @parameterized.parameters((3, 4, 0.5), (5, 2, 0.), (2, 5, 1.)) def testMomentumUpdate(self, n_inp, n_out, momentum): """Checking whether momentum applied correctly.""" sess, train_op, masked_grad_ema = self._setup_graph( n_inp, n_out, 0.5, start_iter=1, end_iter=4, freq_iter=2, momentum=momentum) # Running 6 times to make sure the momeuntum is always updated. current_momentum = np.zeros((n_inp, n_out)) for _ in range(6): ema_masked_grad, = sess.run([masked_grad_ema]) self.assertAllEqual(ema_masked_grad, current_momentum) sess.run([train_op]) # This is since we multiply the output values with range(n_out) # Note the broadcast from n_out vector to (n_inp, n_out) matrix. current_momentum = (current_momentum * momentum + (1 - momentum) * np.arange(n_out)) ema_masked_grad, = sess.run([masked_grad_ema]) self.assertAllEqual(ema_masked_grad, current_momentum) class SparseRigLOptimizerTest(tf.test.TestCase, parameterized.TestCase): def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(1e-3) global_step = tf.train.get_or_create_global_step() sparse_optim = sparse_optimizers.SparseRigLOptimizer( optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac) x = tf.ones((1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) # Multiplying the output with range of constants to have constant but # different gradients at the masked weights. We also multiply the loss with # global_step to increase the gradient linearly with time. scale_vector = ( tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) * tf.cast(global_step, dtype=y.dtype)) y = y * scale_vector loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) weight = pruning.get_weights()[0] expected_gradient = tf.broadcast_to(scale_vector, weight.shape) masked_grad = sparse_optim._weight2masked_grads[weight.name] # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) return sess, train_op, masked_grad, expected_gradient @parameterized.parameters((3, 4), (5, 2), (2, 5)) def testMaskedGradientCalculation(self, n_inp, n_out): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, train_op, masked_grad, expected_gradient = self._setup_graph( n_inp, n_out, 0., start_iter=0, end_iter=3, freq_iter=1) # Since we only update the mask every 2 iterations, we will iterate 6 times. for i in range(6): is_mask_update = i % 2 == 0 if is_mask_update: expected_gradient_tensor, = sess.run([expected_gradient]) _, masked_grad_tensor = sess.run([train_op, masked_grad]) self.assertAllEqual(masked_grad_tensor, expected_gradient_tensor) else: sess.run([train_op]) @parameterized.parameters( (3, 7, 2, [1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1]), (1, 5, 3, [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1]), (0, 4, 1, [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1])) def testApplyGradients(self, start_iter, end_iter, freq_iter, is_incremented): """Checking apply_gradient is called in non mask update iterations.""" sess, train_op, _, _ = self._setup_graph( 3, 5, .5, start_iter=start_iter, end_iter=end_iter, freq_iter=freq_iter) global_step = tf.train.get_or_create_global_step() # Since we only update the mask every 2 iterations, we will iterate 6 times. for one_if_incremented in is_incremented: before, = sess.run([global_step]) sess.run([train_op]) after, = sess.run([global_step]) if one_if_incremented == 1: self.assertEqual(before + 1, after) else: # Mask update step. self.assertEqual(before, after) class SparseSnipOptimizerTest(tf.test.TestCase, parameterized.TestCase): def _setup_graph(self, default_sparsity, mask_init_method, custom_sparsity_map, n_inp=3, n_out=5): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(1e-3) sparse_optim = sparse_optimizers.SparseSnipOptimizer( optim, default_sparsity, mask_init_method, custom_sparsity_map=custom_sparsity_map) inp_values = np.arange(1, n_inp+1) scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5 # The gradient is the outer product of input and the output gradients. # Since the loss is sample sum the output gradient is equal to the scale # vector. expected_grads = np.outer(inp_values, scale_vector_values) x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) scale_vector = tf.constant(scale_vector_values, dtype=tf.float32) y = y * scale_vector loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() train_op = sparse_optim.minimize(loss, global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) mask = pruning.get_masks()[0] weights = pruning.get_weights()[0] return sess, train_op, expected_grads, sparse_optim, mask, weights @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testSnipSparsity(self, n_inp, n_out, default_sparsity): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, train_op, _, _, mask, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) _ = sess.run([train_op]) snipped_mask, = sess.run([mask]) n_ones = np.sum(snipped_mask) n_zeros = snipped_mask.size - n_ones n_zeros_expected = sparse_utils.get_n_zeros(snipped_mask.size, default_sparsity) self.assertEqual(n_zeros, n_zeros_expected) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testGradientUsed(self, n_inp, n_out, default_sparsity): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, train_op, expected_grads, _, mask, weights = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) # Calculate sensitivity scores. weights, = sess.run([weights]) expected_scores = np.abs(expected_grads*weights) _ = sess.run([train_op]) snipped_mask, = sess.run([mask]) kept_connection_scores = expected_scores[snipped_mask == 1] min_score_kept = np.min(kept_connection_scores) snipped_connection_scores = expected_scores[snipped_mask == 0] max_score_snipped = np.max(snipped_connection_scores) self.assertLessEqual(max_score_snipped, min_score_kept) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testInitialMaskIsDense(self, n_inp, n_out, default_sparsity): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, _, _, _, mask, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) mask_start, = sess.run([mask]) self.assertEqual(np.sum(mask_start), mask_start.size) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testAfterSnipTraining(self, n_inp, n_out, default_sparsity): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, train_op, _, sparse_optim, mask, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) global_step = tf.train.get_or_create_global_step() is_snip_iter = sess.run([train_op]) self.assertTrue(is_snip_iter) # On other iterations mask should stay same. Let's do 3 more iterations. for i in range(3): mask_before, c_iter = sess.run([mask, global_step]) self.assertEqual(i, c_iter) is_snip_iter, is_snipped = sess.run([train_op, sparse_optim.is_snipped]) self.assertTrue(is_snipped) self.assertFalse(is_snip_iter) mask_after, = sess.run([mask]) self.assertAllEqual(mask_after, mask_before) class SparseDNWOptimizerTest(tf.test.TestCase, parameterized.TestCase): def _setup_graph(self, default_sparsity, mask_init_method, custom_sparsity_map, n_inp=3, n_out=5): """Setups a trivial training procedure for sparse training.""" tf.reset_default_graph() optim = tf.train.GradientDescentOptimizer(1e-3) sparse_optim = sparse_optimizers.SparseDNWOptimizer( optim, default_sparsity, mask_init_method, custom_sparsity_map=custom_sparsity_map) inp_values = np.arange(1, n_inp + 1) scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5 # The gradient is the outer product of input and the output gradients. # Since the loss is sample sum the output gradient is equal to the scale # vector. expected_grads = np.outer(inp_values, scale_vector_values) x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp)) y = layers.masked_fully_connected(x, n_out, activation_fn=None) scale_vector = tf.constant(scale_vector_values, dtype=tf.float32) y = y * scale_vector loss = tf.reduce_sum(y) global_step = tf.train.get_or_create_global_step() grads_and_vars = sparse_optim.compute_gradients(loss) train_op = sparse_optim.apply_gradients( grads_and_vars, global_step=global_step) # Init sess = tf.Session() init = tf.global_variables_initializer() sess.run(init) mask = pruning.get_masks()[0] weights = pruning.get_weights()[0] return (sess, train_op, (expected_grads, grads_and_vars), mask, weights) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testDNWSparsity(self, n_inp, n_out, default_sparsity): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, train_op, _, mask, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) _ = sess.run([train_op]) dnw_mask, = sess.run([mask]) n_ones = np.sum(dnw_mask) n_zeros = dnw_mask.size - n_ones n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size, default_sparsity) self.assertEqual(n_zeros, n_zeros_expected) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testWeightsUsed(self, n_inp, n_out, default_sparsity): """Checking whether masked_grad is calculated after apply_gradients.""" # No drop since we don't want to change the mask but check whether the grad # is calculated after the gradient step. sess, train_op, _, mask, weights = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) # Calculate sensitivity scores. weights, = sess.run([weights]) expected_scores = np.abs(weights) _ = sess.run([train_op]) dnw_mask, = sess.run([mask]) kept_connection_scores = expected_scores[dnw_mask == 1] min_score_kept = np.min(kept_connection_scores) dnw_mask_connection_scores = expected_scores[dnw_mask == 0] max_score_removed = np.max(dnw_mask_connection_scores) self.assertLessEqual(max_score_removed, min_score_kept) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testGradientIsDense(self, n_inp, n_out, default_sparsity): """Checking whether calculated gradients are dense.""" sess, _, grad_info, _, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) expected_grad, grads_and_vars = grad_info grad, = sess.run([grads_and_vars[0][0]]) self.assertAllClose(expected_grad, grad) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testDNWUpdates(self, n_inp, n_out, default_sparsity): """Checking whether mask is updated correctly.""" sess, train_op, _, mask, weights = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) # On all iterations mask should have least magnitude connections. for _ in range(5): sess.run([train_op]) mask_after, weights_after = sess.run([mask, weights]) kept_connection_magnitudes = np.abs(weights_after[mask_after == 1]) min_score_kept = np.min(kept_connection_magnitudes) removed_connection_magnitudes = np.abs(weights_after[mask_after == 0]) max_score_removed = np.max(removed_connection_magnitudes) self.assertLessEqual(max_score_removed, min_score_kept) @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8)) def testSparsityAfterDNWUpdates(self, n_inp, n_out, default_sparsity): """Checking whether mask is updated correctly.""" sess, train_op, _, mask, _ = self._setup_graph( default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out) # On all iterations mask should have least magnitude connections. for _ in range(5): sess.run([train_op]) dnw_mask, = sess.run([mask]) n_ones = np.sum(dnw_mask) n_zeros = dnw_mask.size - n_ones n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size, default_sparsity) self.assertEqual(n_zeros, n_zeros_expected) if __name__ == '__main__': tf.test.main() ================================================ FILE: rigl/sparse_utils.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This module has helper functions for the interpolation experiments.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import numpy as np from rigl import str_sparsities import tensorflow.compat.v1 as tf from google_research.micronet_challenge import counting DEFAULT_ERK_SCALE = 1.0 def mask_extract_name_fn(mask_name): return re.findall('(.+)/mask:0', mask_name)[0] def get_n_zeros(size, sparsity): return int(np.floor(sparsity * size)) def calculate_sparsity(masks): dense_params = tf.constant(0.) sparse_params = tf.constant(0.) for mask in masks: dense_params += tf.cast(tf.size(mask), dtype=dense_params.dtype) sparse_params += tf.cast(tf.reduce_sum(mask), dtype=sparse_params.dtype) return 1. - sparse_params / dense_params def get_mask_random_numpy(mask_shape, sparsity, random_state=None): """Creates a random sparse mask with deterministic sparsity. Args: mask_shape: list, used to obtain shape of the random mask. sparsity: float, between 0 and 1. random_state: np.random.RandomState, if given the shuffle call is made using the RandomState Returns: numpy.ndarray """ flat_ones = np.ones(mask_shape).flatten() n_zeros = get_n_zeros(flat_ones.size, sparsity) flat_ones[:n_zeros] = 0 if random_state: random_state.shuffle(flat_ones) else: np.random.shuffle(flat_ones) new_mask = flat_ones.reshape(mask_shape) return new_mask def get_mask_random(mask, sparsity, dtype, random_state=None): """Creates a random sparse mask with deterministic sparsity. Args: mask: tf.Tensor, used to obtain shape of the random mask. sparsity: float, between 0 and 1. dtype: tf.dtype, type of the return value. random_state: np.random.RandomState, if given the shuffle call is made using the RandomState Returns: tf.Tensor """ new_mask_numpy = get_mask_random_numpy( mask.shape.as_list(), sparsity, random_state=random_state) new_mask = tf.constant(new_mask_numpy, dtype=dtype) return new_mask def get_sparsities_erdos_renyi(all_masks, default_sparsity, custom_sparsity_map, include_kernel, extract_name_fn=mask_extract_name_fn, erk_power_scale=DEFAULT_ERK_SCALE): """Given the method, returns the sparsity of individual layers as a dict. It ensures that the non-custom layers have a total parameter count as the one with uniform sparsities. In other words for the layers which are not in the custom_sparsity_map the following equation should be satisfied. # eps * (p_1 * N_1 + p_2 * N_2) = (1 - default_sparsity) * (N_1 + N_2) Args: all_masks: list, of all mask Variables. default_sparsity: float, between 0 and 1. custom_sparsity_map: dict, key/value pairs where the mask correspond whose name is '{key}/mask:0' is set to the corresponding sparsity value. include_kernel: bool, if True kernel dimension are included in the scaling. extract_name_fn: function, extracts the variable name. erk_power_scale: float, if given used to take power of the ratio. Use scale<1 to make the erdos_renyi softer. Returns: sparsities, dict of where keys() are equal to all_masks and individiual masks are mapped to the their sparsities. """ # We have to enforce custom sparsities and then find the correct scaling # factor. is_eps_valid = False # # The following loop will terminate worst case when all masks are in the # custom_sparsity_map. This should probably never happen though, since once # we have a single variable or more with the same constant, we have a valid # epsilon. Note that for each iteration we add at least one variable to the # custom_sparsity_map and therefore this while loop should terminate. dense_layers = set() while not is_eps_valid: # We will start with all layers and try to find right epsilon. However if # any probablity exceeds 1, we will make that layer dense and repeat the # process (finding epsilon) with the non-dense layers. # We want the total number of connections to be the same. Let say we have # for layers with N_1, ..., N_4 parameters each. Let say after some # iterations probability of some dense layers (3, 4) exceeded 1 and # therefore we added them to the dense_layers set. Those layers will not # scale with erdos_renyi, however we need to count them so that target # paratemeter count is achieved. See below. # eps * (p_1 * N_1 + p_2 * N_2) + (N_3 + N_4) = # (1 - default_sparsity) * (N_1 + N_2 + N_3 + N_4) # eps * (p_1 * N_1 + p_2 * N_2) = # (1 - default_sparsity) * (N_1 + N_2) - default_sparsity * (N_3 + N_4) # eps = rhs / (\sum_i p_i * N_i) = rhs / divisor. divisor = 0 rhs = 0 raw_probabilities = {} for mask in all_masks: var_name = extract_name_fn(mask.name) shape_list = mask.shape.as_list() n_param = np.prod(shape_list) n_zeros = get_n_zeros(n_param, default_sparsity) if var_name in dense_layers: # See `- default_sparsity * (N_3 + N_4)` part of the equation above. rhs -= n_zeros elif var_name in custom_sparsity_map: # We ignore custom_sparsities in erdos-renyi calculations. pass else: # Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the # equation above. n_ones = n_param - n_zeros rhs += n_ones # Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out). if include_kernel: raw_probabilities[mask.name] = (np.sum(shape_list) / np.prod(shape_list))**erk_power_scale else: n_in, n_out = shape_list[-2:] raw_probabilities[mask.name] = (n_in + n_out) / (n_in * n_out) # Note that raw_probabilities[mask] * n_param gives the individual # elements of the divisor. divisor += raw_probabilities[mask.name] * n_param # By multipliying individual probabilites with epsilon, we should get the # number of parameters per layer correctly. eps = rhs / divisor # If eps * raw_probabilities[mask.name] > 1. We set the sparsities of that # mask to 0., so they become part of dense_layers sets. max_prob = np.max(list(raw_probabilities.values())) max_prob_one = max_prob * eps if max_prob_one > 1: is_eps_valid = False for mask_name, mask_raw_prob in raw_probabilities.items(): if mask_raw_prob == max_prob: var_name = extract_name_fn(mask_name) tf.logging.info('Sparsity of var: %s had to be set to 0.', var_name) dense_layers.add(var_name) else: is_eps_valid = True sparsities = {} # With the valid epsilon, we can set sparsities of the remaning layers. for mask in all_masks: var_name = extract_name_fn(mask.name) shape_list = mask.shape.as_list() n_param = np.prod(shape_list) if var_name in custom_sparsity_map: sparsities[mask.name] = custom_sparsity_map[var_name] tf.logging.info('layer: %s has custom sparsity: %f', var_name, sparsities[mask.name]) elif var_name in dense_layers: sparsities[mask.name] = 0. else: probability_one = eps * raw_probabilities[mask.name] sparsities[mask.name] = 1. - probability_one tf.logging.info('layer: %s, shape: %s, sparsity: %f', var_name, mask.shape, sparsities[mask.name]) return sparsities def get_sparsities_uniform(all_masks, default_sparsity, custom_sparsity_map, extract_name_fn=mask_extract_name_fn): """Given the method, returns the sparsity of individual layers as a dict. Args: all_masks: list, of all mask Variables. default_sparsity: float, between 0 and 1. custom_sparsity_map: dict, key/value pairs where the mask correspond whose name is '{key}/mask:0' is set to the corresponding sparsity value. extract_name_fn: function, extracts the variable name. Returns: sparsities, dict of where keys() are equal to all_masks and individiual masks are mapped to the their sparsities. """ sparsities = {} for mask in all_masks: var_name = extract_name_fn(mask.name) if var_name in custom_sparsity_map: sparsities[mask.name] = custom_sparsity_map[var_name] else: sparsities[mask.name] = default_sparsity return sparsities def get_sparsities_str(all_masks, default_sparsity): """Given the method, returns the sparsity of individual layers as a dict. Args: all_masks: list, of all mask Variables. default_sparsity: float, between 0 and 1. Returns: sparsities, dict of where keys() are equal to all_masks and individiual masks are mapped to the their sparsities. """ str_sparsities_parsed = str_sparsities.read_all() if default_sparsity in str_sparsities_parsed: sprsts = str_sparsities_parsed[default_sparsity] sparsities = {mask.name: sprsts[mask.name] for mask in all_masks} else: raise ValueError('sparsity: %f is not defined' % default_sparsity) return sparsities def get_sparsities(all_masks, method, default_sparsity, custom_sparsity_map, extract_name_fn=mask_extract_name_fn, erk_power_scale=DEFAULT_ERK_SCALE): """Given the method, returns the sparsity of individual layers as a dict. Args: all_masks: list, of all mask Variables. method: str, 'random' or 'erdos_renyi'. default_sparsity: float, between 0 and 1. custom_sparsity_map: dict, key/value pairs where the mask correspond whose name is '{key}/mask:0' is set to the corresponding sparsity value. extract_name_fn: function, extracts the variable name. erk_power_scale: float, passed to the erdos_renyi function. Returns: sparsities, dict of where keys() are equal to all_masks and individiual masks are mapped to the their sparsities. Raises: ValueError: when a key from custom_sparsity not found in all_masks. ValueError: when an invalid initialization option is given. """ # (1) Ensure all keys are valid and processed. keys_found = set() for mask in all_masks: var_name = extract_name_fn(mask.name) if var_name in custom_sparsity_map: keys_found.add(var_name) keys_given = set(custom_sparsity_map.keys()) if keys_found != keys_given: diff = keys_given - keys_found raise ValueError('No masks are found for the following names: %s' % str(diff)) if method in ('erdos_renyi', 'erdos_renyi_kernel'): include_kernel = method == 'erdos_renyi_kernel' sparsities = get_sparsities_erdos_renyi( all_masks, default_sparsity, custom_sparsity_map, include_kernel=include_kernel, extract_name_fn=extract_name_fn, erk_power_scale=erk_power_scale) elif method == 'random': sparsities = get_sparsities_uniform( all_masks, default_sparsity, custom_sparsity_map, extract_name_fn=extract_name_fn) elif method == 'str': sparsities = get_sparsities_str(all_masks, default_sparsity) else: raise ValueError('Method: %s is not valid mask initialization method' % method) return sparsities def get_mask_init_fn(all_masks, method, default_sparsity, custom_sparsity_map, mask_fn=get_mask_random, erk_power_scale=DEFAULT_ERK_SCALE, extract_name_fn=mask_extract_name_fn): """Returns a function for initializing masks randomly. Args: all_masks: list, of all masks to be updated. method: str, method to initialize the masks, passed to the sparse_utils.get_mask() function. default_sparsity: float, if 0 mask left intact, if greater than one, a fraction of the ones in each mask is flipped to 0. custom_sparsity_map: dict, sparsity of individual variables can be overridden here. Key should point to the correct variable name, and value should be in [0, 1]. mask_fn: function, to initialize masks with given sparsity. erk_power_scale: float, passed to get_sparsities. extract_name_fn: function, used to grab names from the variable. Returns: A callable to run after an init op. See `init_fn` of `tf.train.Scaffold`. Returns None if no `preinitialize_checkpoint` field is set in `RunnerSpec`. Raise: ValueError: when there is no mask corresponding to a key in the custom_sparsity_map. """ sparsities = get_sparsities( all_masks, method, default_sparsity, custom_sparsity_map, erk_power_scale=erk_power_scale, extract_name_fn=extract_name_fn) tf.logging.info('Per layer sparsities are like the following: %s', str(sparsities)) assign_ops = [] for mask in all_masks: new_mask = mask_fn(mask, sparsities[mask.name], mask.dtype) assign_op = tf.assign(mask, new_mask) assign_ops.append(assign_op) return tf.group(assign_ops) ## Calculating flops and parameters using a list of Keras layers. def _get_kernel(layer): """Given the Keras layer returns the weights.""" if isinstance(layer, tf.keras.layers.DepthwiseConv2D): return layer.depthwise_kernel else: return layer.kernel def get_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi', custom_sparsities=None, is_debug=False, width=1., first_layer_name='conv1', last_layer_name='conv_preds', param_size=32, erk_power_scale=DEFAULT_ERK_SCALE): """Given the Keras layer returns the size and FLOPS of the model. Args: masked_layers: list, of tf.keras.Layer. default_sparsity: float, if 0 mask left intact, if greater than one, a fraction of the ones in each mask is flipped to 0. method: str, passed to the `.get_sparsities()` functions. custom_sparsities: dictor None, sparsity of individual variables can be overridden here. Key should point to the correct variable name, and value should be in [0, 1]. is_debug: bool, if True prints individual stats for given layers. width: float, multiplier for the individual layer widths. first_layer_name: str, to scale the width correctly. last_layer_name: str, to scale the width correctly. param_size: int, number of bits to represent a single parameter. erk_power_scale: float, passed to the get_sparsities function. Returns: total_flops, sum of multiply and add operations. total_param_bits, total bits to represent the model during the inference. real_sparsity, calculated independently omitting bias parameters. """ if custom_sparsities is None: custom_sparsities = {} sparsities = get_sparsities([_get_kernel(l) for l in masked_layers], method, default_sparsity, custom_sparsities, lambda a: a, erk_power_scale=erk_power_scale) total_flops = 0 total_param_bits = 0 total_params = 0. n_zeros = 0. for layer in masked_layers: kernel = _get_kernel(layer) k_shape = kernel.shape.as_list() d_in, d_out = 2, 3 # If fully connected change indices. if len(k_shape) == 2: d_in, d_out = 0, 1 # and k_shape[d_in] != 1 since depthwise if not kernel.name.startswith(first_layer_name) and k_shape[d_in] != 1: k_shape[d_in] = int(k_shape[d_in] * width) if not kernel.name.startswith(last_layer_name) and k_shape[d_out] != 1: k_shape[d_out] = int(k_shape[d_out] * width) if is_debug: print(kernel.name, layer.input_shape, k_shape, sparsities[kernel.name]) if isinstance(layer, tf.keras.layers.Conv2D): layer_op = counting.Conv2D(layer.input_shape[1], k_shape, layer.strides, 'same', True, 'relu') elif isinstance(layer, tf.keras.layers.DepthwiseConv2D): layer_op = counting.DepthWiseConv2D(layer.input_shape[1], k_shape, layer.strides, 'same', True, 'relu') elif isinstance(layer, tf.keras.layers.Dense): layer_op = counting.FullyConnected(k_shape, True, 'relu') else: raise ValueError('Should not happen.') param_count, n_mults, n_adds = counting.count_ops(layer_op, sparsities[kernel.name], param_size) total_param_bits += param_count total_flops += n_mults + n_adds n_param = np.prod(k_shape) total_params += n_param n_zeros += int(n_param * sparsities[kernel.name]) return total_flops, total_param_bits, n_zeros / total_params ================================================ FILE: rigl/sparse_utils_test.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the data_helper input pipeline and the training process. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized import numpy as np from rigl import sparse_utils import tensorflow.compat.v1 as tf class GetMaskRandomTest(tf.test.TestCase, parameterized.TestCase): def _setup_session(self): """Resets the graph and returns a fresh session.""" tf.reset_default_graph() sess = tf.Session() return sess @parameterized.parameters(((30, 40), 0.5), ((1, 2, 1, 4), 0.8), ((3,), 0.1)) def testMaskConnectionDeterminism(self, shape, sparsity): sess = self._setup_session() mask = tf.ones(shape) mask1 = sparse_utils.get_mask_random(mask, sparsity, tf.int32) mask2 = sparse_utils.get_mask_random(mask, sparsity, tf.int32) mask1_array, = sess.run([mask1]) mask2_array, = sess.run([mask2]) self.assertEqual(np.sum(mask1_array), np.sum(mask2_array)) @parameterized.parameters(((30, 4), 0.5, 60), ((1, 2, 1, 4), 0.8, 2), ((30,), 0.1, 27)) def testMaskFraction(self, shape, sparsity, expected_ones): sess = self._setup_session() mask = tf.ones(shape) mask1 = sparse_utils.get_mask_random(mask, sparsity, tf.int32) mask1_array, = sess.run([mask1]) self.assertEqual(np.sum(mask1_array), expected_ones) @parameterized.parameters(tf.int32, tf.float32, tf.int64, tf.float64) def testMaskDtype(self, dtype): _ = self._setup_session() mask = tf.ones((3, 2)) mask1 = sparse_utils.get_mask_random(mask, 0.5, dtype) self.assertEqual(mask1.dtype, dtype) class GetSparsitiesTest(tf.test.TestCase, parameterized.TestCase): def _setup_session(self): """Resets the graph and returns a fresh session.""" tf.reset_default_graph() sess = tf.Session() return sess @parameterized.parameters(0., 0.4, 0.9) def testSparsityDictRandom(self, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=(2, 3), name='var1/mask'), tf.get_variable(shape=(2, 3), name='var2/mask'), tf.get_variable(shape=(1, 1, 3), name='var3/mask')] custom_sparsity = {'var1': 0.8} sparsities = sparse_utils.get_sparsities( all_masks, 'random', default_sparsity, custom_sparsity) self.assertEqual(sparsities[all_masks[0].name], 0.8) self.assertEqual(sparsities[all_masks[1].name], default_sparsity) self.assertEqual(sparsities[all_masks[2].name], default_sparsity) @parameterized.parameters(0.1, 0.4, 0.9) def testSparsityDictErdosRenyiCustom(self, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=(2, 4), name='var1/mask'), tf.get_variable(shape=(2, 3), name='var2/mask'), tf.get_variable(shape=(1, 1, 3), name='var3/mask')] custom_sparsity = {'var3': 0.8} sparsities = sparse_utils.get_sparsities( all_masks, 'erdos_renyi', default_sparsity, custom_sparsity) self.assertEqual(sparsities[all_masks[2].name], 0.8) @parameterized.parameters(0.1, 0.4, 0.9) def testSparsityDictErdosRenyiError(self, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=(2, 4), name='var1/mask'), tf.get_variable(shape=(2, 3), name='var2/mask'), tf.get_variable(shape=(1, 1, 3), name='var3/mask')] custom_sparsity = {'var3': 0.8} sparsities = sparse_utils.get_sparsities( all_masks, 'erdos_renyi', default_sparsity, custom_sparsity) self.assertEqual(sparsities[all_masks[2].name], 0.8) @parameterized.parameters(((2, 3), (2, 3), 0.5), ((1, 1, 2, 3), (1, 1, 2, 3), 0.3), ((8, 6), (4, 3), 0.7), ((80, 4), (20, 20), 0.8), ((2, 6), (2, 3), 0.8)) def testSparsityDictErdosRenyiSparsitiesScale( self, shape1, shape2, default_sparsity): _ = self._setup_session() all_masks = [tf.get_variable(shape=shape1, name='var1/mask'), tf.get_variable(shape=shape2, name='var2/mask')] custom_sparsity = {} sparsities = sparse_utils.get_sparsities( all_masks, 'erdos_renyi', default_sparsity, custom_sparsity) sparsity1 = sparsities[all_masks[0].name] size1 = np.prod(shape1) sparsity2 = sparsities[all_masks[1].name] size2 = np.prod(shape2) # Ensure that total number of connections are similar. expected_zeros_uniform = ( sparse_utils.get_n_zeros(size1, default_sparsity) + sparse_utils.get_n_zeros(size2, default_sparsity)) # Ensure that total number of connections are similar. expected_zeros_current = ( sparse_utils.get_n_zeros(size1, sparsity1) + sparse_utils.get_n_zeros(size2, sparsity2)) # Due to rounding we can have some difference. This is expected but should # be less than number of rounding operations we make. diff = abs(expected_zeros_uniform - expected_zeros_current) tolerance = 2 self.assertLessEqual(diff, tolerance) # Ensure that ErdosRenyi proportions are preserved. factor1 = (shape1[-1] + shape1[-2]) / float(shape1[-1] * shape1[-2]) factor2 = (shape2[-1] + shape2[-2]) / float(shape2[-1] * shape2[-2]) self.assertAlmostEqual((1 - sparsity1) / factor1, (1 - sparsity2) / factor2) if __name__ == '__main__': tf.test.main() ================================================ FILE: rigl/str_sparsities.py ================================================ # coding=utf-8 # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Reads ResNet-50 sparsity distributions found by STR. [STR]: https://arxiv.org/abs/2002.03231 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import re REPORTED_SPARSITIES = """ Overall - Overall 25502912 4089284608 79.55 81.27 87.70 90.23 90.55 94.80 95.03 95.15 96.11 96.53 97.78 98.05 98.22 98.79 98.98 99.10 Layer 1 - conv1 9408 118013952 51.46 51.40 63.02 59.80 59.83 64.87 67.36 66.96 72.11 69.46 73.29 73.47 72.05 75.12 76.12 77.75 Layer 2 - layer1.0.conv1 4096 12845056 69.36 73.24 87.57 83.28 85.18 89.60 91.41 91.11 92.38 91.75 94.46 94.51 94.60 95.95 96.53 96.51 Layer 3 - layer1.0.conv2 36864 115605504 77.85 76.26 90.87 89.48 87.31 94.79 94.27 95.04 95.69 96.07 97.36 97.77 98.35 98.51 98.59 98.84 Layer 4 - layer1.0.conv3 16384 51380224 74.81 74.65 86.52 85.80 85.25 91.85 92.78 93.67 94.13 94.69 96.61 97.03 97.37 98.04 98.21 98.47 Layer 5 - layer1.0.downsample.0 16384 51380224 70.95 72.96 83.53 83.34 82.56 89.13 90.62 90.17 91.83 92.69 95.48 94.89 95.68 96.98 97.56 97.72 Layer 6 - layer1.1.conv1 16384 51380224 80.27 79.58 89.82 89.89 88.51 94.56 96.64 95.78 95.81 96.81 98.79 98.90 98.98 99.13 99.62 99.47 Layer 7 - layer1.1.conv2 36864 115605504 81.36 80.95 91.75 90.60 89.61 94.70 95.78 96.18 96.42 97.26 98.65 99.07 99.40 99.11 99.31 99.56 Layer 8 - layer1.1.conv3 16384 51380224 84.45 80.11 91.22 91.70 90.21 95.17 97.05 95.81 96.34 97.23 98.68 98.76 98.90 99.16 99.57 99.46 Layer 9 - layer1.2.conv1 16384 51380224 78.23 79.79 90.12 88.07 89.36 94.62 95.94 94.74 96.23 96.75 97.96 98.41 98.72 99.38 99.35 99.46 Layer 10 - layer1.2.conv2 36864 115605504 76.01 81.53 91.06 87.03 88.27 93.90 95.63 94.26 96.24 96.11 97.54 98.27 98.44 99.32 99.19 99.39 Layer 11 - layer1.2.conv3 16384 51380224 84.47 83.28 94.95 90.99 92.64 95.76 96.95 96.01 96.87 97.31 98.38 98.60 98.72 99.38 99.27 99.51 Layer 12 - layer2.0.conv1 32768 102760448 73.74 73.96 86.78 85.95 85.90 92.32 94.79 93.86 94.62 95.64 97.19 98.22 98.52 98.48 98.84 98.92 Layer 13 - layer2.0.conv2 147456 115605504 82.56 85.70 91.31 93.91 94.03 97.54 97.43 97.65 98.38 98.62 99.24 99.23 99.40 99.61 99.67 99.63 Layer 14 - layer2.0.conv3 65536 51380224 84.70 83.55 93.04 93.13 92.13 96.61 97.37 97.21 97.59 98.14 98.80 98.95 99.18 99.29 99.47 99.43 Layer 15 - layer2.0.downsample.0 131072 102760448 85.10 87.66 92.78 94.96 95.13 98.07 97.97 98.15 98.70 98.88 99.37 99.35 99.40 99.69 99.68 99.71 Layer 16 - layer2.1.conv1 65536 51380224 85.42 85.79 94.04 95.31 94.94 97.92 98.53 98.21 98.84 99.06 99.46 99.53 99.72 99.78 99.81 99.80 Layer 17 - layer2.1.conv2 147456 115605504 76.95 82.75 87.63 91.50 91.76 95.59 97.22 96.07 97.32 97.80 98.24 98.24 98.60 99.24 99.66 99.33 Layer 18 - layer2.1.conv3 65536 51380224 84.76 84.71 93.10 93.66 93.23 97.00 98.18 97.35 98.06 98.41 98.96 99.21 99.32 99.55 99.58 99.59 Layer 19 - layer2.2.conv1 65536 51380224 84.30 85.34 92.70 94.61 94.76 97.72 97.91 98.21 98.54 98.98 99.24 99.35 99.50 99.62 99.63 99.77 Layer 20 - layer2.2.conv2 147456 115605504 84.28 85.43 92.99 94.86 94.90 97.52 97.21 98.11 98.19 99.04 99.28 99.37 99.46 99.63 99.59 99.72 Layer 21 - layer2.2.conv3 65536 51380224 82.19 84.21 91.12 93.38 93.53 96.89 97.14 97.59 97.77 98.66 98.96 99.15 99.25 99.49 99.51 99.57 Layer 22 - layer2.3.conv1 65536 51380224 83.37 84.41 90.46 93.26 93.50 96.71 97.89 96.99 98.14 98.36 99.10 99.23 99.33 99.53 99.75 99.60 Layer 23 - layer2.3.conv2 147456 115605504 82.83 84.03 91.44 93.21 93.25 96.83 98.02 96.96 98.45 98.30 98.97 99.06 99.26 99.31 99.81 99.68 Layer 24 - layer2.3.conv3 65536 51380224 82.93 85.65 91.02 94.14 93.56 97.20 97.97 97.04 98.16 98.36 98.88 98.97 99.20 99.32 99.67 99.62 Layer 25 - layer3.0.conv1 131072 102760448 76.63 77.98 85.99 88.85 88.60 94.26 95.07 94.97 96.21 96.59 97.75 98.04 98.30 98.72 99.11 99.06 Layer 26 - layer3.0.conv2 589824 115605504 87.35 88.68 94.39 96.14 96.19 98.51 98.77 98.72 99.11 99.23 99.53 99.59 99.64 99.73 99.80 99.81 Layer 27 - layer3.0.conv3 262144 51380224 81.22 83.22 90.58 93.19 93.05 96.82 97.38 97.32 97.98 98.28 98.88 99.03 99.16 99.39 99.55 99.53 Layer 28 - layer3.0.downsample.0 524288 102760448 89.75 90.99 96.05 97.20 97.16 98.96 99.21 99.20 99.50 99.58 99.78 99.82 99.86 99.91 99.94 99.93 Layer 29 - layer3.1.conv1 262144 51380224 85.88 87.35 93.43 95.36 96.12 98.64 98.77 98.87 99.22 99.33 99.64 99.67 99.72 99.82 99.88 99.84 Layer 30 - layer3.1.conv2 589824 115605504 85.06 86.24 92.74 95.06 95.30 98.09 98.28 98.36 98.75 99.08 99.46 99.48 99.54 99.69 99.76 99.76 Layer 31 - layer3.1.conv3 262144 51380224 84.34 86.79 92.15 94.84 94.90 97.75 98.15 98.11 98.56 98.94 99.30 99.36 99.45 99.65 99.79 99.70 Layer 32 - layer3.2.conv1 262144 51380224 87.51 89.15 94.15 96.77 96.46 98.81 98.83 98.96 99.19 99.44 99.67 99.71 99.74 99.82 99.85 99.89 Layer 33 - layer3.2.conv2 589824 115605504 87.15 88.67 94.09 95.59 96.14 98.86 98.69 98.91 99.21 99.20 99.64 99.72 99.76 99.85 99.84 99.90 Layer 34 - layer3.2.conv3 262144 51380224 84.86 86.90 92.40 94.99 94.99 98.19 98.19 98.42 98.76 98.97 99.42 99.56 99.62 99.76 99.75 99.88 Layer 35 - layer3.3.conv1 262144 51380224 86.62 89.46 94.06 96.08 95.88 98.70 98.71 98.77 99.01 99.27 99.58 99.66 99.69 99.83 99.87 99.87 Layer 36 - layer3.3.conv2 589824 115605504 86.52 87.97 93.56 96.10 96.11 98.70 98.82 98.89 99.19 99.31 99.68 99.73 99.77 99.88 99.87 99.93 Layer 37 - layer3.3.conv3 262144 51380224 84.19 86.81 92.32 94.94 94.91 98.20 98.37 98.43 98.82 99.00 99.51 99.57 99.64 99.81 99.81 99.87 Layer 38 - layer3.4.conv1 262144 51380224 85.85 88.40 93.55 95.49 95.86 98.35 98.44 98.55 98.79 98.96 99.54 99.59 99.60 99.82 99.86 99.87 Layer 39 - layer3.4.conv2 589824 115605504 85.96 87.38 93.27 95.66 95.63 98.41 98.58 98.56 99.19 99.26 99.64 99.69 99.67 99.87 99.90 99.92 Layer 40 - layer3.4.conv3 262144 51380224 83.45 85.76 91.75 94.49 94.35 97.67 98.09 97.99 98.65 98.94 99.49 99.52 99.48 99.77 99.86 99.85 Layer 41 - layer3.5.conv1 262144 51380224 83.33 85.77 91.79 95.09 94.24 97.46 97.89 97.92 98.71 98.90 99.35 99.52 99.58 99.76 99.79 99.83 Layer 42 - layer3.5.conv2 589824 115605504 84.98 86.67 92.48 94.92 95.13 97.88 98.14 98.32 98.91 99.00 99.44 99.58 99.69 99.80 99.83 99.87 Layer 43 - layer3.5.conv3 262144 51380224 79.78 82.23 89.39 93.14 92.76 96.59 97.04 97.30 98.10 98.41 99.03 99.25 99.44 99.61 99.71 99.75 Layer 44 - layer4.0.conv1 524288 102760448 77.83 79.61 87.11 90.32 90.64 95.39 95.84 95.92 97.17 97.35 98.36 98.60 98.83 99.20 99.37 99.42 Layer 45 - layer4.0.conv2 2359296 115605504 86.18 88.00 93.53 95.66 95.78 98.31 98.47 98.55 99.08 99.16 99.54 99.63 99.69 99.81 99.85 99.86 Layer 46 - layer4.0.conv3 1048576 51380224 78.43 80.48 87.85 91.14 91.27 96.00 96.40 96.47 97.53 97.92 98.81 99.00 99.15 99.45 99.57 99.61 Layer 47 - layer4.0.downsample.0 2097152 102760448 88.49 89.98 95.03 96.79 96.90 98.91 99.06 99.11 99.45 99.51 99.77 99.82 99.85 99.92 99.94 99.94 Layer 48 - layer4.1.conv1 1048576 51380224 82.07 84.02 90.34 93.69 93.72 97.15 97.56 97.76 98.45 98.75 99.27 99.36 99.54 99.67 99.76 99.80 Layer 49 - layer4.1.conv2 2359296 115605504 83.42 85.23 91.16 93.98 93.93 97.26 97.58 97.71 98.36 98.67 99.25 99.34 99.50 99.68 99.76 99.80 Layer 50 - layer4.1.conv3 1048576 51380224 78.08 79.96 86.66 90.48 90.22 95.22 95.76 95.89 96.88 97.65 98.70 98.85 99.13 99.45 99.58 99.66 Layer 51 - layer4.2.conv1 1048576 51380224 76.34 77.93 84.98 87.57 88.47 93.90 93.87 94.16 95.55 95.91 97.66 97.97 98.15 98.88 99.08 99.22 Layer 52 - layer4.2.conv2 2359296 115605504 73.57 74.97 82.32 84.37 86.01 91.92 91.66 92.22 94.02 94.16 96.65 97.13 97.29 98.44 98.74 99.00 Layer 53 - layer4.2.conv3 1048576 51380224 68.78 70.38 78.11 80.29 81.73 89.64 89.43 89.65 91.40 92.65 96.02 96.72 96.93 98.47 98.83 99.15 Layer 54 - fc 2048000 2048000 50.65 52.46 60.48 64.50 65.12 75.20 75.73 75.80 78.57 80.69 85.96 87.26 88.03 91.11 92.15 92.87""" def _name_map_str(k): """Maps the naming of the layers.""" if k == 'conv1': new_key = 'initial_conv' elif k == 'fc': new_key = 'final_dense' else: if 'downsample' in k: group_id = re.search(r'layer(\d)\.0\.downsample\.0', k).group(1) new_key = 'bottleneck_projection_block_group_projection_block_group%s' % group_id else: res = re.search(r'layer(\d)\.(\d)\.conv(\d)', k) group_id, block_id, layer_id = (int(res.group(1)), int(res.group(2)), int(res.group(3))) if block_id == 0: new_key = 'bottleneck_%d_block_group_projection_block_group%d' % ( layer_id, group_id) else: new_key = 'bottleneck_%d_block_group%d_%d_1' % (layer_id, group_id, block_id) return 'resnet_model/%s/mask:0' % new_key def read_all(): """Reads and returns sparsity distributions.""" str_sparsities_parsed = collections.defaultdict(dict) for l in REPORTED_SPARSITIES.strip().split('\n'): l = l.split('-')[1].strip().split(' ') if l[0] == 'Overall': overall_sparsities = list(map(float, l[3:])) else: for i, ls in enumerate(l[3:]): # Sparsities are between 0 and 1, so devide by 100. s = overall_sparsities[i] / 100 new_key = _name_map_str(l[0]) # Accuracies are between 0 and 1, so devide by 100. str_sparsities_parsed[s][new_key] = float(ls) / 100. return str_sparsities_parsed ================================================ FILE: run.sh ================================================ # Copyright 2022 RigL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #!/bin/bash set -e set -x virtualenv -p python3 env source env/bin/activate pip install -r rigl/requirements.txt python -m rigl.sparse_optimizers_test python -m rigl.sparse_utils_test