[
  {
    "path": "CONTRIBUTING.md",
    "content": "# How to Contribute\n\nWe'd love to accept your patches and contributions to this project.\n\n- If you want to contribute to the library please check `Issues` tab and feel\nfree to take on any problem/issue you find interesting.\n- If your `issue` is not reported yet, please create a new one. It is\nimportant to discuss the problem/request before implementing the solution.\n- Reach us at rigl.authors@gmail.com any time!\n\n## Code reviews\n\nAll submissions, including submissions by project members, require review. We\nuse GitHub pull requests for this purpose. Consult\n[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more\ninformation on using pull requests.\n\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Rigging the Lottery: Making All Tickets Winners\n<img src=\"https://github.com/google-research/rigl/blob/master/imgs/flops8.jpg\" alt=\"80% Sparse Resnet-50\" width=\"45%\" align=\"middle\">\n\n**Paper**: [https://arxiv.org/abs/1911.11134](https://arxiv.org/abs/1911.11134)\n\n**15min Presentation** [[pml4dc](https://pml4dc.github.io/iclr2020/program/pml4dc_7.html)] [[icml](https://icml.cc/virtual/2020/paper/5808)]\n\n**ML Reproducibility Challenge 2020** [report](https://openreview.net/forum?id=riCIeP6LzEE)\n\n## Colabs for Calculating FLOPs of Sparse Models\n[MobileNet-v1](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb)\n\n[ResNet-50](https://github.com/google-research/rigl/blob/master/rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb)\n\n## Best Sparse Models\nParameters are float, so each parameter is represented with 4 bytes. Uniform\nsparsity distribution keeps first layer dense therefore have slightly larger size\nand parameters. ERK applies to all layers except for 99% sparse model, in which\nwe set the first layer to be dense, since otherwise we observe much worse\nperformance.\n\n### Extended Training Results\nPerformance of RigL increases significantly with extended training iterations.\nIn this section we extend the training of sparse models by 5x. Note that sparse\nmodels require much less FLOPs per training iteration and therefore most of the\nextended trainings cost less FLOPs than baseline dense training.\n\nObserving improving performance we wanted to understand where the performance of sparse networks saturates. Longest training we ran had 100x training length of the original\n100 epoch ImageNet training. This training costs 5.8x of the original dense training FLOPS and the resulting 99% sparse Resnet-50 achieves an impressive 68.15% test accuracy (vs 5x training accuracy of 61.86%).\n\n| S. Distribution |  Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt         |\n|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|\n| - (DENSE)       | 0         | 3.2e18         | 8.2e9           | 102.122                             | 76.8      | -            |\n| ERK             | 0.8       | 2.09x          | 0.42x           | 23.683                              | 77.17     | [link](https://storage.googleapis.com/gresearch/rigl/s80erk5x.tar.gz) |\n| Uniform         | 0.8       | 1.14x          | 0.23x           | 23.685                              | 76.71     | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform5x.tar.gz) |\n| ERK             | 0.9       | 1.23x          | 0.24x           | 13.499                              | 76.42     | [link](https://storage.googleapis.com/gresearch/rigl/s90erk5x.tar.gz) |\n| Uniform         | 0.9       | 0.66x          | 0.13x           | 13.532                              | 75.73     | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform5x.tar.gz) |\n| ERK             | 0.95      | 0.63x          | 0.12x           | 8.399                               | 74.63     | [link](https://storage.googleapis.com/gresearch/rigl/s95erk5x.tar.gz) |\n| Uniform         | 0.95      | 0.42x          | 0.08x           | 8.433                               | 73.22     | [link](https://storage.googleapis.com/gresearch/rigl/s95uniform5x.tar.gz) |\n| ERK             | 0.965     | 0.45x          | 0.09x           | 6.904                               | 72.77     | [link](https://storage.googleapis.com/gresearch/rigl/s965erk5x.tar.gz) |\n| Uniform         | 0.965     | 0.34x          | 0.07x           | 6.904                               | 71.31     | [link](https://storage.googleapis.com/gresearch/rigl/s965uniform5x.tar.gz) |\n| ERK             | 0.99      | 0.29x          | 0.05x           | 4.354                    | 61.86     | [link](https://storage.googleapis.com/gresearch/rigl/s99erk5x.tar.gz) |\n| ERK             | 0.99  | 0.58x          | 0.05x           | 4.354                               | 63.89 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk10x.tar.gz) |\n| ERK             | 0.99  | 2.32x          | 0.05x           | 4.354                               | 66.94 | [link](https://storage.googleapis.com/gresearch/rigl/s99erk40x.tar.gz) |\n| ERK             | **0.99**  | 5.8x          | 0.05x           | 4.354                               | **68.15** | [link](https://storage.googleapis.com/gresearch/rigl/s99erk100x.tar.gz) |\n\nWe also ran extended training runs with MobileNet-v1. Again training 100x more,\nwe were not able saturate the performance. Training longer consistently achieved\nbetter results.\n\n| S. Distribution |  Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt         |\n|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|\n| - (DENSE)       | 0         | 4.5e17         | 1.14e9           | 16.864                            | 72.1      | -            |\n| ERK             | 0.89       | 1.39x         | 0.21x           | 2.392                             | 69.31     | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk10x.tar.gz) |\n| ERK             | 0.89       | 2.79x          | 0.21x         | 2.392                              | 70.63     | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_erk50x.tar.gz) |\n| Uniform         | 0.89       | 1.25x          | 0.09x           | 2.392                              | 69.28     | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform10x.tar.gz) |\n| Uniform         | 0.89       | 6.25x          | 0.09x           | 2.392                              | 70.25     | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform50x.tar.gz) |\n| Uniform         | 0.89       | 12.5x          | 0.09x           | 2.392                              | 70.59     | [link](https://storage.googleapis.com/gresearch/rigl/mbv1_s90_uniform100x.tar.gz) |\n\n\n### 1x Training Results\n\n| S. Distribution |  Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt         |\n|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|\n| ERK             | 0.8       | 0.42x          | 0.42x           | 23.683                              | 75.12     | [link](https://storage.googleapis.com/gresearch/rigl/s80erk1x.tar.gz) |\n| Uniform         | 0.8       | 0.23x          | 0.23x           | 23.685                              | 74.60     | [link](https://storage.googleapis.com/gresearch/rigl/s80uniform1x.tar.gz) |\n| ERK             | 0.9       | 0.24x          | 0.24x           | 13.499                              | 73.07     | [link](https://storage.googleapis.com/gresearch/rigl/s90erk1x.tar.gz) |\n| Uniform         | 0.9       | 0.13x          | 0.13x           | 13.532                              | 72.02     | [link](https://storage.googleapis.com/gresearch/rigl/s90uniform1x.tar.gz) |\n\n### Results w/o label smoothing\n\n| S. Distribution |  Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt         |\n|-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------|\n| ERK             | 0.8       | 0.42x          | 0.42x           | 23.683                              | 75.02     | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_1x.tar.gz) |\n| ERK             | 0.8       | 2.09x          | 0.42x           | 23.683                              | 76.17     | [link](https://storage.googleapis.com/gresearch/rigl/S80erk_nolabelsmooth_5x.tar.gz) |\n| ERK             | 0.9       | 0.24x          | 0.24x           | 13.499                              | 73.4     | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_1x.tar.gz) |\n| ERK             | 0.9       | 1.23x          | 0.24x           | 13.499                              | 75.9     | [link](https://storage.googleapis.com/gresearch/rigl/S90erk_nolabelsmooth_5x.tar.gz) |\n| ERK             | 0.95      | 0.13x          | 0.12x           | 8.399                              | 70.39     | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_1x.tar.gz) |\n| ERK             | 0.95      | 0.63x          | 0.12x           | 8.399                              | 74.36    | [link](https://storage.googleapis.com/gresearch/rigl/S95erk_nolabelsmooth_5x.tar.gz) |\n\n### Evaluating checkpoints\nDownload the checkpoints and run the evaluation on ERK checkpoints with the\nfollowing:\n\n```python\npython imagenet_train_eval.py --mode=eval_once --output_dir=path/to/ckpt/folder \\\n    --eval_once_ckpt_prefix=model.ckpt-3200000 --use_folder_stub=False \\\n    --training_method=rigl --mask_init_method=erdos_renyi_kernel \\\n    --first_layer_sparsity=-1\n```\n\nWhen running checkpoints with uniform sparsity distribution use `--mask_init_method=random` and `--first_layer_sparsity=0`. Set \n`--model_architecture=mobilenet_v1` when evaluating mobilenet checkpoints.\n\n## Sparse Training Algorithms\nIn this repository we implement following dynamic sparsity strategies:\n\n1.  [SET](https://www.nature.com/articles/s41467-018-04316-3): Implements Sparse\n    Evalutionary Training (SET) which corresponds to replacing low magnitude\n    connections randomly with new ones.\n\n2.  [SNFS](https://arxiv.org/abs/1907.04840): Implements momentum based training\n    *without* sparsity re-distribution:\n\n3.  [RigL](https://arxiv.org/abs/1911.11134): Our method, RigL, removes a\n    fraction of connections based on weight magnitudes and activates new ones\n    using instantaneous gradient information.\n\nAnd the following one-shot pruning algorithm:\n\n1. [SNIP](https://arxiv.org/abs/1810.02340): Single-shot Network Pruning based \n  on connection sensitivity prunes the least salient connections before training.\n\nWe have code for following settings:\n- [Imagenet2012](https://github.com/google-research/rigl/tree/master/rigl/imagenet_resnet):\n  TPU compatible code with Resnet-50 and MobileNet-v1/v2.\n- [CIFAR-10](https://github.com/google-research/rigl/tree/master/rigl/cifar_resnet)\n  with WideResNets.\n- [MNIST](https://github.com/google-research/rigl/tree/master/rigl/mnist) with\n  2 layer fully connected network.\n\n## Setup\nFirst clone this repo.\n```bash\ngit clone https://github.com/google-research/rigl.git\ncd rigl\n```\n\nWe use [Neurips 2019 MicroNet Challenge](https://micronet-challenge.github.io/)\ncode for counting operations and size of our networks. Let's clone the\ngoogle_research repo and add current folder to the python path.\n```bash\ngit clone https://github.com/google-research/google-research.git\nmv google-research/ google_research/\nexport PYTHONPATH=$PYTHONPATH:$PWD\n```\n\nNow we can run some tests. Following script creates a virtual environment and\ninstalls the necessary libraries. Finally, it runs few tests.\n```bash\nbash run.sh\n```\n\nWe need to activate the virtual environment before running an experiment. With\nthat, we are ready to run some trivial MNIST experiments.\n```bash\nsource env/bin/activate\n\npython rigl/mnist/mnist_train_eval.py\n```\n\nYou can load and verify the performance of the Resnet-50 checkpoints\nlike following.\n```bash\npython rigl/imagenet_resnet/imagenet_train_eval.py --mode=eval_once --training_method=baseline --eval_batch_size=100 --output_dir=/path/to/folder --eval_once_ckpt_prefix=s80_model.ckpt-1280000 --use_folder_stub=False\n```\n\nWe use the [Official TPU Code](https://github.com/tensorflow/tpu/tree/master/models/official/resnet)\nfor loading ImageNet data. First clone the\ntensorflow/tpu repo and then add models/ folder to the python path.\n```bash\ngit clone https://github.com/tensorflow/tpu.git\nexport PYTHONPATH=$PYTHONPATH:$PWD/tpu/models/\n```\n\n## Other Implementations\n- [Graphcore-TF-MNIST](https://github.com/graphcore/examples/tree/master/applications/tensorflow/dynamic_sparsity/mnist_rigl): with sparse matrix ops!\n- [Pytorch implementation](https://github.com/McCrearyD/rigl-torch) by Dyllan McCreary.\n- [Micrograd-Pure Python](https://evcu.github.io/ml/sparse-micrograd/): This is\na toy example with pure python sparse implementation. Caution, very slow but fun.\n\n## Citation\n```\n@incollection{rigl,\n author = {Evci, Utku and Gale, Trevor and Menick, Jacob and Castro, Pablo Samuel and Elsen, Erich},\n booktitle = {Proceedings of Machine Learning and Systems 2020},\n pages = {471--481},\n title = {Rigging the Lottery: Making All Tickets Winners},\n year = {2020}\n}\n```\n## Disclaimer\nThis is not an official Google product.\n"
  },
  {
    "path": "rigl/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"This repo involves the code for training sparse neural networks.\"\"\"\nname = 'rigl'\n"
  },
  {
    "path": "rigl/cifar_resnet/data_helper.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Helper functions for CIFAR10 data input pipeline.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nimport tensorflow.compat.v1 as tf\n\nimport tensorflow_datasets as tfds\n\nIMG_SIZE = 32\n\n\ndef pad_input(x, crop_dim=4):\n  \"\"\"Concatenates sides of image with pixels cropped from the border of image.\n\n  Args:\n    x: Input image float32 tensor.\n    crop_dim: Number of pixels to crop from the edge of the image.\n      Cropped pixels are then concatenated to the original image.\n  Returns:\n    x: input image float32 tensor. Transformed by padding edges with cropped\n      pixels.\n  \"\"\"\n  x = tf.concat(\n      [x[:crop_dim, :, :][::-1], x, x[-crop_dim:, :, :][::-1]], axis=0)\n  x = tf.concat(\n      [x[:, :crop_dim, :][:, ::-1], x, x[:, -crop_dim:, :][:, ::-1]], axis=1)\n  return x\n\n\ndef preprocess_train(x, width, height):\n  \"\"\"Pre-processing applied to training data set.\n\n  Args:\n    x: Input image float32 tensor.\n    width: int specifying intended width in pixels of image after preprocessing.\n    height: int specifying intended height in pixels of image after\n      preprocessing.\n  Returns:\n    x: transformed input with random crops, flips and reflection.\n  \"\"\"\n  x = pad_input(x, crop_dim=4)\n  x = tf.random_crop(x, [width, height, 3])\n  x = tf.image.random_flip_left_right(x)\n  return x\n\n\ndef input_fn(params):\n  \"\"\"Provides batches of CIFAR data.\n\n  Args:\n    params: A dictionary with a set of arguments, namely:\n      * batch_size (int32), specifies data points in a batch\n      * data_split (string), designates train or eval\n      * data_dictionary (string), specifies directory location of input dataset\n\n  Returns:\n    images: A float32`Tensor` of size [batch_size, 32, 32, 3].\n    labels: A  int32`Tensor` of size [batch_size, num_classes].\n  \"\"\"\n\n  def parse_serialized_example(record):\n    \"\"\"Parses a CIFAR10 example.\"\"\"\n    image = record['image']\n    label = tf.cast(record['label'], tf.int32)\n    image = tf.cast(image, tf.float32)\n    image = tf.image.per_image_standardization(image)\n    if data_split == 'train':\n      image = preprocess_train(image, IMG_SIZE, IMG_SIZE)\n    return image, label\n\n  data_split = params['data_split']\n  batch_size = params['batch_size']\n  if data_split == 'eval':\n    data_split = 'test'\n  dataset = tfds.load('cifar10:3.*.*', split=data_split)\n\n  # we only repeat an example and shuffle inputs during training\n  if data_split == 'train':\n    dataset = dataset.repeat().shuffle(buffer_size=50000)\n\n  # deserialize record into tensors and apply pre-processing.\n  dataset = dataset.map(parse_serialized_example).prefetch(batch_size)\n\n  # at test time, for the final batch we drop remaining examples so that no\n  # example is seen twice.\n  dataset = dataset.batch(batch_size)\n\n  images_batch, labels_batch = tf.data.make_one_shot_iterator(\n      dataset).get_next()\n\n  return (tf.reshape(images_batch, [batch_size, IMG_SIZE, IMG_SIZE, 3]),\n          tf.reshape(labels_batch, [batch_size]))\n"
  },
  {
    "path": "rigl/cifar_resnet/data_helper_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Tests for the data_helper input pipeline and the training process.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nfrom absl import flags\nfrom absl import logging\nimport absl.testing.parameterized as parameterized\nfrom rigl.cifar_resnet import resnet_train_eval\nfrom rigl.cifar_resnet.data_helper import input_fn\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib.model_pruning.python import pruning\n\nFLAGS = flags.FLAGS\n\nBATCH_SIZE = 1\nNUM_IMAGES = 1\nJITTER_MULTIPLIER = 2\n\n\nclass DataHelperTest(tf.test.TestCase, parameterized.TestCase):\n\n  def get_next(self):\n    data_directory = FLAGS.data_directory\n    # we pass the updated eval and train string to the params dictionary.\n    params = {\n        'mode': 'test',\n        'data_split': 'eval',\n        'batch_size': BATCH_SIZE,\n        'data_directory': data_directory\n    }\n\n    test_inputs, test_labels = input_fn(params)\n\n    return test_inputs, test_labels\n\n  def testInputPipeline(self):\n\n    tf.reset_default_graph()\n    g = tf.Graph()\n    with g.as_default():\n      test_inputs, test_labels = self.get_next()\n\n      with self.test_session() as sess:\n        test_images_out, test_labels_out = sess.run([test_inputs, test_labels])\n        self.assertAllEqual(test_images_out.shape, [BATCH_SIZE, 32, 32, 3])\n        self.assertAllEqual(test_labels_out.shape, [BATCH_SIZE])\n\n  @parameterized.parameters(\n      {\n          'training_method': 'baseline',\n      },\n      {\n          'training_method': 'threshold',\n      },\n      {\n          'training_method': 'rigl',\n      },\n  )\n  def testTrainingStep(self, training_method):\n\n    tf.reset_default_graph()\n    g = tf.Graph()\n    with g.as_default():\n\n      images, labels = self.get_next()\n\n      global_step, _, _, logits = resnet_train_eval.build_model(\n          mode='train',\n          images=images,\n          labels=labels,\n          training_method=training_method,\n          num_classes=FLAGS.num_classes,\n          depth=FLAGS.resnet_depth,\n          width=FLAGS.resnet_width)\n\n      tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)\n\n      total_loss = tf.losses.get_total_loss(add_regularization_losses=True)\n\n      learning_rate = 0.1\n\n      opt = tf.train.MomentumOptimizer(\n          learning_rate, momentum=FLAGS.momentum, use_nesterov=True)\n\n      if training_method in ['threshold']:\n        # Create a pruning object using the pruning hyperparameters\n        pruning_obj = pruning.Pruning()\n\n        logging.info('starting mask update op')\n        mask_update_op = pruning_obj.conditional_mask_update_op()\n\n      # Create the training op\n      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n      with tf.control_dependencies(update_ops):\n        train_op = opt.minimize(total_loss, global_step)\n\n      init_op = tf.global_variables_initializer()\n\n      with self.test_session() as sess:\n        # test that we can train successfully for 1 step\n        sess.run(init_op)\n        for _ in range(1):\n          sess.run(train_op)\n          if training_method in ['threshold']:\n            sess.run(mask_update_op)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "rigl/cifar_resnet/resnet_model.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Model implementation of wide resnet model.\n\nImplements masking layer if pruning method is selected.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom rigl.imagenet_resnet.pruning_layers import sparse_conv2d\nfrom rigl.imagenet_resnet.pruning_layers import sparse_fully_connected\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib import layers as contrib_layers\n_BN_EPS = 1e-5\n_BN_MOMENTUM = 0.9\n\n\nclass WideResNetModel(object):\n  \"\"\"Implements WideResNet model.\"\"\"\n\n  def __init__(self,\n               is_training,\n               regularizer=None,\n               data_format='channels_last',\n               pruning_method='baseline',\n               droprate=0.3,\n               prune_first_layer=True,\n               prune_last_layer=True):\n    \"\"\"WideResnet as described in https://arxiv.org/pdf/1605.07146.pdf.\n\n    Args:\n      is_training: Boolean, True during model training,\n        false for evaluation/inference.\n      regularizer: A regularization function (mapping variables to\n        regularization losses), or None.\n      data_format: A string that indicates whether the channels are the second\n        or last index in the matrix. 'channels_first' or 'channels_last'.\n      pruning_method: str, 'threshold' or 'baseline'.\n      droprate: float, dropout rate to apply activations.\n      prune_first_layer: bool, if True first layer is pruned.\n      prune_last_layer: bool, if True last layer is pruned.\n    \"\"\"\n    self._training = is_training\n    self._regularizer = regularizer\n    self._data_format = data_format\n    self._pruning_method = pruning_method\n    self._droprate = droprate\n    self._prune_first_layer = prune_first_layer\n    self._prune_last_layer = prune_last_layer\n    if data_format == 'channels_last':\n      self._channel_axis = -1\n    elif data_format == 'channels_first':\n      self._channel_axis = 1\n\n  def build(self, inputs, depth, width, num_classes, name=None):\n    \"\"\"Model architecture to train the model.\n\n    The configuration of the resnet blocks requires that depth should be\n    6n+4 where n is the number of resnet blocks desired.\n\n    Args:\n      inputs: A 4D float tensor containing the model inputs.\n      depth: Number of convolutional layers in the network.\n      width: Size of the convolutional filters in the residual blocks.\n      num_classes: Positive integer number of possible classes.\n      name: Optional string, the name of the resulting op in the TF graph.\n\n    Returns:\n      A 2D float logits tensor of shape (batch_size, num_classes).\n    Raises:\n      ValueError: if depth is not the minimum amount required to build the\n        model.\n    \"\"\"\n\n    if (depth - 4) % 6 != 0:\n      raise ValueError('Depth of ResNet specified not sufficient.')\n\n    resnet_blocks = (depth - 4) // 6\n    with tf.variable_scope(name, 'resnet_model'):\n\n      first_layer_technique = self._pruning_method\n      if not self._prune_first_layer:\n        first_layer_technique = 'baseline'\n      net = self._conv(\n          inputs,\n          'conv_1',\n          output_size=16,\n          sparsity_technique=first_layer_technique)\n      net = self._residual_block(\n          net, 'conv_2', 16 * width, subsample=False, blocks=resnet_blocks)\n\n      net = self._residual_block(\n          net, 'conv_3', 32 * width, subsample=True, blocks=resnet_blocks)\n      net = self._residual_block(\n          net, 'conv_4', 64 * width, subsample=True, blocks=resnet_blocks)\n\n      # Put the final BN, relu before the max pooling.\n      with tf.name_scope('Pooling'):\n        net = self._batch_norm(net)\n        net = tf.nn.relu(net)\n        net = tf.layers.average_pooling2d(\n            net, pool_size=8, strides=1, data_format=self._data_format)\n\n      net = contrib_layers.flatten(net)\n      last_layer_technique = self._pruning_method\n      if not self._prune_last_layer:\n        last_layer_technique = 'baseline'\n      net = self._dense(\n          net, num_classes, 'logits', sparsity_technique=last_layer_technique)\n    return net\n\n  def _batch_norm(self, net, name=None):\n    \"\"\"Adds batchnorm to the model.\n\n    Input gradients cannot be computed with fused batch norm; causes recursive\n    loop of tf.gradient call. If regularizer is specified, fused batchnorm must\n    be set to False (default setting).\n\n    Args:\n      net: Pre-batch norm tensor activations.\n      name: Specified name for batch normalization layer.\n\n    Returns:\n      batch norm layer: Activations from the batch normalization layer.\n    \"\"\"\n    return tf.layers.batch_normalization(\n        inputs=net,\n        fused=False,\n        training=self._training,\n        axis=self._channel_axis,\n        momentum=_BN_MOMENTUM,\n        epsilon=_BN_EPS,\n        name=name)\n\n  def _dense(self, net, num_units, name=None, sparsity_technique='baseline'):\n    return sparse_fully_connected(\n        x=net,\n        units=num_units,\n        sparsity_technique=sparsity_technique,\n        kernel_regularizer=self._regularizer,\n        name=name)\n\n  def _conv(self,\n            net,\n            name,\n            output_size,\n            strides=(1, 1),\n            padding='SAME',\n            sparsity_technique='baseline'):\n    \"\"\"returns conv layer.\"\"\"\n    return sparse_conv2d(\n        x=net,\n        units=output_size,\n        activation=None,\n        kernel_size=[3, 3],\n        use_bias=False,\n        kernel_initializer=None,\n        kernel_regularizer=self._regularizer,\n        bias_initializer=None,\n        biases_regularizer=None,\n        sparsity_technique=sparsity_technique,\n        normalizer_fn=None,\n        strides=strides,\n        padding=padding,\n        data_format=self._data_format,\n        name=name)\n\n  def _residual_block(self, net, name, output_size, subsample, blocks):\n    \"\"\"Adds a residual block to the model.\"\"\"\n    with tf.name_scope(name):\n      for n in range(blocks):\n        with tf.name_scope('res_%d' % n):\n          # when subsample is true + first block a larger stride is used.\n          if subsample and n == 0:\n            strides = [2, 2]\n          else:\n            strides = [1, 1]\n\n          # Create the skip connection\n          skip = net\n          end_point = 'skip_%s' % name\n          net = self._batch_norm(net)\n          net = tf.nn.relu(net)\n          if net.get_shape()[3].value != output_size:\n            skip = sparse_conv2d(\n                x=net,\n                units=output_size,\n                activation=None,\n                kernel_size=[1, 1],\n                use_bias=False,\n                kernel_initializer=None,\n                kernel_regularizer=self._regularizer,\n                bias_initializer=None,\n                biases_regularizer=None,\n                sparsity_technique=self._pruning_method,\n                normalizer_fn=None,\n                strides=strides,\n                padding='VALID',\n                data_format=self._data_format,\n                name=end_point)\n\n          # Create residual\n          net = self._conv(\n              net,\n              '%s_%d_1' % (name, n),\n              output_size,\n              strides,\n              sparsity_technique=self._pruning_method)\n          net = self._batch_norm(net)\n          net = tf.nn.relu(net)\n          net = tf.keras.layers.Dropout(self._droprate)(net, self._training)\n          net = self._conv(\n              net,\n              '%s_%d_2' % (name, n),\n              output_size,\n              sparsity_technique=self._pruning_method)\n\n          # Combine the residual and the skip connection\n          net += skip\n    return net\n"
  },
  {
    "path": "rigl/cifar_resnet/resnet_train_eval.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"This script trains a ResNet model that implements various pruning methods.\n\nImplement pruning method during training:\n\nSpecify the pruning method to use using FLAGS.training_method\n- To train a model with no pruning, specify FLAGS.training_method='baseline'\n\nSpecify desired end sparsity using FLAGS.end_sparsity\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nfrom absl import flags\nfrom rigl import sparse_optimizers\nfrom rigl import sparse_utils\nfrom rigl.cifar_resnet.data_helper import input_fn\nfrom rigl.cifar_resnet.resnet_model import WideResNetModel\nfrom rigl.imagenet_resnet import utils\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.compat.v1 import estimator as tf_estimator\nfrom tensorflow.contrib import layers as contrib_layers\nfrom tensorflow.contrib import training as contrib_training\nfrom tensorflow.contrib.model_pruning.python import pruning\n\nflags.DEFINE_string('master', 'local',\n                    'BNS name of the TensorFlow runtime to use.')\nflags.DEFINE_integer('ps_task', 0,\n                     'Task id of the replica running the training.')\nflags.DEFINE_integer('keep_checkpoint_max', 5,\n                     'Number of checkpoints to save, set 0 for all.')\nflags.DEFINE_string('pruning_hparams', '',\n                    'Comma separated list of pruning-related hyperparameters')\nflags.DEFINE_string('train_dir', '/tmp/cifar10/',\n                    'Directory where to write event logs and checkpoint.')\nflags.DEFINE_string(\n    'load_mask_dir', '',\n    'Directory of a trained model from which to load only the mask')\nflags.DEFINE_string(\n    'initial_value_checkpoint', '',\n    'Directory of a model from which to load only the parameters')\nflags.DEFINE_integer(\n    'seed', default=0, help=('Sets the random seed.'))\nflags.DEFINE_float('momentum', 0.9, 'The momentum value.')\n# 250 Epochs\nflags.DEFINE_integer('max_steps', 97656, 'Number of steps to run.')\nflags.DEFINE_float('l2', 5e-4, 'Scale factor for L2 weight decay.')\nflags.DEFINE_integer('resnet_depth', 16, 'Number of core convolutional layers'\n                     'in the network.')\nflags.DEFINE_integer('resnet_width', 4, 'Width of the residual blocks.')\nflags.DEFINE_string(\n    'data_directory', '', 'data directory where cifar10 records are stored')\nflags.DEFINE_integer('num_classes', 10, 'Number of classes.')\nflags.DEFINE_integer('dataset_size', 50000, 'Size of training dataset.')\nflags.DEFINE_integer('batch_size', 128, 'Batch size.')\nflags.DEFINE_integer('checkpoint_steps', 5000, 'Specifies step interval for'\n                     'saving model checkpoints.')\nflags.DEFINE_integer(\n    'summaries_steps', 300, 'Specifies interval in steps for'\n    'saving model summaries.')\nflags.DEFINE_bool('per_class_metrics', True, 'Whether to add per-class'\n                  'performance summaries.')\nflags.DEFINE_enum('mode', 'train', ('train_and_eval', 'train', 'eval'),\n                  'String that specifies either inference or training')\n\n# pruning flags\nflags.DEFINE_integer('sparsity_begin_step', 20000, 'Step to begin pruning at.')\nflags.DEFINE_integer('sparsity_end_step', 75000, 'Step to end pruning at.')\nflags.DEFINE_integer('pruning_frequency', 1000,\n                     'Step interval between pruning steps.')\nflags.DEFINE_float('end_sparsity', 0.9,\n                   'Target sparsity desired by end of training.')\nflags.DEFINE_enum(\n    'training_method', 'baseline',\n    ('scratch', 'set', 'baseline', 'momentum', 'rigl', 'static', 'snip',\n     'prune'),\n    'Method used for training sparse network. `scratch` means initial mask is '\n    'kept during training. `set` is for sparse evalutionary training and '\n    '`baseline` is for dense baseline.')\nflags.DEFINE_bool('prune_first_layer', False,\n                  'Whether or not to apply sparsification to the first layer')\nflags.DEFINE_bool('prune_last_layer', True,\n                  'Whether or not to apply sparsification to the last layer')\nflags.DEFINE_float('drop_fraction', 0.3,\n                   'When changing mask dynamically, this fraction decides how '\n                   'much of the ')\nflags.DEFINE_string('drop_fraction_anneal', 'constant',\n                    'If not empty the drop fraction is annealed during sparse'\n                    ' training. One of the following: `constant`, `cosine` or '\n                    '`exponential_(\\\\d*\\\\.?\\\\d*)$`. For example: '\n                    '`exponential_3`, `exponential_.3`, `exponential_0.3`. '\n                    'The number after `exponential` defines the exponent.')\nflags.DEFINE_string('grow_init', 'zeros',\n                    'Passed to the SparseInitializer, one of: zeros, '\n                    'initial_value, random_normal, random_uniform.')\nflags.DEFINE_float('s_momentum', 0.9,\n                   'Momentum values for exponential moving average of '\n                   'gradients. Used when training_method=\"momentum\".')\nflags.DEFINE_float('rigl_acc_scale', 0.,\n                   'Used to scale initial accumulated gradients for new '\n                   'connections.')\nflags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin mask updates.')\nflags.DEFINE_integer('maskupdate_end_step', 75000, 'Step to end mask updates.')\nflags.DEFINE_integer('maskupdate_frequency', 100,\n                     'Step interval between mask updates.')\nflags.DEFINE_string(\n    'mask_init_method',\n    default='random',\n    help='If not empty string and mask is not loaded from a checkpoint, '\n    'indicates the method used for mask initialization. One of the following: '\n    '`random`, `erdos_renyi`.')\nflags.DEFINE_float('training_steps_multiplier', 1.0,\n                   'Training schedule is shortened or extended with the '\n                   'multiplier, if it is not 1.')\n\nFLAGS = flags.FLAGS\nPARAM_SUFFIXES = ('gamma', 'beta', 'weights', 'biases')\nMASK_SUFFIX = 'mask'\nCLASSES = [\n    'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',\n    'ship', 'truck'\n]\n\n\ndef create_eval_metrics(labels, logits):\n  \"\"\"Creates the evaluation metrics for the model.\"\"\"\n\n  eval_metrics = {}\n  label_keys = CLASSES\n  predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)\n  eval_metrics['eval_accuracy'] = tf.metrics.accuracy(\n      labels=labels, predictions=predictions)\n  if FLAGS.per_class_metrics:\n    with tf.name_scope('class_level_summaries') as scope:\n      for i in range(len(label_keys)):\n        labels = tf.cast(labels, tf.int64)\n        name = scope + '/' + label_keys[i]\n        eval_metrics[('class_level_summaries/precision/' +\n                      label_keys[i])] = tf.metrics.precision_at_k(\n                          labels=labels,\n                          predictions=logits,\n                          class_id=i,\n                          k=1,\n                          name=name)\n        eval_metrics[('class_level_summaries/recall/' +\n                      label_keys[i])] = tf.metrics.recall_at_k(\n                          labels=labels,\n                          predictions=logits,\n                          class_id=i,\n                          k=1,\n                          name=name)\n  return eval_metrics\n\n\ndef train_fn(training_method, global_step, total_loss, train_dir, accuracy,\n             top_5_accuracy):\n  \"\"\"Training script for resnet model.\n\n  Args:\n   training_method: specifies the method used to sparsify networks.\n   global_step: the current step of training/eval.\n   total_loss: tensor float32 of the cross entropy + regularization losses.\n   train_dir: string specifying where directory where summaries are saved.\n   accuracy: tensor float32 batch classification accuracy.\n   top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes).\n\n  Returns:\n    hooks: summary tensors to be computed at each training step.\n    eval_metrics: set to None during training.\n    train_op: the optimization term.\n  \"\"\"\n  # Rougly drops at every 30k steps.\n  boundaries = [30000, 60000, 90000]\n  if FLAGS.training_steps_multiplier != 1.0:\n    multiplier = FLAGS.training_steps_multiplier\n    boundaries = [int(x * multiplier) for x in boundaries]\n    tf.logging.info(\n        'Learning Rate boundaries are updated with multiplier:%.2f', multiplier)\n\n  learning_rate = tf.train.piecewise_constant(\n      global_step,\n      boundaries,\n      values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)],\n      name='lr_schedule')\n\n  optimizer = tf.train.MomentumOptimizer(\n      learning_rate, momentum=FLAGS.momentum, use_nesterov=True)\n\n  if training_method == 'set':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseSETOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal)\n  elif training_method == 'static':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseStaticOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal)\n  elif training_method == 'momentum':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseMomentumOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        grow_init=FLAGS.grow_init,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)\n  elif training_method == 'rigl':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseRigLOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency,\n        drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal,\n        initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)\n  elif training_method == 'snip':\n    optimizer = sparse_optimizers.SparseSnipOptimizer(\n        optimizer, mask_init_method=FLAGS.mask_init_method,\n        default_sparsity=FLAGS.end_sparsity, use_tpu=False)\n  elif training_method in ('scratch', 'baseline', 'prune'):\n    pass\n  else:\n    raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)\n  # Create the training op\n  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n  with tf.control_dependencies(update_ops):\n    train_op = optimizer.minimize(total_loss, global_step)\n\n  if training_method == 'prune':\n    # construct the necessary hparams string from the FLAGS\n    hparams_string = ('begin_pruning_step={0},'\n                      'sparsity_function_begin_step={0},'\n                      'end_pruning_step={1},'\n                      'sparsity_function_end_step={1},'\n                      'target_sparsity={2},'\n                      'pruning_frequency={3},'\n                      'threshold_decay=0,'\n                      'use_tpu={4}'.format(\n                          FLAGS.sparsity_begin_step,\n                          FLAGS.sparsity_end_step,\n                          FLAGS.end_sparsity,\n                          FLAGS.pruning_frequency,\n                          False,\n                      ))\n    # Parse pruning hyperparameters\n    pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)\n\n    # Create a pruning object using the pruning hyperparameters\n    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)\n\n    tf.logging.info('starting mask update op')\n\n    # We override the train op to also update the mask.\n    with tf.control_dependencies([train_op]):\n      train_op = pruning_obj.conditional_mask_update_op()\n\n  masks = pruning.get_masks()\n  mask_metrics = utils.mask_summaries(masks)\n  for name, tensor in mask_metrics.items():\n    tf.summary.scalar(name, tensor)\n\n  tf.summary.scalar('learning_rate', learning_rate)\n  tf.summary.scalar('accuracy', accuracy)\n  tf.summary.scalar('total_loss', total_loss)\n  tf.summary.scalar('top_5_accuracy', top_5_accuracy)\n  # Logging drop_fraction if dynamic sparse training.\n  if training_method in ('set', 'momentum', 'rigl', 'static'):\n    tf.summary.scalar('drop_fraction', optimizer.drop_fraction)\n\n  summary_op = tf.summary.merge_all()\n  summary_hook = tf.train.SummarySaverHook(\n      save_secs=300, output_dir=train_dir, summary_op=summary_op)\n  hooks = [summary_hook]\n  eval_metrics = None\n\n  return hooks, eval_metrics, train_op\n\n\ndef build_model(mode,\n                images,\n                labels,\n                training_method='baseline',\n                num_classes=10,\n                depth=10,\n                width=4):\n  \"\"\"Build the wide ResNet model for training or eval.\n\n  If regularizer is specified, a regularizer term is added to the loss function.\n  The regularizer term is computed using either the pre-softmax activation or an\n  auxiliary network logits layer based upon activations earlier in the network\n  after the first resnet block.\n\n  Args:\n    mode: String for whether training or evaluation is taking place.\n    images:  A 4D float32 tensor containing the model input images.\n    labels:  A int32 tensor of size (batch size, number of classes)\n    containing the model labels.\n    training_method: The method used to sparsify the network weights.\n    num_classes: The number of distinct labels in the dataset.\n    depth: Number of core convolutional layers in the network.\n    width: The width of the convolurional filters in the resnet block.\n\n  Returns:\n    total_loss: A 1D float32 tensor that is the sum of cross-entropy and\n      all regularization losses.\n    accuracy: A 1D float32 accuracy tensor.\n  Raises:\n      ValueError: if depth is not the minimum amount required to build the\n        model.\n  \"\"\"\n  regularizer_term = tf.constant(FLAGS.l2, tf.float32)\n  kernel_regularizer = contrib_layers.l2_regularizer(scale=regularizer_term)\n\n  # depth should be 6n+4 where n is the desired number of resnet blocks\n  # if n=2,depth=10  n=3,depth=22, n=5,depth=34 n=7,depth=46\n  if (depth - 4) % 6 != 0:\n    raise ValueError('Depth of ResNet specified not sufficient.')\n\n  if mode == 'train':\n    is_training = True\n  else:\n    is_training = False\n  # 'threshold' would create layers with mask.\n  pruning_method = 'baseline' if training_method == 'baseline' else 'threshold'\n\n  model = WideResNetModel(\n      is_training=is_training,\n      regularizer=kernel_regularizer,\n      data_format='channels_last',\n      pruning_method=pruning_method,\n      prune_first_layer=FLAGS.prune_first_layer,\n      prune_last_layer=FLAGS.prune_last_layer)\n\n  logits = model.build(\n      images, depth=depth, width=width, num_classes=num_classes)\n\n  global_step = tf.train.get_or_create_global_step()\n\n  predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)\n  accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32))\n\n  in_top_5 = tf.cast(\n      tf.nn.in_top_k(predictions=logits, targets=labels, k=5), tf.float32)\n\n  top_5_accuracy = tf.cast(tf.reduce_mean(in_top_5), tf.float32)\n\n  return global_step, accuracy, top_5_accuracy, logits\n\n\ndef wide_resnet_w_pruning(features, labels, mode, params):\n  \"\"\"The model_fn for ResNet wide with pruning.\n\n  Args:\n    features: A float32 batch of images.\n    labels: A int32 batch of labels.\n    mode: Specifies whether training or evaluation.\n    params: Dictionary of parameters passed to the model.\n\n  Returns:\n    A EstimatorSpec for the model\n\n  Raises:\n      ValueError: if mode is not recognized as train or eval.\n  \"\"\"\n\n  if isinstance(features, dict):\n    features = features['feature']\n\n  train_dir = params['train_dir']\n  training_method = params['training_method']\n\n  global_step, accuracy, top_5_accuracy, logits = build_model(\n      mode=mode,\n      images=features,\n      labels=labels,\n      training_method=training_method,\n      num_classes=FLAGS.num_classes,\n      depth=FLAGS.resnet_depth,\n      width=FLAGS.resnet_width)\n\n  if mode == tf_estimator.ModeKeys.PREDICT:\n    predictions = {\n        'classes': tf.argmax(logits, axis=1),\n        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')\n    }\n    return tf_estimator.EstimatorSpec(\n        mode=mode,\n        predictions=predictions,\n        export_outputs={\n            'classify': tf_estimator.export.PredictOutput(predictions)\n        })\n\n  with tf.name_scope('computing_cross_entropy_loss'):\n    entropy_loss = tf.losses.sparse_softmax_cross_entropy(\n        labels=labels, logits=logits)\n    tf.summary.scalar('cross_entropy_loss', entropy_loss)\n\n  with tf.name_scope('computing_total_loss'):\n    total_loss = tf.losses.get_total_loss(add_regularization_losses=True)\n\n  if mode == tf_estimator.ModeKeys.TRAIN:\n    hooks, eval_metrics, train_op = train_fn(training_method, global_step,\n                                             total_loss, train_dir, accuracy,\n                                             top_5_accuracy)\n  elif mode == tf_estimator.ModeKeys.EVAL:\n    hooks = None\n    train_op = None\n    with tf.name_scope('summaries'):\n      eval_metrics = create_eval_metrics(labels, logits)\n  else:\n    raise ValueError('mode not recognized as training or eval.')\n\n  # If given load parameter values.\n  if FLAGS.initial_value_checkpoint:\n    tf.logging.info('Loading inital values from: %s',\n                    FLAGS.initial_value_checkpoint)\n    utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,\n                                          FLAGS.train_dir, PARAM_SUFFIXES)\n\n  # Load or randomly initialize masks.\n  if (FLAGS.load_mask_dir and\n      FLAGS.training_method not in ('snip', 'baseline', 'prune')):\n    # Init masks.\n    tf.logging.info('Loading masks from %s', FLAGS.load_mask_dir)\n    utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir, FLAGS.train_dir,\n                                          MASK_SUFFIX)\n    scaffold = tf.train.Scaffold()\n  elif (FLAGS.mask_init_method and\n        FLAGS.training_method not in ('snip', 'baseline', 'scratch', 'prune')):\n    tf.logging.info('Initializing masks using method: %s',\n                    FLAGS.mask_init_method)\n    all_masks = pruning.get_masks()\n    assigner = sparse_utils.get_mask_init_fn(\n        all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, {})\n    def init_fn(scaffold, session):\n      \"\"\"A callable for restoring variable from a checkpoint.\"\"\"\n      del scaffold  # Unused.\n      session.run(assigner)\n    scaffold = tf.train.Scaffold(init_fn=init_fn)\n  else:\n    assert FLAGS.training_method in ('snip', 'baseline', 'prune')\n    scaffold = None\n    tf.logging.info('No mask is set, starting dense.')\n\n  return tf_estimator.EstimatorSpec(\n      mode=mode,\n      training_hooks=hooks,\n      loss=total_loss,\n      train_op=train_op,\n      eval_metric_ops=eval_metrics,\n      scaffold=scaffold)\n\n\ndef main(argv):\n  del argv  # Unused.\n  tf.set_random_seed(FLAGS.seed)\n  if FLAGS.training_steps_multiplier != 1.0:\n    multiplier = FLAGS.training_steps_multiplier\n    FLAGS.max_steps = int(FLAGS.max_steps * multiplier)\n    FLAGS.maskupdate_begin_step = int(FLAGS.maskupdate_begin_step * multiplier)\n    FLAGS.maskupdate_end_step = int(FLAGS.maskupdate_end_step * multiplier)\n    FLAGS.sparsity_begin_step = int(FLAGS.sparsity_begin_step * multiplier)\n    FLAGS.sparsity_end_step = int(FLAGS.sparsity_end_step * multiplier)\n    tf.logging.info(\n        'Training schedule is updated with multiplier: %.2f', multiplier)\n  # configures train directories based upon hyperparameters used.\n  if FLAGS.training_method == 'prune':\n    folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),\n                               str(FLAGS.sparsity_begin_step),\n                               str(FLAGS.sparsity_end_step),\n                               str(FLAGS.pruning_frequency))\n\n  elif FLAGS.training_method in ('set', 'momentum', 'rigl', 'static'):\n    folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),\n                               str(FLAGS.maskupdate_begin_step),\n                               str(FLAGS.maskupdate_end_step),\n                               str(FLAGS.maskupdate_frequency))\n  elif FLAGS.training_method in ('baseline', 'snip', 'scratch'):\n    folder_stub = os.path.join(FLAGS.training_method, str(0.0), str(0.0),\n                               str(0.0), str(0.0))\n  else:\n    raise ValueError('Training method is not known %s' % FLAGS.training_method)\n\n  train_dir = os.path.join(FLAGS.train_dir, folder_stub)\n\n  # we pass the updated eval and train string to the params dictionary.\n  params = {}\n  params['train_dir'] = train_dir\n  params['data_split'] = FLAGS.mode\n  params['batch_size'] = FLAGS.batch_size\n  params['data_directory'] = FLAGS.data_directory\n  params['mode'] = FLAGS.mode\n  params['training_method'] = FLAGS.training_method\n\n  run_config = tf_estimator.RunConfig(\n      model_dir=train_dir,\n      keep_checkpoint_max=FLAGS.keep_checkpoint_max,\n      save_summary_steps=FLAGS.summaries_steps,\n      save_checkpoints_steps=FLAGS.checkpoint_steps,\n      log_step_count_steps=100)\n\n  classifier = tf_estimator.Estimator(\n      model_fn=wide_resnet_w_pruning,\n      model_dir=train_dir,\n      config=run_config,\n      params=params)\n\n  if FLAGS.mode == 'eval':\n    eval_steps = 10000 // FLAGS.batch_size\n    # Run evaluation when there's a new checkpoint\n    for ckpt in contrib_training.checkpoints_iterator(train_dir):\n      print('Starting to evaluate.')\n      try:\n        classifier.evaluate(\n            input_fn=input_fn,\n            steps=eval_steps,\n            checkpoint_path=ckpt,\n            name='eval')\n        # Terminate eval job when final checkpoint is reached\n        global_step = int(os.path.basename(ckpt).split('-')[1])\n        if global_step >= FLAGS.max_steps:\n          print('Evaluation finished after training step %d' % global_step)\n          break\n\n      except tf.errors.NotFoundError:\n        print('Checkpoint no longer exists,skipping checkpoint.')\n\n  else:\n    print('Starting training...')\n    if FLAGS.mode == 'train':\n      classifier.train(input_fn=input_fn, max_steps=FLAGS.max_steps)\n\n\nif __name__ == '__main__':\n  tf.app.run(main)\n"
  },
  {
    "path": "rigl/experimental/jax/README.md",
    "content": "# Weight Symmetry Research Code\nThis code is mostly written by Yani Ioannou.\n\n## Experiment Summary\n\nThere are a number of experiment drivers defined in the base directory:\n\n### Experiment Types {#experiment-types}\n\nrandom_mask\n:   Random Variable Sparsity Masks\n:   This experiment generates random masks of a given type (see\n    [Mask Types](#mask-types)) within the *given a sparsity range*, and trains\n    the models, tracking mask statistics and training details. Masks are\n    generated with a random number of connections and randomly shuffled.\n\nshuffled_mask\n:   Random Fixed Sparsity Masks\n:   This experiment generates random masks of a given type (see\n    [Mask Types](#mask-types)) *of a fixed sparsity*, and trains the models,\n    tracking mask statistics and training details. Masks are generated with a\n    fixed number of connections and simply shuffled.\n\nfixed_param\n:   Train models with (approximately) fixed number of parameters, but varying\n    depth/width.\n:   Train models with (approximately) fixed number of parameters, but varying\n    depth/width, with shuffled mask (as in shuffled_mask driver), and only the\n    MNIST_FC model type.\n\nprune\n:   Simple Pruning/Training Driver\n:   This experiment trains a dense model pruning either iteratively or one-shot,\n    tracking mask statistics and training details.\n\ntrain\n:   Simple Training Driver (Without Masking/Pruning)\n:   This experiment simply trains a dense model, tracking mask statistics and\n    training details.\n\n### Mask Types {#mask-types}\n\nsymmetric\n:   Structured Mask.\n:   The mask is a structured\n\nrandom\n:   Unstructured Mask.\n:   The mask as a whole is a random mask of a given sparsity, with some neurons\n    having fewer/more connections than others.\n\nper-neuron\n:   Unstructured Mask.\n:   Each neuron has the same sparsity (# of masked connections), but is shuffled\n    randomly.\n\nper-neuron-no-input-ablation:\n:   Unstructured Mask.\n:   As with per-neuron, each neuron has the same sparsity, but randomly shuffled\n    connections. Also at least one connection is maintained to each of the input\n    neurons (i.e. the input neurons are not effectively ablated), although these\n    connections are also randomly shuffled amongst the neurons of a given layer.\n\n### Model Types {#model-types}\n\nMNIST_FC\n:   A small fully-connected model, accepting number of neurons and depth as\n    parameters. No batch normalization, configurable drop-out rate (default: 0).\n\nMNIST_CNN\n:   A small convolutional model designed for MNIST, accepting number of filters\n    for each layer and depth as parameters. Uses batch normalization and\n    configurable drop-out rate (default: 0).\n\nCIFAR10_CNN\n:   A larger convolutional model designed for CIFAR10, accepting number of\n    filters for each layer and depth as parameters. No batch normalization,\n    configurable drop-out rate (default: 0).\n\n### Dataset Types {#dataset-types}\n\nMNIST\n:   Wrapper of the Tensorflow Datasets (TFDS) MNIST dataset.\n\nCIFAR10\n:   Wrapper of the Tensorflow Datasets (TFDS) CIFAR10 dataset.\n\n## Running Experiments\n\n### Running on a Workstation\n\nTrain:\n\n```shell\npython -m weight_symmetry:${EXPERIMENT_TYPE}\n```\n\n## Result Processing/Analysis\n\n### Plotting Results from a JSON Summary File\n\nYou can convert the results to a Pandas dataframe from a JSON summary file for\nplotting/analysis using the example colab in `analysis/plot_summary_json.ipynb`.\n"
  },
  {
    "path": "rigl/experimental/jax/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"This module contains code for weight symmetry experiments.\"\"\"\nname = 'weight_symmetry'\n"
  },
  {
    "path": "rigl/experimental/jax/analysis/plot_summary_json.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"6iEEw5OwSlnz\"\n      },\n      \"source\": [\n        \"# Plot Results from an Experiment Summary JSON File\",\n        \"Licensed under the Apache License, Version 2.0\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"Eg6FmoCaTCHM\"\n      },\n      \"source\": [\n        \"## Parameters\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"ML0hUJMzYF0W\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from google.colab import files\\n\",\n        \"\\n\",\n        \"# Experiment summary filenames (one per experiment)\\n\",\n        \"SUMMARY_FILES = files.upload()\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"MHubbscQSLGm\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Labels to use for each of the summaries listed above (in the same order!)\\n\",\n        \"XID_LABELS=['structured', 'unstructured']  #@param\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"x0jDBWKdU_2A\"\n      },\n      \"source\": [\n        \"## Loading of JSON Summary/Conversion to Pandas Dataframe\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"Lz-HwS1tU-ie\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import json\\n\",\n        \"import pandas as pd\\n\",\n        \"import os\\n\",\n        \"\\n\",\n        \"from colabtools.interactive_widgets import ProgressIter\\n\",\n        \"\\n\",\n        \"dfs = []\\n\",\n        \"for i, summary_file in enumerate(SUMMARY_FILES):\\n\",\n        \"  with open(summary_file) as summary_file:\\n\",\n        \"    data = json.load(summary_file)\\n\",\n        \"    dataframe = pd.DataFrame.from_dict(data, orient='index')\\n\",\n        \"    dataframe['experiment_label'] = XID_LABELS[i]\\n\",\n        \"    dfs.append(dataframe)\\n\",\n        \"\\n\",\n        \"df=pd.concat(dfs)\\n\",\n        \"\\n\",\n        \"print('Loaded {} rows for experiment'.format(len(data)))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"DhO6oT1nVpTV\"\n      },\n      \"source\": [\n        \"## Measurements and Labels\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"XFRR3XrXVopB\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"DATA_LABELS={\\n\",\n        \"    'best_train_loss/test_accuracy': 'Test Accuracy  (of best train loss)',\\n\",\n        \"    'best_train_loss/train_accuracy': 'Train Accuracy  (of best train loss)',\\n\",\n        \"    'best_train_loss/test_avg_loss': 'Test Loss  (of best train loss)',\\n\",\n        \"    'best_train_loss/train_avg_loss': 'Train Loss  (of best train loss)',\\n\",\n        \"    'best_train_loss/step': 'Training Iterations  (of best train loss)',\\n\",\n        \"    'best_train_loss/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best train loss)',\\n\",\n        \"    'best_train_loss/vector_difference_norm': 'Vector Difference Norm. (of best train loss)',\\n\",\n        \"    'best_train_loss/cosine_distance': 'Cosine Similarity (of best train loss)',\\n\",\n        \"    'best_test_acc/test_accuracy': 'Test Accuracy (of best test acc.)',\\n\",\n        \"    'best_test_acc/train_accuracy': 'Train Accuracy (of best test acc.)',\\n\",\n        \"    'best_test_acc/test_avg_loss': 'Test Loss (of best test acc.)',\\n\",\n        \"    'best_test_acc/train_avg_loss': 'Train Loss (of best test acc.)',\\n\",\n        \"    'best_test_acc/step': 'Training Iterations (of best test acc.)',\\n\",\n        \"    'best_test_acc/cumulative_gradient_norm': 'Cumulative Gradient Norm. (of best Test Acc.)',\\n\",\n        \"    'best_test_acc/cosine_distance': 'Cosine Similarity (of best Test Acc.)',\\n\",\n        \"    'best_test_acc/vector_difference_norm': 'Vector Difference Norm. (of best Test Acc.)',\\n\",\n        \"    'mask/sparsity': 'Sparsity',\\n\",\n        \"    'mask/unique_neurons': '# Unique Neurons',\\n\",\n        \"    'mask/zeroed_neurons': '# Zeroed Neurons',\\n\",\n        \"    'mask/permutation_log10': 'log10(1 + Permutations)',\\n\",\n        \"    'mask/permutation_num_digits': 'Permutation # of Digits',\\n\",\n        \"    'mask/permutations': 'Permutation',\\n\",\n        \"    'mask/total_neurons': 'Total # of Neurons',\\n\",\n        \"    'propagated_mask/sparsity': 'Mask Sparsity',\\n\",\n        \"    'propagated_mask/unique_neurons': '# Unique Neurons (prop.)',\\n\",\n        \"    'propagated_mask/zeroed_neurons': '# Zeroed Neurons (prop.)',\\n\",\n        \"    'propagated_mask/permutation_log10': 'log10(1 + Permutations) (prop.)',\\n\",\n        \"    'propagated_mask/permutation_num_digits': 'Permutation # of Digits (prop.)',\\n\",\n        \"    'propagated_mask/permutations': 'Mask Permutations',\\n\",\n        \"    'propagated_mask/total_neurons': 'Total # of Neurons (prop.)',\\n\",\n        \"    'training/train_avg_loss': 'Train Loss',\\n\",\n        \"}\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"HAVkz8ZzV0Hd\"\n      },\n      \"source\": [\n        \"# Seaborn Plot Example\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"eoxoJH4gWHbb\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Choose the X/Y/Z labels from the parameter list above.\\n\",\n        \"X_LABEL='propagated_mask/sparsity'  #@param {type:\\\"string\\\"}\\n\",\n        \"Y_LABEL='best_train_loss/cumulative_gradient_norm' #@param {type:\\\"string\\\"}\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"pudAXLl1VzFl\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import seaborn as sns\\n\",\n        \"import matplotlib.pyplot as plt\\n\",\n        \"import numpy as np\\n\",\n        \"\\n\",\n        \"# Seaborn style - remove outer plot ticks, white plot background.\\n\",\n        \"np.set_printoptions(linewidth=128, precision=3, edgeitems=5)\\n\",\n        \"sns.set_style(\\\"whitegrid\\\")\\n\",\n        \"sns.color_palette(\\\"muted\\\")\\n\",\n        \"sns.set_context(\\\"paper\\\", font_scale=1, rc={\\n\",\n        \"    \\\"lines.linewidth\\\": 1.2,\\n\",\n        \"    \\\"xtick.major.size\\\": 0,\\n\",\n        \"    \\\"xtick.minor.size\\\": 0,\\n\",\n        \"    \\\"ytick.major.size\\\": 0,\\n\",\n        \"    \\\"ytick.minor.size\\\": 0\\n\",\n        \"})\\n\",\n        \"\\n\",\n        \"# Higher resolution plots\\n\",\n        \"%config InlineBackend.figure_format = 'retina'\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"lYUK9xi_aym3\"\n      },\n      \"source\": [\n        \"### Plot Raw Data Points\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"uWcT76L6Wbv6\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"\\n\",\n        \"plt.figure(figsize=(16,8))\\n\",\n        \"axis = sns.scatterplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', s=50, alpha=.5)\\n\",\n        \"axis.set_ylabel(DATA_LABELS[Y_LABEL])\\n\",\n        \"axis.set_xlabel(DATA_LABELS[X_LABEL])\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"Kws6tjfTa7h0\"\n      },\n      \"source\": [\n        \"### Plot Mean/StdDev\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"jR04tmMnaxjG\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"plt.figure(figsize=(16,8))\\n\",\n        \"axis = sns.lineplot(data=df, x=X_LABEL, y=Y_LABEL, hue='experiment_label', alpha=.5, ci=\\\"sd\\\", markers=True)\\n\",\n        \"axis.set_ylabel(DATA_LABELS[Y_LABEL])\\n\",\n        \"axis.set_xlabel(DATA_LABELS[X_LABEL])\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {},\n        \"colab_type\": \"code\",\n        \"id\": \"jyNFtKQfajiq\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Code to save output files for publication.\\n\",\n        \"PARAM_STR=X_LABEL.replace('/', '-')+'_'+Y_LABEL.replace('/', '-')\\n\",\n        \"\\n\",\n        \"OUT_FILE_PDF=f'/tmp/{PARAM_STR}.pdf'\\n\",\n        \"OUT_FILE_SVG=f'/tmp/{PARAM_STR}.svg'\\n\",\n        \"OUT_FILE_PNG=f'/tmp/{PARAM_STR}.png'\\n\",\n        \"\\n\",\n        \"plt.savefig(OUT_FILE_PDF, pi=600)\\n\",\n        \"files.download(OUT_FILE_PDF)\\n\",\n        \"\\n\",\n        \"plt.savefig(OUT_FILE_SVG)\\n\",\n        \"files.download(OUT_FILE_SVG)\\n\",\n        \"\\n\",\n        \"plt.savefig(OUT_FILE_PNG)\\n\",\n        \"files.download(OUT_FILE_PNG)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [],\n      \"name\": \"plot_summary_json\",\n      \"provenance\": [\n        {\n          \"file_id\": \"1g2aTwv76XMrLfEwryfj_tGzNnvZWjIVl\",\n          \"timestamp\": 1600990155741\n        }\n      ]\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/cifar10.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"CIFAR10 Dataset.\n\nDataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)\nwith JAX/FLAX, by defining a bunch of wrappers, including preprocessing.\nIn this case, the CIFAR10 dataset.\n\"\"\"\nfrom typing import MutableMapping, Sequence\nfrom rigl.experimental.jax.datasets import dataset_base\nimport tensorflow.compat.v2 as tf\n\n\nclass CIFAR10Dataset(dataset_base.ImageDataset):\n  \"\"\"CIFAR10 dataset.\n\n  Attributes:\n      NAME: The Tensorflow Dataset's dataset name.\n  \"\"\"\n  NAME: str = 'cifar10'\n  # Computed from the training set by taking the per-channel mean/std-dev\n  # over sample, height and width axes of all training samples.\n  MEAN_RGB: Sequence[float] = [0.4914 * 255, 0.4822 * 255, 0.4465 * 255]\n  STDDEV_RGB: Sequence[float] = [0.2470 * 255, 0.2435 * 255, 0.2616 * 255]\n\n  def __init__(self,\n               batch_size,\n               batch_size_test,\n               shuffle_buffer_size = 1024,\n               seed = 42):\n    \"\"\"CIFAR10 dataset.\n\n    Args:\n        batch_size: The batch size to use for the training datasets.\n        batch_size_test: The batch size used for the test dataset.\n        shuffle_buffer_size: The buffer size to use for dataset shuffling.\n        seed: The random seed used to shuffle.\n\n    Returns:\n        Dataset: A dataset object.\n\n    Raises:\n        ValueError: If the test dataset is not evenly divisible by the\n                    test batch size.\n    \"\"\"\n    super().__init__(CIFAR10Dataset.NAME, batch_size, batch_size_test,\n                     shuffle_buffer_size, seed)\n    if self.get_test_len() % batch_size_test != 0:\n      raise ValueError(\n          'Test data not evenly divisible by batch size: {} % {} != 0.'.format(\n              self.get_test_len(), batch_size_test))\n\n  def preprocess(\n      self, data):\n    \"\"\"Normalizes CIFAR10 images: `uint8` -> `float32`.\n\n    Args:\n      data: Data sample.\n\n    Returns:\n    Data after being augmented/normalized/transformed.\n    \"\"\"\n    data = super().preprocess(data)\n    mean_rgb = tf.constant(self.MEAN_RGB, shape=[1, 1, 3], dtype=tf.float32)\n    std_rgb = tf.constant(self.STDDEV_RGB, shape=[1, 1, 3], dtype=tf.float32)\n\n    data['image'] = (tf.cast(data['image'], tf.float32) - mean_rgb) / std_rgb\n    return data\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/cifar10_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.datasets.cifar10.\"\"\"\nfrom absl.testing import absltest\nimport numpy as np\n\nfrom rigl.experimental.jax.datasets import cifar10\n\n\nclass CIFAR10DatasetTest(absltest.TestCase):\n  \"\"\"Test cases for CIFAR10 Dataset.\"\"\"\n\n  def setUp(self):\n    \"\"\"Common setup routines/variables for test cases.\"\"\"\n    super().setUp()\n    self._batch_size = 16\n    self._batch_size_test = 10\n    self._shuffle_buffer_size = 8\n\n    self._dataset = cifar10.CIFAR10Dataset(\n        self._batch_size,\n        batch_size_test=self._batch_size_test,\n        shuffle_buffer_size=self._shuffle_buffer_size)\n\n  def test_create_dataset(self):\n    \"\"\"Tests creation of dataset.\"\"\"\n    self.assertIsInstance(self._dataset, cifar10.CIFAR10Dataset)\n\n  def test_train_image_dims_content(self):\n    \"\"\"Tests dimensions and contents of test data.\"\"\"\n    iterator = self._dataset.get_train()\n    sample = next(iterator)\n    image, label = sample['image'], sample['label']\n\n    with self.subTest(name='DataShape'):\n      self.assertTupleEqual(image.shape, (self._batch_size, 32, 32, 3))\n\n    with self.subTest(name='DataType'):\n      self.assertTrue(np.issubdtype(image.dtype, float))\n\n    with self.subTest(name='DataValues'):\n      # Normalized by stddev., expect nothing to fall outside 3 stddev.\n      self.assertTrue((image >= -3.).all() and (image <= 3.).all())\n\n    with self.subTest(name='LabelShape'):\n      self.assertLen(label, self._batch_size)\n\n    with self.subTest(name='LabelType'):\n      self.assertTrue(np.issubdtype(label.dtype, int))\n\n    with self.subTest(name='LabelValues'):\n      self.assertTrue((label >= 0).all() and\n                      (label <= self._dataset.num_classes).all())\n\n  def test_test_image_dims_content(self):\n    \"\"\"Tests dimensions and contents of train data.\"\"\"\n    iterator = self._dataset.get_test()\n    sample = next(iterator)\n    image, label = sample['image'], sample['label']\n\n    with self.subTest(name='DataShape'):\n      self.assertTupleEqual(image.shape, (self._batch_size_test, 32, 32, 3))\n\n    with self.subTest(name='DataType'):\n      self.assertTrue(np.issubdtype(image.dtype, float))\n\n    with self.subTest(name='DataValues'):\n      # Normalized by stddev., expect nothing to fall outside 3 stddev.\n      self.assertTrue((image >= -3.).all() and (image <= 3.).all())\n\n    with self.subTest(name='LabelShape'):\n      self.assertLen(label, self._batch_size_test)\n\n    with self.subTest(name='LabelType'):\n      self.assertTrue(np.issubdtype(label.dtype, int))\n\n    with self.subTest(name='LabelValues'):\n      self.assertTrue((label >= 0).all() and\n                      (label <= self._dataset.num_classes).all())\n\n  def test_train_data_length(self):\n    \"\"\"Tests length of training dataset.\"\"\"\n    total_count = 0\n    for batch in self._dataset.get_train():\n      total_count += len(batch['label'])\n\n    self.assertEqual(total_count, self._dataset.get_train_len())\n\n  def test_test_data_length(self):\n    \"\"\"Tests length of test dataset.\"\"\"\n    total_count = 0\n    for batch in self._dataset.get_test():\n      total_count += len(batch['label'])\n\n    self.assertEqual(total_count, self._dataset.get_test_len())\n\n  def test_dataset_nonevenly_divisible_batch_size(self):\n    \"\"\"Tests non-evenly divisible test batch size.\"\"\"\n    with self.assertRaisesRegex(\n        ValueError, 'Test data not evenly divisible by batch size: .*'):\n      self._dataset = cifar10.CIFAR10Dataset(\n          self._batch_size, batch_size_test=101)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/dataset_base.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Dataset Classes.\n\nDataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)\nwith JAX/FLAX, by defining a bunch of wrappers, including preprocessing.\n\"\"\"\n\nimport abc\nfrom typing import MutableMapping, Optional\n\nimport tensorflow.compat.v2 as tf\nimport tensorflow_datasets as tfds\n\n\nclass Dataset(metaclass=abc.ABCMeta):\n  \"\"\"Base class for datasets.\n\n  Attributes:\n      DATAKEY: The key used for the data component of a Tensorflow Dataset\n        (TFDS) sample, e.g. 'image' for image datasets.\n      LABELKEY: The key used fot the label component of a Tensorflow Dataset\n        sample, i.e. 'label'.\n      name: The TFDS name of the dataset.\n      batch_size: The batch size to use for the training dataset.\n      batch_size_test: The batch size to use for the test dataset.\n      num_classes: the number of supervised classes in the dataset.\n      shape: the shape of an input data array.\n  \"\"\"\n\n  DATAKEY: Optional[str] = None\n  LABELKEY: str = 'label'\n\n  def __init__(self,\n               name,\n               batch_size,\n               batch_size_test,\n               shuffle_buffer_size,\n               prefetch_size = 1,\n               seed = None):  # pytype: disable=annotation-type-mismatch\n    \"\"\"Base class for datasets.\n\n    Args:\n        name: The TFDS name of the dataset.\n        batch_size: The batch size to use for the training dataset.\n        batch_size_test: The batch size to use for the test dataset.\n        shuffle_buffer_size: The buffer size to use for dataset shuffling.\n        prefetch_size: The number of mini-batches to prefetch.\n        seed: The random seed used to shuffle.\n\n    Returns:\n        A Dataset object.\n    \"\"\"\n    super().__init__()\n    self.name = name\n    self.batch_size = batch_size\n    self.batch_size_test = batch_size_test\n    self._shuffle_buffer_size = shuffle_buffer_size\n    self._prefetch_size = prefetch_size\n\n    self._train_ds, self._train_info = tfds.load(\n        self.name,\n        split=tfds.Split.TRAIN,\n        data_dir=self._dataset_dir(),\n        with_info=True)\n    self._train_ds = self._train_ds.shuffle(\n        self._shuffle_buffer_size,\n        seed).map(self.preprocess).cache().map(self.augment).batch(\n            self.batch_size, drop_remainder=True).prefetch(self._prefetch_size)\n\n    self._test_ds, self._test_info = tfds.load(\n        self.name,\n        split=tfds.Split.TEST,\n        data_dir=self._dataset_dir(),\n        with_info=True)\n    self._test_ds = self._test_ds.map(self.preprocess).cache().batch(\n        self.batch_size_test).prefetch(self._prefetch_size)\n\n    self.num_classes = self._train_info.features['label'].num_classes\n    self.shape = self._train_info.features['image'].shape\n\n  def _dataset_dir(self):\n    \"\"\"Returns the dataset path for the TFDS data.\"\"\"\n    return None\n\n  def get_train(self):\n    \"\"\"Returns the training dataset.\"\"\"\n    return iter(tfds.as_numpy(self._train_ds))\n\n  def get_train_len(self):\n    \"\"\"Returns the length of the training dataset.\"\"\"\n    return self._train_info.splits['train'].num_examples\n\n  def get_test(self):\n    \"\"\"Returns the test dataset.\"\"\"\n    return iter(tfds.as_numpy(self._test_ds))\n\n  def get_test_len(self):\n    \"\"\"Returns the length of the test dataset.\"\"\"\n    return self._test_info.splits['test'].num_examples\n\n  def preprocess(\n      self, data):\n    \"\"\"Preprocessing fn used by TFDS map for normalization.\n\n    This function is for transformations that can be cached, e.g.\n    normalization/whitening.\n\n    Args:\n      data: Data sample.\n\n    Returns:\n    Data after being normalized/transformed.\n    \"\"\"\n    return data\n\n  def augment(\n      self, data):\n    \"\"\"Preprocessing fn used by TFDS map for augmentation at training time.\n\n    This function is for transformations that should not be cached, e.g. random\n    augmentation that should change for every sample, and are only applied at\n    training time.\n\n    Args:\n      data: Data sample.\n\n    Returns:\n    Data after being augmented/transformed.\n    \"\"\"\n    return data\n\n\nclass ImageDataset(Dataset):\n  \"\"\"Base class for image datasets.\"\"\"\n\n  DATAKEY = 'image'\n\n  def preprocess(\n      self, data):\n    \"\"\"Preprocessing function used by TFDS map for normalization.\n\n    This function is for transformations that can be cached, e.g.\n    normalization/whitening.\n\n    Args:\n      data: Data sample.\n\n    Returns:\n      Data after being normalized/transformed.\n    \"\"\"\n    data = super().preprocess(data)\n    # Ensure we only provide the image and label, stripping out other keys.\n    return dict((key, val)\n                for key, val in data.items()\n                if key in [self.LABELKEY, self.DATAKEY])\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/dataset_base_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.datasets.dataset_base.\"\"\"\nfrom absl.testing import absltest\n\nfrom rigl.experimental.jax.datasets import dataset_base\n\n\nclass DummyDataset(dataset_base.ImageDataset):\n  \"\"\"A dummy implementation of the abstract dataset class.\n\n  Attributes:\n      NAME: The Tensorflow Dataset's dataset name.\n  \"\"\"\n  NAME: str = 'mnist'\n\n  def __init__(self,\n               batch_size,\n               batch_size_test,\n               shuffle_buffer_size = 1024,\n               seed = 42):\n    \"\"\"Dummy MNIST dataset.\n\n    Args:\n        batch_size: The batch size to use for the training datasets.\n        batch_size_test: The batch size to used for the test dataset.\n        shuffle_buffer_size: The buffer size to use for dataset shuffling.\n        seed: The random seed used to shuffle.\n\n    Returns:\n        Dataset: A dataset object.\n    \"\"\"\n    super().__init__(DummyDataset.NAME, batch_size, batch_size_test,\n                     shuffle_buffer_size, seed)\n\n\nclass DummyDatasetTest(absltest.TestCase):\n  \"\"\"Test cases for dummy dataset.\"\"\"\n\n  def setUp(self):\n    \"\"\"Common setup routines/variables for test cases.\"\"\"\n    super().setUp()\n    self._batch_size = 16\n    self._batch_size_test = 10\n    self._shuffle_buffer_size = 8\n    self._dataset = DummyDataset(\n        self._batch_size,\n        batch_size_test=self._batch_size_test,\n        shuffle_buffer_size=self._shuffle_buffer_size)\n\n  def test_create_dataset(self):\n    \"\"\"Tests creation of dataset.\"\"\"\n    self.assertIsInstance(self._dataset, DummyDataset)\n\n  def test_train_image_dims_content(self):\n    \"\"\"Tests dimensions and contents of test data.\"\"\"\n    iterator = iter(self._dataset.get_train())\n    sample = next(iterator)\n    image, label = sample['image'], sample['label']\n\n    with self.subTest(name='data_shape'):\n      self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1))\n\n    with self.subTest(name='data_values'):\n      self.assertBetween(image.all(), 0, 256)\n\n    with self.subTest(name='label_shape'):\n      self.assertLen(label, self._batch_size)\n\n    with self.subTest(name='label_values'):\n      self.assertBetween(label.all(), 0, self._dataset.num_classes)\n\n  def test_test_image_dims_content(self):\n    \"\"\"Tests dimensions and contents of train data.\"\"\"\n    iterator = iter(self._dataset.get_test())\n    sample = next(iterator)\n    image, label = sample['image'], sample['label']\n\n    with self.subTest(name='data_shape'):\n      self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1))\n\n    with self.subTest(name='data_values'):\n      self.assertBetween(image.all(), 0, 256)\n\n    with self.subTest(name='label_shape'):\n      self.assertLen(label, self._batch_size_test)\n\n    with self.subTest(name='label_values'):\n      self.assertBetween(label.all(), 0, self._dataset.num_classes)\n\n  def test_train_data_length(self):\n    \"\"\"Tests length of training dataset.\"\"\"\n    total_count = 0\n    for batch in self._dataset.get_train():\n      total_count += len(batch['label'])\n\n    self.assertEqual(total_count, self._dataset.get_train_len())\n\n  def test_test_data_length(self):\n    \"\"\"Tests length of test dataset.\"\"\"\n    total_count = 0\n    for batch in self._dataset.get_test():\n      total_count += len(batch['label'])\n\n    # Check image size/content.\n    self.assertEqual(total_count, self._dataset.get_test_len())\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/dataset_factory.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Dataset Factory.\n\nDataset factory to allow us to easily use tensorflow datasets (TFDS)\nwith JAX/FLAX, by defining a bunch of wrappers, including preprocessing.\n\nAttributes:\n  DATASETS: A list of the datasets that can be created.\n\"\"\"\n\nfrom typing import Any, Mapping, Type\nfrom rigl.experimental.jax.datasets import cifar10\nfrom rigl.experimental.jax.datasets import dataset_base\nfrom rigl.experimental.jax.datasets import mnist\nimport tensorflow.compat.v2 as tf\n\n\nDATASETS: Mapping[str, Type[dataset_base.Dataset]] = {\n    'MNIST': mnist.MNISTDataset,\n    'CIFAR10': cifar10.CIFAR10Dataset,\n}\n\n\ndef create_dataset(name, *args, **kwargs):\n  \"\"\"Creates a Tensorflow datasets (TFDS) dataset.\n\n  Args:\n      name: The TFDS name of the dataset.\n      *args: Dataset arguments.\n      **kwargs: Dataset keyword arguments.\n\n  Returns:\n      Dataset: An abstracted dataset object.\n\n  Raises:\n      ValueError if a dataset with the given name does not exist.\n  \"\"\"\n  if name not in DATASETS:\n    raise ValueError(f'No such dataset: {name}')\n  return DATASETS[name](*args, **kwargs)\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/dataset_factory_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.datasets.dataset_common.\"\"\"\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport numpy as np\nfrom rigl.experimental.jax.datasets import dataset_base\nfrom rigl.experimental.jax.datasets import dataset_factory\n\n\nclass DatasetCommonTest(parameterized.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    self._batch_size = 32\n    self._batch_size_test = 10\n    self._shuffle_buffer_size = 128\n\n  def _create_dataset(self, dataset_name):\n    \"\"\"Helper function for creating a dataset.\"\"\"\n    return dataset_factory.create_dataset(\n        dataset_name,\n        self._batch_size,\n        self._batch_size_test,\n        shuffle_buffer_size=self._shuffle_buffer_size)\n\n  def test_dataset_supported(self):\n    \"\"\"Tests supported datasets.\"\"\"\n    for dataset_name in dataset_factory.DATASETS:\n      dataset = self._create_dataset(dataset_name)\n\n      self.assertIsInstance(dataset, dataset_base.Dataset)\n\n  @parameterized.parameters(*dataset_factory.DATASETS.keys())\n  def test_dataset_train_iterators(self, dataset_name):\n    \"\"\"Tests dataset's train iterator.\"\"\"\n    dataset = self._create_dataset(dataset_name)\n    sample = next(dataset.get_train())\n\n    with self.subTest(name='{}_sample'.format(dataset_name)):\n      self.assertNotEmpty(sample)\n\n    with self.subTest(name='{}_label_type'.format(dataset_name)):\n      self.assertIsInstance(sample['label'], np.ndarray)\n\n    with self.subTest(name='{}_label_batch_size'.format(dataset_name)):\n      self.assertLen(sample['label'], self._batch_size)\n\n    with self.subTest(name='{}_image_type'.format(dataset_name)):\n      self.assertIsInstance(sample['image'], np.ndarray)\n\n    with self.subTest(name='{}_image_shape'.format(dataset_name)):\n      self.assertLen(sample['image'].shape, 4)\n\n    with self.subTest(name='{}_image_batch_size'.format(dataset_name)):\n      self.assertEqual(sample['image'].shape[0], self._batch_size)\n\n    with self.subTest(\n        name='{}_non_zero_image_dimensions'.format(dataset_name)):\n      self.assertGreater(sample['image'].shape[1], 1)\n\n  @parameterized.parameters(*dataset_factory.DATASETS.keys())\n  def test_dataset_test_iterators(self, dataset_name):\n    \"\"\"Tests dataset's test iterator.\"\"\"\n    dataset = self._create_dataset(dataset_name)\n    sample = next(dataset.get_test())\n\n    with self.subTest(name='{}_sample'.format(dataset_name)):\n      self.assertNotEmpty(sample)\n\n    with self.subTest(name='{}_label_type'.format(dataset_name)):\n      self.assertIsInstance(sample['label'], np.ndarray)\n\n    with self.subTest(name='{}_label_batch_size'.format(dataset_name)):\n      self.assertLen(sample['label'], self._batch_size_test)\n\n    with self.subTest(name='{}_image_type'.format(dataset_name)):\n      self.assertIsInstance(sample['image'], np.ndarray)\n\n    with self.subTest(name='{}_image_shape'.format(dataset_name)):\n      self.assertLen(sample['image'].shape, 4)\n\n    with self.subTest(name='{}_image_batch_size'.format(dataset_name)):\n      self.assertEqual(sample['image'].shape[0], self._batch_size_test)\n\n    with self.subTest(\n        name='{}_non_zero_image_dimensions'.format(dataset_name)):\n      self.assertGreater(sample['image'].shape[1], 1)\n\n  def test_dataset_unsupported(self):\n    \"\"\"Tests unsupported datasets.\"\"\"\n    with self.assertRaisesRegex(ValueError, 'No such dataset: unsupported'):\n      self._create_dataset('unsupported')\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/mnist.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"MNIST Dataset.\n\nDataset abstraction/factory to allow us to easily use tensorflow datasets (TFDS)\nwith JAX/FLAX, by defining a bunch of wrappers, including preprocessing.\nIn this case, the MNIST dataset.\n\"\"\"\nfrom typing import MutableMapping\nfrom rigl.experimental.jax.datasets import dataset_base\nimport tensorflow.compat.v2 as tf\n\n\nclass MNISTDataset(dataset_base.ImageDataset):\n  \"\"\"MNIST dataset.\n\n  Attributes:\n      NAME: The Tensorflow Dataset's dataset name.\n  \"\"\"\n  NAME: str = 'mnist'\n\n  def __init__(self,\n               batch_size,\n               batch_size_test,\n               shuffle_buffer_size = 1024,\n               seed = 42):\n    \"\"\"MNIST dataset.\n\n    Args:\n        batch_size: The batch size to use for the training datasets.\n        batch_size_test: The batch size to used for the test dataset.\n        shuffle_buffer_size: The buffer size to use for dataset shuffling.\n        seed: The random seed used to shuffle.\n\n    Returns:\n        Dataset: A dataset object.\n    \"\"\"\n    super().__init__(MNISTDataset.NAME, batch_size, batch_size_test,\n                     shuffle_buffer_size, seed)\n\n  def preprocess(\n      self, data):\n    \"\"\"Normalizes MNIST images: `uint8` -> `float32`.\n\n    Args:\n      data: Data sample.\n\n    Returns:\n    Data after being augmented/normalized/transformed.\n    \"\"\"\n    data = super().preprocess(data)\n    data['image'] = (tf.cast(data['image'], tf.float32) / 255.) - 0.5\n    return data\n"
  },
  {
    "path": "rigl/experimental/jax/datasets/mnist_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.datasets.mnist.\"\"\"\nfrom absl.testing import absltest\nimport numpy as np\n\nfrom rigl.experimental.jax.datasets import mnist\n\n\nclass MNISTDatasetTest(absltest.TestCase):\n  \"\"\"Test cases for MNIST Dataset.\"\"\"\n\n  def setUp(self):\n    \"\"\"Common setup routines/variables for test cases.\"\"\"\n    super().setUp()\n    self._batch_size = 16\n    self._batch_size_test = 10\n    self._shuffle_buffer_size = 8\n\n    self._dataset = mnist.MNISTDataset(\n        self._batch_size,\n        batch_size_test=self._batch_size_test,\n        shuffle_buffer_size=self._shuffle_buffer_size)\n\n  def test_create_dataset(self):\n    \"\"\"Tests creation of dataset.\"\"\"\n    self.assertIsInstance(self._dataset, mnist.MNISTDataset)\n\n  def test_train_image_dims_content(self):\n    \"\"\"Tests dimensions and contents of test data.\"\"\"\n    iterator = self._dataset.get_train()\n    sample = next(iterator)\n    image, label = sample['image'], sample['label']\n\n    with self.subTest(name='data_shape'):\n      self.assertTupleEqual(image.shape, (self._batch_size, 28, 28, 1))\n\n    with self.subTest(name='data_values'):\n      self.assertTrue((image >= -1.).all() and (image <= 1.).all())\n\n    with self.subTest(name='data_type'):\n      self.assertTrue(np.issubdtype(image.dtype, float))\n\n    with self.subTest(name='label_shape'):\n      self.assertLen(label, self._batch_size)\n\n    with self.subTest(name='label_type'):\n      self.assertTrue(np.issubdtype(label.dtype, int))\n\n    with self.subTest(name='label_values'):\n      self.assertTrue((label >= 0).all() and\n                      (label <= self._dataset.num_classes).all())\n\n  def test_test_image_dims_content(self):\n    \"\"\"Tests dimensions and contents of train data.\"\"\"\n    iterator = self._dataset.get_test()\n    sample = next(iterator)\n    image, label = sample['image'], sample['label']\n\n    with self.subTest(name='data_shape'):\n      self.assertTupleEqual(image.shape, (self._batch_size_test, 28, 28, 1))\n\n    with self.subTest(name='data_type'):\n      self.assertTrue(np.issubdtype(image.dtype, float))\n\n    # TODO: Find a better approach to testing with JAX arrays.\n    with self.subTest(name='data_values'):\n      self.assertTrue((image >= -1.).all() and (image <= 1.).all())\n\n    with self.subTest(name='label_shape'):\n      self.assertLen(label, self._batch_size_test)\n\n    with self.subTest(name='label_type'):\n      self.assertTrue(np.issubdtype(label.dtype, int))\n\n    with self.subTest(name='label_values'):\n      self.assertTrue((label >= 0).all() and\n                      (label <= self._dataset.num_classes).all())\n\n  def test_train_data_length(self):\n    \"\"\"Tests length of training dataset.\"\"\"\n    total_count = 0\n    for batch in self._dataset.get_train():\n      total_count += len(batch['label'])\n\n    self.assertEqual(total_count, self._dataset.get_train_len())\n\n  def test_test_data_length(self):\n    \"\"\"Tests length of test dataset.\"\"\"\n    total_count = 0\n    for batch in self._dataset.get_test():\n      total_count += len(batch['label'])\n\n    # Check image size/content.\n    self.assertEqual(total_count, self._dataset.get_test_len())\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/fixed_param.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Weight Symmetry: Train models with fixed param, but diff. depth and width.\"\"\"\nimport ast\nimport functools\nimport operator\nfrom os import path\nfrom typing import List, Sequence\nimport uuid\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nimport flax\nfrom flax.metrics import tensorboard\nfrom flax.training import lr_schedule\nimport jax\nimport jax.numpy as jnp\nfrom rigl.experimental.jax.datasets import dataset_factory\nfrom rigl.experimental.jax.models import mnist_fc\nfrom rigl.experimental.jax.models import model_factory\nfrom rigl.experimental.jax.pruning import masked\nfrom rigl.experimental.jax.pruning import symmetry\nfrom rigl.experimental.jax.training import training\nfrom rigl.experimental.jax.utils import utils\n  experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))\n\n  logging.info('Saving experimental results to %s', experiment_dir)\n\n  host_count = jax.host_count()\n  local_device_count = jax.local_device_count()\n  logging.info('Device count: %d, host count: %d, local device count: %d',\n               jax.device_count(), host_count, local_device_count)\n\n  if jax.host_id() == 0:\n    summary_writer = tensorboard.SummaryWriter(experiment_dir)\n\n  dataset = dataset_factory.create_dataset(\n      FLAGS.dataset,\n      FLAGS.batch_size,\n      FLAGS.batch_size_test,\n      shuffle_buffer_size=FLAGS.shuffle_buffer_size)\n\n  logging.info('Training %s on the %s dataset...', MODEL, FLAGS.dataset)\n\n  rng = jax.random.PRNGKey(FLAGS.random_seed)\n\n  input_shape = (1,) + dataset.shape\n\n  input_len = functools.reduce(operator.mul, dataset.shape)\n\n  features = mnist_fc.feature_dim_for_param(\n      input_len,\n      FLAGS.param_count,\n      FLAGS.depth)\n\n  logging.info('Model Configuration: %s', str(features))\n\n  base_model, _ = model_factory.create_model(\n      MODEL,\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes,\n      features=features)\n\n  model_param_count = utils.count_param(base_model, ('kernel',))\n\n  logging.info(\n      'Model Config: param.: %d, depth: %d. max width: %d, min width: %d',\n      model_param_count, len(features), max(features), min(features))\n\n  logging.info('Generating random mask based on model')\n\n  # Re-initialize the RNG to maintain same training pattern (as in prune code).\n  mask_rng = jax.random.PRNGKey(FLAGS.random_seed)\n  mask = masked.shuffled_mask(\n      base_model,\n      rng=mask_rng,\n      sparsity=FLAGS.mask_sparsity)\n\n  if jax.host_id() == 0:\n    mask_stats = symmetry.get_mask_stats(mask)\n    logging.info('Mask stats: %s', str(mask_stats))\n\n\n    for label, value in mask_stats.items():\n      try:\n        summary_writer.scalar(f'mask/{label}', value, 0)\n      # This is needed because permutations (long int) can't be cast to float32.\n      except (OverflowError, ValueError):\n        summary_writer.text(f'mask/{label}', str(value), 0)\n        logging.error('Could not write mask/%s to tensorflow summary as float32'\n                      ', writing as string instead.', label)\n\n    if FLAGS.dump_json:\n      mask_stats['permutations'] = str(mask_stats['permutations'])\n      utils.dump_dict_json(\n          mask_stats, path.join(experiment_dir, 'mask_stats.json'))\n\n    if FLAGS.dump_json:\n      mask_stats['permutations'] = str(mask_stats['permutations'])\n      utils.dump_dict_json(mask_stats,\n                           path.join(experiment_dir, 'mask_stats.json'))\n\n    model_stats = {\n        'depth': len(features),\n        'max_width': max(features),\n        'min_width': min(features),\n    }\n    model_stats.update(\n        {'feature_{}'.format(i): value for i, value in enumerate(features)})\n\n\n    if FLAGS.dump_json:\n      utils.dump_dict_json(model_stats,\n                           path.join(experiment_dir, 'model_stats.json'))\n\n  model, initial_state = model_factory.create_model(\n      'MNIST_FC',\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes,\n      features=features, masks=mask)\n\n  if FLAGS.opt == 'Adam':\n    optimizer = flax.optim.Adam(\n        learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)\n  elif FLAGS.opt == 'Momentum':\n    optimizer = flax.optim.Momentum(\n        learning_rate=FLAGS.lr,\n        beta=FLAGS.momentum,\n        weight_decay=FLAGS.weight_decay,\n        nesterov=False)\n  else:\n    raise ValueError('Unknown Optimizer: {}'.format(FLAGS.opt))\n\n  steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size\n\n  if FLAGS.lr_schedule == 'constant':\n    lr_fn = lr_schedule.create_constant_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch)\n  elif FLAGS.lr_schedule == 'stepped':\n    lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)\n    lr_fn = lr_schedule.create_stepped_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, lr_schedule_steps)\n  elif FLAGS.lr_schedule == 'cosine':\n    lr_fn = lr_schedule.create_cosine_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, FLAGS.epochs)\n  else:\n    raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule))\n\n  if jax.host_id() == 0:\n    trainer = training.Trainer(\n        optimizer,\n        model,\n        initial_state,\n        dataset,\n        rng,\n        summary_writer=summary_writer,\n    )\n  else:\n    trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)\n\n  _, best_metrics = trainer.train(\n      FLAGS.epochs,\n      lr_fn=lr_fn,\n      update_iter=FLAGS.update_iterations,\n      update_epoch=FLAGS.update_epoch,\n  )\n\n  logging.info('Best metrics: %s', str(best_metrics))\n\n  if jax.host_id() == 0:\n    if FLAGS.dump_json:\n      utils.dump_dict_json(best_metrics,\n                           path.join(experiment_dir, 'best_metrics.json'))\n\n    for label, value in best_metrics.items():\n      summary_writer.scalar('best/{}'.format(label), value,\n                            FLAGS.epochs * steps_per_epoch)\n    summary_writer.close()\n\n\ndef main(argv: List[str]):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n  run_training()\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "rigl/experimental/jax/fixed_param_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.fixed_param.\"\"\"\nimport glob\nfrom os import path\nimport tempfile\n\nfrom absl.testing import absltest\nfrom absl.testing import flagsaver\n\nfrom rigl.experimental.jax import fixed_param\n\n\nclass FixedParamTest(absltest.TestCase):\n\n  def test_run(self):\n    \"\"\"Tests if the driver for shuffled training runs correctly.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n    )\n\n    with flagsaver.flagsaver(**eval_flags):\n      fixed_param.main([])\n\n    with self.subTest(name='tf_summary_file_exists'):\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/models/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "rigl/experimental/jax/models/cifar10_cnn.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"CIFAR10 CNN.\n\nA small CNN for the CIFAR10 dataset, consists of a number of convolutional\nlayers (determined by length of filters parameter), followed by a\nfully-connected layer.\n\"\"\"\nfrom typing import Callable, Mapping, Optional, Sequence\n\nfrom absl import logging\nimport flax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import init\nfrom rigl.experimental.jax.pruning import masked\n\n\nclass CIFAR10CNN(flax.deprecated.nn.Module):\n  \"\"\"Small CIFAR10 CNN.\"\"\"\n\n  def apply(self,\n            inputs,\n            num_classes,\n            filter_shape = (3, 3),\n            filters = (32, 32, 64, 64, 128, 128),\n            init_fn=flax.deprecated.nn.initializers.kaiming_normal,\n            train=True,\n            activation_fn = flax.deprecated.nn.relu,\n            masks = None,\n            masked_layer_indices = None):\n    \"\"\"Applies a convolution to the inputs.\n\n    Args:\n      inputs: Input data with dimensions (batch, spatial_dims..., features).\n      num_classes: Number of classes in the dataset.\n      filter_shape: Shape of the convolutional filters.\n      filters: Number of filters in each convolutional layer, and number of conv\n        layers (given by length of sequence).\n      init_fn: Initialization function used for convolutional layers.\n      train: If model is being evaluated in training mode or not.\n      activation_fn: Activation function to be used for convolutional layers.\n      masks: Masks of the layers in this model, in the same form as\n         module params, or None.\n      masked_layer_indices: The layer indices of layers in model to be masked.\n\n    Returns:\n      A tensor of shape (batch, num_classes), containing the logit output.\n\n    Raises:\n      ValueError if the number of pooling layers is too many for the given input\n        size, or if the provided mask is not of the correct depth for the model.\n    \"\"\"\n    # Note: First dim is batch, last dim is channels, other dims are \"spatial\".\n    if not all([(dim >= 2**(len(filters)//2)) for dim in inputs.shape[1:-2]]):\n      raise ValueError(\n          'Input spatial size, {}, does not allow {} pooling layers.'.format(\n              str(inputs.shape[1:-2]), len(filters))\n          )\n\n    depth = 1 + len(filters)\n    masks = masked.generate_model_masks(depth, masks,\n                                        masked_layer_indices)\n\n    batch_norm = flax.deprecated.nn.BatchNorm.partial(\n        use_running_average=not train, momentum=0.99, epsilon=1e-5)\n\n    for i, filter_num in enumerate(filters):\n      if f'MaskedModule_{i}' in masks:\n        logging.info('Layer %d is masked in model', i)\n        mask = masks[f'MaskedModule_{i}']\n        inputs = masked.masked(flax.deprecated.nn.Conv, mask)(\n            inputs,\n            features=filter_num,\n            kernel_size=filter_shape,\n            kernel_init=init.sparse_init(\n                init_fn(), mask['kernel'] if mask is not None else None))\n      else:\n        inputs = flax.deprecated.nn.Conv(\n            inputs,\n            features=filter_num,\n            kernel_size=filter_shape,\n            kernel_init=init_fn())\n      inputs = batch_norm(inputs, name='bn_conv_{}'.format(i))\n      inputs = activation_fn(inputs)\n\n      if i % 2 == 1:\n        inputs = flax.deprecated.nn.max_pool(\n            inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID')\n\n    # Global average pooling if we have spatial dimensions left.\n    inputs = flax.deprecated.nn.avg_pool(\n        inputs, window_shape=(inputs.shape[1:-1]), padding='VALID')\n    inputs = inputs.reshape((inputs.shape[0], -1))\n\n    # This is effectively a Dense layer, but we cast it as a convolution layer\n    # to allow us to easily propagate masks, avoiding b/156135283.\n    inputs = flax.deprecated.nn.Conv(\n        inputs,\n        features=num_classes,\n        kernel_size=inputs.shape[1:-1],\n        kernel_init=flax.deprecated.nn.initializers.xavier_normal())\n    inputs = batch_norm(inputs, name='bn_dense_1')\n    inputs = jnp.squeeze(inputs)\n    return flax.deprecated.nn.log_softmax(inputs)\n"
  },
  {
    "path": "rigl/experimental/jax/models/cifar10_cnn_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.models.cifar10_cnn.\"\"\"\nfrom absl.testing import absltest\nimport flax\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.models import cifar10_cnn\n\n\nclass CIFAR10CNNTest(absltest.TestCase):\n  \"\"\"Tests the CIFAR10CNN model.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._num_classes = 10\n    self._batch_size = 2\n    self._input_shape = ((self._batch_size, 32, 32, 3), jnp.float32)\n    self._input = jnp.zeros(*self._input_shape)\n\n  def test_output_shapes(self):\n    \"\"\"Tests the output shapes of the model.\"\"\"\n    with flax.deprecated.nn.stateful() as initial_state:\n      _, initial_params = cifar10_cnn.CIFAR10CNN.init_by_shape(\n          self._rng, (self._input_shape,), num_classes=self._num_classes)\n      model = flax.deprecated.nn.Model(cifar10_cnn.CIFAR10CNN, initial_params)\n\n    with flax.deprecated.nn.stateful(initial_state, mutable=False):\n      logits = model(self._input, num_classes=self._num_classes, train=False)\n\n    self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))\n\n  def test_invalid_spatial_dimensions(self):\n    \"\"\"Tests model with an invalid spatial dimension parameters.\"\"\"\n    with self.assertRaisesRegex(ValueError, 'Input spatial size, '):\n      cifar10_cnn.CIFAR10CNN.init_by_shape(\n          self._rng, (self._input_shape,),\n          num_classes=self._num_classes,\n          filters=20 * (32,))\n\n  def test_invalid_masks_depth(self):\n    \"\"\"Tests model mask with the incorrect depth for the given model.\"\"\"\n    invalid_masks = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros((self._batch_size, 3, 3, 32))\n        }\n    }\n\n    with self.assertRaisesRegex(\n        ValueError, 'Mask is invalid for model.'):\n      cifar10_cnn.CIFAR10CNN.init_by_shape(\n          self._rng, (self._input_shape,),\n          num_classes=self._num_classes,\n          masks=invalid_masks)\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/models/mnist_cnn.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"MNIST CNN.\n\nA small CNN for the MNIST dataset, consists of a number of convolutional layers\n(determined by length of filters parameter), followed by a fully-connected\nlayer.\n\"\"\"\nfrom typing import Callable, Mapping, Optional, Sequence\n\nfrom absl import logging\nimport flax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import init\nfrom rigl.experimental.jax.pruning import masked\n\n\nclass MNISTCNN(flax.deprecated.nn.Module):\n  \"\"\"Small MNIST CNN.\"\"\"\n\n  def apply(self,\n            inputs,\n            num_classes,\n            filter_shape = (5, 5),\n            filters = (16, 32),\n            dense_size = 64,\n            train=True,\n            init_fn = flax.deprecated.nn.initializers.kaiming_normal,\n            activation_fn = flax.deprecated.nn.relu,\n            masks = None,\n            masked_layer_indices = None):\n    \"\"\"Applies a convolution to the inputs.\n\n    Args:\n      inputs: Input data with dimensions (batch, spatial_dims..., features).\n      num_classes: Number of classes in the dataset.\n      filter_shape: Shape of the convolutional filters.\n      filters: Number of filters in each convolutional layer, and number of conv\n        layers (given by length of sequence).\n      dense_size: Number of filters in each convolutional layer, and number of\n        conv layers (given by length of sequence).\n      train: If model is being evaluated in training mode or not.\n      init_fn: Initialization function used for convolutional layers.\n      activation_fn: Activation function to be used for convolutional layers.\n      masks: Masks of the layers in this model, in the same form as\n             module params, or None.\n      masked_layer_indices: The layer indices of layers in model to be masked.\n\n    Returns:\n      A tensor of shape (batch, num_classes), containing the logit output.\n    Raises:\n      ValueError if the number of pooling layers is too many for the given input\n        size.\n    \"\"\"\n    # Note: First dim is batch, last dim is channels, other dims are \"spatial\".\n    if not all([(dim >= 2**len(filters)) for dim in inputs.shape[1:-2]]):\n      raise ValueError(\n          'Input spatial size, {}, does not allow {} pooling layers.'.format(\n              str(inputs.shape[1:-2]), len(filters))\n          )\n\n    depth = 2 + len(filters)\n    masks = masked.generate_model_masks(depth, masks,\n                                        masked_layer_indices)\n\n    batch_norm = flax.deprecated.nn.BatchNorm.partial(\n        use_running_average=not train, momentum=0.99, epsilon=1e-5)\n\n    for i, filter_num in enumerate(filters):\n      if f'MaskedModule_{i}' in masks:\n        logging.info('Layer %d is masked in model', i)\n        mask = masks[f'MaskedModule_{i}']\n        inputs = masked.masked(flax.deprecated.nn.Conv, mask)(\n            inputs,\n            features=filter_num,\n            kernel_size=filter_shape,\n            kernel_init=init.sparse_init(\n                init_fn(), mask['kernel'] if mask is not None else None))\n      else:\n        inputs = flax.deprecated.nn.Conv(\n            inputs,\n            features=filter_num,\n            kernel_size=filter_shape,\n            kernel_init=init_fn())\n      inputs = batch_norm(inputs, name='bn_conv_{}'.format(i))\n      inputs = activation_fn(inputs)\n\n      if i < len(filters) - 1:\n        inputs = flax.deprecated.nn.max_pool(\n            inputs, window_shape=(2, 2), strides=(2, 2), padding='VALID')\n\n    # Global average pool at end of convolutional layers.\n    inputs = flax.deprecated.nn.avg_pool(\n        inputs, window_shape=inputs.shape[1:-1], padding='VALID')\n\n    # This is effectively a Dense layer, but we cast it as a convolution layer\n    # to allow us to easily propagate masks, avoiding b/156135283.\n    if f'MaskedModule_{depth - 2}' in masks:\n      mask_dense_1 = masks[f'MaskedModule_{depth - 2}']\n      inputs = masked.masked(flax.deprecated.nn.Conv, mask_dense_1)(\n          inputs,\n          features=dense_size,\n          kernel_size=inputs.shape[1:-1],\n          kernel_init=init.sparse_init(\n              init_fn(),\n              mask_dense_1['kernel'] if mask_dense_1 is not None else None))\n    else:\n      inputs = flax.deprecated.nn.Conv(\n          inputs,\n          features=dense_size,\n          kernel_size=inputs.shape[1:-1],\n          kernel_init=init_fn())\n    inputs = batch_norm(inputs, name='bn_dense_1')\n    inputs = activation_fn(inputs)\n\n    inputs = flax.deprecated.nn.Dense(\n        inputs,\n        features=num_classes,\n        kernel_init=flax.deprecated.nn.initializers.xavier_normal())\n    inputs = batch_norm(inputs, name='bn_dense_2')\n    inputs = jnp.squeeze(inputs)\n    return flax.deprecated.nn.log_softmax(inputs)\n"
  },
  {
    "path": "rigl/experimental/jax/models/mnist_cnn_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.models.mnist_cnn.\"\"\"\nfrom absl.testing import absltest\nimport flax\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.models import mnist_cnn\n\n\nclass MNISTCNNTest(absltest.TestCase):\n  \"\"\"Tests the MNISTCNN model.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._num_classes = 10\n    self._batch_size = 2\n    self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)\n    self._input = jnp.zeros(*self._input_shape)\n\n  def test_output_shapes(self):\n    \"\"\"Tests the output shapes of the model.\"\"\"\n    with flax.deprecated.nn.stateful() as initial_state:\n      _, initial_params = mnist_cnn.MNISTCNN.init_by_shape(\n          self._rng, (self._input_shape,), num_classes=self._num_classes)\n      model = flax.deprecated.nn.Model(mnist_cnn.MNISTCNN, initial_params)\n\n    with flax.deprecated.nn.stateful(initial_state, mutable=False):\n      logits = model(self._input, num_classes=self._num_classes, train=False)\n\n    self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))\n\n  def test_invalid_depth(self):\n    \"\"\"Tests model mask with the incorrect depth for the given model.\"\"\"\n    with self.assertRaisesRegex(ValueError, 'Input spatial size, '):\n      mnist_cnn.MNISTCNN.init_by_shape(\n          self._rng, (self._input_shape,),\n          num_classes=self._num_classes,\n          filters=10 * (32,))\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/models/mnist_fc.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"MNIST Fully-Connected Neural Network.\n\nA fully-connected model for the MNIST dataset, consists of a number of\ndense layers (determined by length of features parameter).\n\"\"\"\nimport math\nfrom typing import Callable, Mapping, Optional, Sequence, Tuple\n\nfrom absl import logging\nimport flax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import init\nfrom rigl.experimental.jax.pruning import masked\n\n\ndef feature_dim_for_param(input_len,\n                          param_count,\n                          depth,\n                          depth_mult = 2.):\n  \"\"\"Calculates feature dimensions for a fixed parameter count and depth.\n\n  This is calculated for the specific case of a fully-connected neural\n  network, where each layer consists of l * a**i neurons, where a is a\n  multiplier for each layer.\n\n  Assume,\n    x is the input size,\n    a is the depth multiplier,\n    l is the initial layer width,\n    d is the depth.\n\n  The total number of parameters, n, is then given by,\n  $$n = x*l + l^2 * sum_{i=2}^d a^{2i-3})$$.\n\n  Args:\n    input_len: Input size.\n    param_count: Number of parameters model should maintain.\n    depth: Depth of the model.\n    depth_mult: The layer width multiplier w.r.t. depth.\n\n  Returns:\n    The feature specification for a fully-connected model, as a tuple of layer\n    widths.\n\n  Raises:\n    ValueError: If the given number of parameters is too low for the given\n    depth and input size.\n  \"\"\"\n  # Calculate the initial width for the first layer.\n  if depth == 1:\n    initial_width = param_count / input_len\n  else:\n    # l = ((x^2 + 4cn)^{1/2} - x)/(2c) where c = sum_{i=2}^d a^{2i-3}.\n    depth_sum = sum(depth_mult**(2 * i - 3) for i in range(2, depth + 1))\n    initial_width = (math.sqrt(input_len**2 + 4 * depth_sum * param_count) -\n                     input_len) / (2 * depth_sum)\n\n  if initial_width < 1:\n    raise ValueError(\n        'Expected parameter count too low for given depth and input size.')\n\n  return tuple(int(int(initial_width) * depth_mult**i) for i in range(depth))\n\n\nclass MNISTFC(flax.deprecated.nn.Module):\n  \"\"\"MNIST Fully-Connected Neural Network.\"\"\"\n\n  def apply(self,\n            inputs,\n            num_classes,\n            features = (32, 32),\n            train=True,\n            init_fn = flax.deprecated.nn.initializers.kaiming_normal,\n            activation_fn = flax.deprecated.nn.relu,\n            masks = None,\n            masked_layer_indices = None,\n            dropout_rate = 0.):\n    \"\"\"Applies fully-connected neural network to the inputs.\n\n    Args:\n      inputs: Input data with dimensions (batch, features), if features has more\n        than one dimension, it is flattened.\n      num_classes: Number of classes in the dataset.\n      features: Number of neurons in each layer, and number of layers (given by\n        length of sequence) + one layer for softmax.\n      train: If model is being evaluated in training mode or not.\n      init_fn: Initialization function used for dense layers.\n      activation_fn: Activation function to be used for dense layers.\n      masks: Masks of the layers in this model, in the same form as module\n        params, or None.\n      masked_layer_indices: The layer indices of layers in model to be masked.\n      dropout_rate: Dropout rate, if 0 then dropout is not used (default).\n\n    Returns:\n      A tensor of shape (batch, num_classes), containing the logit output.\n    \"\"\"\n    batch_norm = flax.deprecated.nn.BatchNorm.partial(\n        use_running_average=not train, momentum=0.99, epsilon=1e-5)\n\n    depth = 1 + len(features)\n    masks = masked.generate_model_masks(depth, masks,\n                                        masked_layer_indices)\n\n    # If inputs are in image dimensions, flatten image.\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    for i, feature_num in enumerate(features):\n      if f'MaskedModule_{i}' in masks:\n        logging.info('Layer %d is masked in model', i)\n        mask = masks[f'MaskedModule_{i}']\n        inputs = masked.masked(flax.deprecated.nn.Dense, mask)(\n            inputs,\n            features=feature_num,\n            kernel_init=init.sparse_init(\n                init_fn(), mask['kernel'] if mask is not None else None))\n      else:\n        inputs = flax.deprecated.nn.Dense(\n            inputs, features=feature_num, kernel_init=init_fn())\n      inputs = batch_norm(inputs, name=f'bn_conv_{i}')\n      inputs = activation_fn(inputs)\n      if dropout_rate > 0.0:\n        inputs = flax.deprecated.nn.dropout(\n            inputs, dropout_rate, deterministic=not train)\n\n    inputs = flax.deprecated.nn.Dense(\n        inputs,\n        features=num_classes,\n        kernel_init=flax.deprecated.nn.initializers.xavier_normal())\n\n    return flax.deprecated.nn.log_softmax(inputs)\n"
  },
  {
    "path": "rigl/experimental/jax/models/mnist_fc_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.models.mnist_fc.\"\"\"\nfrom typing import Sequence\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport flax\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.models import mnist_fc\nfrom rigl.experimental.jax.utils import utils\n\nPARAM_COUNT_PARAM: Sequence[str] = ('kernel',)\n\n\nclass MNISTFCTest(parameterized.TestCase):\n  \"\"\"Tests the MNISTFC model.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._num_classes = 10\n    self._batch_size = 2\n    self._input_len = 28*28*1\n    self._input_shape = ((self._batch_size, self._input_len), jnp.float32)\n    self._input = jnp.zeros((self._batch_size, self._input_len), jnp.float32)\n    self._param_count = 1e7\n\n  def test_output_shapes(self):\n    \"\"\"Tests the output shape from the model.\"\"\"\n    with flax.deprecated.nn.stateful() as initial_state:\n      _, initial_params = mnist_fc.MNISTFC.init_by_shape(\n          self._rng, (self._input_shape,), num_classes=self._num_classes)\n      model = flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params)\n\n    with flax.deprecated.nn.stateful(initial_state, mutable=False):\n      logits = model(self._input, num_classes=self._num_classes, train=False)\n\n    self.assertTupleEqual(logits.shape, (self._batch_size, self._num_classes))\n\n  def test_invalid_masks_depth(self):\n    \"\"\"Tests a model with an invalid mask.\"\"\"\n    invalid_masks = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros((self._batch_size, 5 * 5 * 16))\n        }\n    }\n\n    with self.assertRaisesRegex(\n        ValueError, 'Mask is invalid for model.'):\n      mnist_fc.MNISTFC.init_by_shape(\n          self._rng,\n          (self._input_shape,),\n          num_classes=self._num_classes,\n          masks=invalid_masks)\n\n  def _create_model(self, features):\n    \"\"\"Convenience fn to create a FLAX model .\"\"\"\n    _, initial_params = mnist_fc.MNISTFC.init_by_shape(\n        self._rng,\n        (self._input_shape,),\n        num_classes=self._num_classes,\n        features=features)\n    return flax.deprecated.nn.Model(mnist_fc.MNISTFC, initial_params)\n\n  @parameterized.parameters(*range(1, 6))\n  def test_feature_dim_for_param_depth(self, depth):\n    \"\"\"Tests feature_dim_for_param with multiple depths.\"\"\"\n    features = mnist_fc.feature_dim_for_param(self._input_len,\n                                              self._param_count, depth)\n    model = self._create_model(features)\n    total_size = utils.count_param(model, PARAM_COUNT_PARAM)\n\n    with self.subTest(name='FeatureDimLen'):\n      self.assertLen(features, depth)\n\n    with self.subTest(name='FeatureDimParamCount'):\n      self.assertBetween(total_size, self._param_count * 0.95,\n                         self._param_count * 1.05)\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/models/model_factory.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Factory for neural network models.\n\nAttributes:\n  MODELS: A list of the models that can be created.\n\"\"\"\nfrom typing import Any, Callable, Mapping, Sequence, Tuple, Type\n\nimport flax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.models import cifar10_cnn\nfrom rigl.experimental.jax.models import mnist_cnn\nfrom rigl.experimental.jax.models import mnist_fc\n\nMODELS: Mapping[str, Type[flax.deprecated.nn.Model]] = {\n    'MNIST_CNN': mnist_cnn.MNISTCNN,\n    'MNIST_FC': mnist_fc.MNISTFC,\n    'CIFAR10_CNN': cifar10_cnn.CIFAR10CNN,\n}\n\n\ndef create_model(\n    name, rng,\n    input_specs, **kwargs\n):\n  \"\"\"Creates a Model.\n\n  Args:\n      name: the name of the model to instantiate.\n      rng : the random number generator to use for init.\n      input_specs: an iterable of (shape, dtype) pairs specifying the inputs.\n      **kwargs: list of model specific keyword arguments.\n\n  Returns:\n      A tuple of FLAX model (flax.deprecated.nn.Model), and initial model state.\n\n  Raises:\n      ValueError if a model with the given name does not exist.\n  \"\"\"\n  if name not in MODELS:\n    raise ValueError('No such model: {}'.format(name))\n\n  with flax.deprecated.nn.stateful() as init_state:\n    with flax.deprecated.nn.stochastic(rng):\n      model_class = MODELS[name].partial(**kwargs)\n      _, params = model_class.init_by_shape(rng, input_specs)\n\n  return flax.deprecated.nn.Model(model_class, params), init_state\n\n\ndef update_model(model,\n                 **kwargs):\n  \"\"\"Updates a model to use different model arguments, but same parameters.\n\n  Args:\n      model: The model to update.\n      **kwargs: List of model specific keyword arguments.\n\n  Returns:\n      A FLAX model.\n  \"\"\"\n  return flax.deprecated.nn.Model(model.module.partial(**kwargs), model.params)\n"
  },
  {
    "path": "rigl/experimental/jax/models/model_factory_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.models.model_factory.\"\"\"\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport flax\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.models import model_factory\n\n\nclass ModelCommonTest(parameterized.TestCase):\n  \"\"\"Tests the model factory.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._input_shape = ((1, 28, 28, 1), jnp.float32)\n    self._num_classes = 10\n\n  def _create_model(self, model_name):\n    return model_factory.create_model(\n        model_name,\n        self._rng, (self._input_shape,),\n        num_classes=self._num_classes)\n\n  @parameterized.parameters(*model_factory.MODELS.keys())\n  def test_model_supported(self, model_name):\n    \"\"\"Tests supported models.\"\"\"\n    model, state = self._create_model(model_name)\n\n    with self.subTest(name='test_model_supported_model_instance'):\n      self.assertIsInstance(model, flax.deprecated.nn.Model)\n\n    with self.subTest(name='test_model_supported_collection_instance'):\n      self.assertIsInstance(state, flax.deprecated.nn.Collection)\n\n  def test_model_unsupported(self):\n    \"\"\"Tests unsupported models.\"\"\"\n    with self.assertRaisesRegex(ValueError, 'No such model: unsupported'):\n      self._create_model('unsupported')\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/prune.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Weight Symmetry: Iteratively Prune Model during Training.\n\nCommand for training and pruning an MNIST fully-connected model for 10 epochs\nwith a fixed pruning rate of 0.95:\n\nprune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10\n--pruning_rate=0.95\n\nCommand for training and pruning an MNIST fully-connected model for 10\nepochs, with pruning rates 0.3, 0.6 and 0.95 at epochs 2, 5, and 8 respectively\nfor all layers:\n\nprune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10\n--pruning_schedule='[(2, 0.3), (5, 0.6), (8, 0.95)]'\n\nCommand for doing the same, but performing pruning only on the second layer:\n\nprune --xm_runlocal --dataset=MNIST --model=MNIST_FC --epochs=10\n--pruning_schedule=\"{'1': [(2, 0.3), (5, 0.6), (8, 0.95)]}\"\n\"\"\"\nimport ast\nfrom collections import abc\nimport functools\nfrom os import path\nfrom typing import List\nimport uuid\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nimport flax\nfrom flax.metrics import tensorboard\nfrom flax.training import lr_schedule\nimport jax\nimport jax.numpy as jnp\nfrom rigl.experimental.jax.datasets import dataset_factory\nfrom rigl.experimental.jax.models import model_factory\nfrom rigl.experimental.jax.training import training\nfrom rigl.experimental.jax.utils import utils\n\n  experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))\n\n  logging.info('Saving experimental results to %s', experiment_dir)\n\n  host_count = jax.host_count()\n  local_device_count = jax.local_device_count()\n  logging.info('Device count: %d, host count: %d, local device count: %d',\n               jax.device_count(), host_count, local_device_count)\n\n  if jax.host_id() == 0:\n    summary_writer = tensorboard.SummaryWriter(experiment_dir)\n\n  dataset = dataset_factory.create_dataset(\n      FLAGS.dataset,\n      FLAGS.batch_size,\n      FLAGS.batch_size_test,\n      shuffle_buffer_size=FLAGS.shuffle_buffer_size)\n\n  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)\n\n  rng = jax.random.PRNGKey(FLAGS.random_seed)\n\n  input_shape = (1,) + dataset.shape\n  base_model, _ = model_factory.create_model(\n      FLAGS.model,\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes)\n\n  initial_model, initial_state = model_factory.create_model(\n      FLAGS.model,\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes,\n      masked_layer_indices=FLAGS.masked_layer_indices)\n\n  if FLAGS.optimizer == 'Adam':\n    optimizer = flax.optim.Adam(\n        learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)\n  elif FLAGS.optimizer == 'Momentum':\n    optimizer = flax.optim.Momentum(\n        learning_rate=FLAGS.lr,\n        beta=FLAGS.momentum,\n        weight_decay=FLAGS.weight_decay,\n        nesterov=False)\n\n  steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size\n\n  if FLAGS.lr_schedule == LR_SCHEDULE_CONSTANT:\n    lr_fn = lr_schedule.create_constant_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch)\n  elif FLAGS.lr_schedule == LR_SCHEDULE_STEPPED:\n    lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)\n    lr_fn = lr_schedule.create_stepped_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, lr_schedule_steps)\n  elif FLAGS.lr_schedule == LR_SCHEDULE_COSINE:\n    lr_fn = lr_schedule.create_cosine_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, FLAGS.epochs)\n  else:\n    raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}')\n\n  # Reuses the FLAX learning rate schedule framework for pruning rate schedule.\n  pruning_fn_p = functools.partial(\n      lr_schedule.create_stepped_learning_rate_schedule, FLAGS.pruning_rate,\n      steps_per_epoch)\n  if FLAGS.pruning_schedule:\n    pruning_schedule = ast.literal_eval(FLAGS.pruning_schedule)\n    if isinstance(pruning_schedule, abc.Mapping):\n      pruning_rate_fn = {\n          f'MaskedModule_{layer_num}': pruning_fn_p(schedule)\n          for layer_num, schedule in pruning_schedule.items()\n      }\n    else:\n      pruning_rate_fn = pruning_fn_p(pruning_schedule)\n  else:\n    pruning_rate_fn = lr_schedule.create_constant_learning_rate_schedule(\n        FLAGS.pruning_rate, steps_per_epoch)\n\n  if jax.host_id() == 0:\n    trainer = training.Trainer(\n        optimizer,\n        initial_model,\n        initial_state,\n        dataset,\n        rng,\n        summary_writer=summary_writer,\n    )\n  else:\n    trainer = training.Trainer(\n        optimizer, initial_model, initial_state, dataset, rng)\n\n  _, best_metrics = trainer.train(\n      FLAGS.epochs,\n      lr_fn=lr_fn,\n      pruning_rate_fn=pruning_rate_fn,\n      update_iter=FLAGS.update_iterations,\n      update_epoch=FLAGS.update_epoch,\n  )\n\n  logging.info('Best metrics: %s', str(best_metrics))\n\n  if jax.host_id() == 0:\n    if FLAGS.dump_json:\n      utils.dump_dict_json(best_metrics,\n                           path.join(experiment_dir, 'best_metrics.json'))\n\n    for label, value in best_metrics.items():\n      summary_writer.scalar(f'best/{label}', value,\n                            FLAGS.epochs * steps_per_epoch)\n    summary_writer.close()\n\n\ndef main(argv: List[str]):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n  run_training()\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "rigl/experimental/jax/prune_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.prune.\"\"\"\nimport glob\nfrom os import path\n\nfrom absl.testing import absltest\nfrom absl.testing import flagsaver\n\nfrom rigl.experimental.jax import prune\n\n\nclass PruneTest(absltest.TestCase):\n\n  def test_prune_fixed_schedule(self):\n    \"\"\"Tests training/pruning driver with a fixed global sparsity.\"\"\"\n    experiment_dir = self.create_tempdir().full_path\n    eval_flags = dict(\n        epochs=1,\n        pruning_rate=0.95,\n        experiment_dir=experiment_dir,\n    )\n\n    with flagsaver.flagsaver(**eval_flags):\n      prune.main([])\n\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_prune_global_pruning_schedule(self):\n    \"\"\"Tests training/pruning driver with a global sparsity schedule.\"\"\"\n    experiment_dir = self.create_tempdir().full_path\n    eval_flags = dict(\n        epochs=10,\n        pruning_schedule='[(5, 0.33), (7, 0.66), (9, 0.95)]',\n        experiment_dir=experiment_dir,\n    )\n\n    with flagsaver.flagsaver(**eval_flags):\n      prune.main([])\n\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_prune_local_pruning_schedule(self):\n    \"\"\"Tests training/pruning driver with a single layer sparsity schedule.\"\"\"\n    experiment_dir = self.create_tempdir().full_path\n    eval_flags = dict(\n        epochs=10,\n        pruning_schedule='{1:[(5, 0.33), (7, 0.66), (9, 0.95)]}',\n        experiment_dir=experiment_dir,\n    )\n\n    with flagsaver.flagsaver(**eval_flags):\n      prune.main([])\n\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/init.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tools for initialization of masked models.\"\"\"\nimport functools\nfrom typing import Callable, Sequence, Optional\n\nimport flax\nimport jax\nimport jax.numpy as jnp\n\n\ndef sparse_init(\n    base_init,\n    mask,\n    dtype=jnp.float32):\n  \"\"\"Weight initializer with correct fan in/fan out for a masked model.\n\n  The weight initializer uses any dense initializer to correctly initialize a\n  masked weight matrix by calling the given initialization method with the\n  correct fan in/fan out for every neuron in the layer. If the mask is None, it\n  reverts to the original initialization method.\n\n  Args:\n    base_init: The base (dense) initialization method to use.\n    mask: The layer's mask, or None.\n    dtype: The weight array jnp.dtype.\n\n  Returns:\n    An initialization method that is mask aware for the given layer and mask.\n  \"\"\"\n  def init(rng, shape, dtype=dtype):\n    if mask is None:\n      return base_init(rng, shape, dtype)\n\n    # Find the ablated neurons in the mask, to determine correct fan_out.\n    neuron_weight_count = jnp.sum(\n        jnp.reshape(mask, (-1, mask.shape[-1])), axis=0)\n    non_zero_neurons = jnp.sum(neuron_weight_count != 0)\n\n    # Special case of completely ablated weight matrix/layer.\n    if jnp.sum(non_zero_neurons) == 0:\n      print('Empty weight mask!')\n      return jnp.zeros(shape, dtype)\n\n    # Neurons have different fan_in w/mask, build up initialization per-unit.\n    init_cols = []\n    rng, *split_rngs = jax.random.split(rng, mask.shape[-1] + 1)\n    for i in range(mask.shape[-1]):\n      # Special case of ablated neuron.\n      if neuron_weight_count[i] == 0:\n        init_cols.append(jnp.zeros(shape[:-1] + (1,), dtype))\n        continue\n\n      # Fake shape of weight matrix with correct fan_in, and fan_out.\n      sparse_shape = (int(neuron_weight_count[i]), int(non_zero_neurons))\n\n      # Use only the first column of init from initializer, since faked fan_out.\n      init = base_init(split_rngs[i], sparse_shape, dtype)[Ellipsis, 0]\n\n      # Expand out to full sparse array.\n      expanded_init = jnp.zeros(\n          mask[Ellipsis, i].shape,\n          dtype).flatten().at[jnp.where(mask[Ellipsis, i].flatten() == 1)].set(init)\n      expanded_init = jnp.reshape(expanded_init, mask[Ellipsis, i].shape)\n      init_cols.append(expanded_init[Ellipsis, jnp.newaxis])\n\n    return jnp.concatenate(init_cols, axis=-1)\n\n  return init\n\n\nxavier_sparse_normal = glorot_sparse_normal = functools.partial(\n    sparse_init, flax.deprecated.nn.initializers.xavier_normal())\nkaiming_sparse_normal = he_sparse_normal = functools.partial(\n    sparse_init, flax.deprecated.nn.initializers.kaiming_normal())\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/init_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.pruning.init.\"\"\"\nfrom typing import Any, Mapping, Optional\n\nfrom absl.testing import absltest\nimport flax\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import init\nfrom rigl.experimental.jax.pruning import masked\n\n\nclass MaskedDense(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    layer_mask = mask['MaskedModule_0'] if mask else None\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=layer_mask,\n        kernel_init=flax.deprecated.nn.initializers.kaiming_normal())\n\n\nclass MaskedDenseSparseInit(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            *args,\n            mask = None,\n            **kwargs):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    layer_mask = mask['MaskedModule_0'] if mask else None\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=layer_mask,\n        kernel_init=init.kaiming_sparse_normal(\n            layer_mask['kernel'] if layer_mask is not None else None),\n        **kwargs)\n\n\nclass MaskedCNN(flax.deprecated.nn.Module):\n  \"\"\"Single-layer CNN Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            mask = None):\n\n    layer_mask = mask['MaskedModule_0'] if mask else None\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Conv,\n        kernel_size=(3, 3),\n        mask=layer_mask,\n        kernel_init=flax.deprecated.nn.initializers.kaiming_normal())\n\n\nclass MaskedCNNSparseInit(flax.deprecated.nn.Module):\n  \"\"\"Single-layer CNN Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            *args,\n            mask = None,\n            **kwargs):\n\n    layer_mask = mask['MaskedModule_0'] if mask else None\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Conv,\n        kernel_size=(3, 3),\n        mask=layer_mask,\n        kernel_init=init.kaiming_sparse_normal(\n            layer_mask['kernel'] if layer_mask is not None else None),\n        **kwargs)\n\n\nclass InitTest(absltest.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._batch_size = 2\n    self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)\n    self._input = jnp.ones(*self._input_shape)\n\n  def test_init_kaiming_sparse_normal_output(self):\n    \"\"\"Tests the output shape/type of kaiming normal sparse initialization.\"\"\"\n    input_array = jnp.ones((64, 16), jnp.float32)\n    mask = jax.random.bernoulli(self._rng, shape=(64, 16))\n\n    base_init = flax.deprecated.nn.initializers.kaiming_normal()(\n        self._rng, input_array.shape, input_array.dtype)\n    sparse_init = init.kaiming_sparse_normal(mask)(self._rng, input_array.shape,\n                                                   input_array.dtype)\n\n    with self.subTest(name='test_sparse_init_output_shape'):\n      self.assertSequenceEqual(sparse_init.shape, base_init.shape)\n\n    with self.subTest(name='test_sparse_init_output_dtype'):\n      self.assertEqual(sparse_init.dtype, base_init.dtype)\n\n    with self.subTest(name='test_sparse_init_output_notallzero'):\n      self.assertTrue((sparse_init != 0).any())\n\n  def test_dense_no_mask(self):\n    \"\"\"Checks that in the special case of no mask, init is same as base_init.\"\"\"\n    _, initial_params = MaskedDense.init_by_shape(self._rng,\n                                                  (self._input_shape,))\n    self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)\n\n    _, initial_params = MaskedDenseSparseInit.init_by_shape(\n        jax.random.PRNGKey(42), (self._input_shape,), mask=None)\n    self._masked_model_sparse_init = flax.deprecated.nn.Model(\n        MaskedDenseSparseInit, initial_params)\n\n    self.assertTrue(\n        jnp.isclose(\n            self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']\n            ['kernel'], self._unmasked_model.params['MaskedModule_0']\n            ['unmasked']['kernel']).all())\n\n  def test_dense_sparse_init_kaiming(self):\n    \"\"\"Checks kaiming normal sparse initialization for dense layer.\"\"\"\n    _, initial_params = MaskedDense.init_by_shape(self._rng,\n                                                  (self._input_shape,))\n    self._unmasked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)\n\n    mask = masked.simple_mask(self._unmasked_model, jnp.ones,\n                              masked.WEIGHT_PARAM_NAMES)\n\n    _, initial_params = MaskedDenseSparseInit.init_by_shape(\n        jax.random.PRNGKey(42), (self._input_shape,), mask=mask)\n    self._masked_model_sparse_init = flax.deprecated.nn.Model(\n        MaskedDenseSparseInit, initial_params)\n\n    mean_init = jnp.mean(\n        self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])\n\n    stddev_init = jnp.std(\n        self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])\n\n    mean_sparse_init = jnp.mean(\n        self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']\n        ['kernel'])\n\n    stddev_sparse_init = jnp.std(\n        self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']\n        ['kernel'])\n\n    with self.subTest(name='test_cnn_sparse_init_mean'):\n      self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,\n                         mean_init + 2 * stddev_init)\n\n    with self.subTest(name='test_cnn_sparse_init_stddev'):\n      self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,\n                         1.5 * stddev_init)\n\n  def test_cnn_sparse_init_kaiming(self):\n    \"\"\"Checks kaiming normal sparse initialization for convolutional layer.\"\"\"\n    _, initial_params = MaskedCNN.init_by_shape(self._rng, (self._input_shape,))\n    self._unmasked_model = flax.deprecated.nn.Model(MaskedCNN, initial_params)\n\n    mask = masked.simple_mask(self._unmasked_model, jnp.ones,\n                              masked.WEIGHT_PARAM_NAMES)\n\n    _, initial_params = MaskedCNNSparseInit.init_by_shape(\n        jax.random.PRNGKey(42), (self._input_shape,), mask=mask)\n    self._masked_model_sparse_init = flax.deprecated.nn.Model(\n        MaskedCNNSparseInit, initial_params)\n\n    mean_init = jnp.mean(\n        self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])\n\n    stddev_init = jnp.std(\n        self._unmasked_model.params['MaskedModule_0']['unmasked']['kernel'])\n\n    mean_sparse_init = jnp.mean(\n        self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']\n        ['kernel'])\n\n    stddev_sparse_init = jnp.std(\n        self._masked_model_sparse_init.params['MaskedModule_0']['unmasked']\n        ['kernel'])\n\n    with self.subTest(name='test_cnn_sparse_init_mean'):\n      self.assertBetween(mean_sparse_init, mean_init - 2 * stddev_init,\n                         mean_init + 2 * stddev_init)\n\n    with self.subTest(name='test_cnn_sparse_init_stddev'):\n      self.assertBetween(stddev_sparse_init, 0.5 * stddev_init,\n                         1.5 * stddev_init)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/mask_factory.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Pruning mask factory.\n\nAttributes:\n  MaskFnType: A type alias for functions to create sparse masks.\n  MASK_TYPES: Masks types that can be created.\n\"\"\"\nfrom typing import Any, Callable, Mapping\n\nimport flax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import masked\n\n# A function to create a mask, takes as arguments: a flax model, JAX PRNG Key,\n# sparsity level as a float in [0, 1].\nMaskFnType = Callable[\n    [flax.deprecated.nn.Model, Callable[[int],\n                                        jnp.array], float], masked.MaskType]\n\nMASK_TYPES: Mapping[str, MaskFnType] = {\n    'random':\n        masked.shuffled_mask,\n    'per_neuron':\n        masked.shuffled_neuron_mask,\n    'per_neuron_no_input_ablation':\n        masked.shuffled_neuron_no_input_ablation_mask,\n    'symmetric':\n        masked.symmetric_mask,\n}\n\n\ndef create_mask(mask_type, base_model,\n                rng, sparsity,\n                **kwargs):\n  \"\"\"Creates a Mask of the given type.\n\n  Args:\n      mask_type: the name of the type of mask to instantiate.\n      base_model: the model to create a mask for.\n      rng : the random number generator to use for init.\n      sparsity: the mask sparsity.\n      **kwargs: list of model specific keyword arguments.\n\n  Returns:\n      A mask for a FLAX model.\n\n  Raises:\n      ValueError if a model with the given name does not exist.\n  \"\"\"\n  if mask_type not in MASK_TYPES:\n    raise ValueError(f'Unknown mask type: {mask_type}')\n\n  return MASK_TYPES[mask_type](base_model, rng, sparsity, **kwargs)\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/mask_factory_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.models.model_factory.\"\"\"\nfrom typing import Mapping, Optional\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport flax\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import mask_factory\nfrom rigl.experimental.jax.pruning import masked\n\n\nclass MaskedDense(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_0'] if mask else None)\n\n\nclass MaskFactoryTest(parameterized.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._input_shape = ((1, 28, 28, 1), jnp.float32)\n    self._num_classes = 10\n    self._sparsity = 0.9\n\n    _, initial_params = MaskedDense.init_by_shape(self._rng,\n                                                  (self._input_shape,))\n    # Use the same initialization for both masked/unmasked models.\n    self._model = flax.deprecated.nn.Model(MaskedDense, initial_params)\n\n  def _create_mask(self, mask_type):\n    return mask_factory.create_mask(\n        mask_type, self._model,\n        self._rng, self._sparsity)\n\n  @parameterized.parameters(*mask_factory.MASK_TYPES.keys())\n  def test_mask_supported(self, mask_type):\n    \"\"\"Tests supported mask types.\"\"\"\n    mask = self._create_mask(mask_type)\n\n    with self.subTest(name='test_mask_type'):\n      self.assertIsInstance(mask, dict)\n\n  def test_mask_unsupported(self):\n    \"\"\"Tests unsupported mask types.\"\"\"\n    with self.assertRaisesRegex(ValueError,\n                                'Unknown mask type: unsupported'):\n      self._create_mask('unsupported')\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/masked.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Masked wrapped for FLAX modules.\n\nAttributes:\n  WEIGHT_PARAM_NAMES: The name of the weight parameters to use.\n  MaskType: Model mask type for static type checking.\n  MaskLayerType: Mask layer type for static type checking.\n  MutableMaskType: Mutable model mask type for static type checking.\n  MutableMaskLayerType: Mutable mask layer type for static type checking.\n\"\"\"\nimport functools\nimport operator\nfrom typing import Any, Callable, Iterator, Mapping, MutableMapping, Optional, Sequence, Tuple, Type\n\nfrom absl import logging\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport jax.ops\n\n# Model weight param names, e.g. 'kernel', (as opposed batch norm param, etc).\nWEIGHT_PARAM_NAMES = ('kernel',)  # Note: Bias is not typically masked.\n\n\n# Mask layer type for static type checking.\nMaskLayerType = Mapping[str, Optional[jnp.array]]\n\n\n# Model mask type for static type checking.\nMaskType = Mapping[str, Optional[MaskLayerType]]\n\n\n# Mask layer type for static type checking.\nMutableMaskLayerType = MutableMapping[str, Optional[jnp.array]]\n\n\n# Model mask type for static type checking.\nMutableMaskType = MutableMapping[str, MutableMaskLayerType]\n\n\nclass MaskedModule(flax.deprecated.nn.Module):\n  \"\"\"Generic FLAX Masking Module.\n\n     Masks a FLAX module, given a mask for params of each layer.\n\n     Attributes:\n       UNMASKED: The key to use for the unmasked parameter dictionary.\n  \"\"\"\n\n  UNMASKED = 'unmasked'\n\n  def apply(self,\n            *args,\n            wrapped_module,\n            mask = None,\n            **kwargs):\n    \"\"\"Apply the wrapped module, while applying the given masks to its params.\n\n    Args:\n      *args: The positional arguments for the wrapped module.\n      wrapped_module: The module class to be wrapped.\n      mask: The mask nested dictionary containing masks for the wrapped module's\n        params, in the same format/with the same keys as the module param dict\n        (or None if not to mask).\n      **kwargs: The keyword arguments for the wrapped module.\n\n    Returns:\n    The intermediate outputs specified by truncate_path.\n\n    Raises:\n    ValueError: If the given mask is not valid for the wrapped module, i.e. the\n                pytrees do not match.\n    \"\"\"\n\n    # Explicitly create the parameters of the wrapped module.\n    def init_fn(rng, input_shape):\n      del input_shape  # Unused.\n\n      # Call init to get the params of the wrapped module.\n      _, params = wrapped_module.init(rng, *args, **kwargs)\n      return params\n\n    unmasked_params = self.param(self.UNMASKED, None, init_fn)\n\n    if mask is not None:\n      try:\n        masked_params = jax.tree_util.tree_map(\n            lambda x, *xs: x\n            if xs[0] is None else x * xs[0], unmasked_params, mask)\n      except ValueError as err:\n        raise ValueError('Mask is invalid for model.') from err\n\n      # Call the wrapped module with the masked params.\n      return wrapped_module.call(masked_params, *args, **kwargs)\n    else:\n      logging.warning('Using masked module without mask!')\n      # Call the wrapped module with the unmasked params.\n      return wrapped_module.call(unmasked_params, *args, **kwargs)\n\n\ndef masked(module, mask):\n  \"\"\"Convenience function for masking a FLAX module with MaskedModule.\"\"\"\n  return MaskedModule.partial(wrapped_module=module, mask=mask)\n\n\ndef generate_model_masks(\n    depth,\n    mask = None,\n    masked_layer_indices = None):\n  \"\"\"Creates empty masks for this model, or initializes with existing mask.\n\n  Args:\n    depth: Number of layers in the model.\n    mask: Existing model mask for layers in this model, if not given, all\n      module masks are initialized to None.\n    masked_layer_indices: The layer indices of layers in model to be masked, or\n      all if None.\n\n  Returns:\n    A model mask, with None where no mask is given for a model layer, or that\n    specific layer is indicated as not to be masked by the masked_layer_indices\n    parameter.\n  \"\"\"\n  if depth <= 0:\n    raise ValueError(f'Invalid model depth: {depth}')\n\n  if mask is None:\n    mask = {f'MaskedModule_{i}': None for i in range(depth)}\n\n  # Have to explicitly check for None to differentiate from empty array.\n  if masked_layer_indices is not None:\n    # Check none of the indices are outside of model's layer bounds.\n    if any(i < 0 or i >= depth for i in masked_layer_indices):\n      raise ValueError(\n          f'Invalid indices for given depth ({depth}): {masked_layer_indices}')\n    mask = {\n        f'MaskedModule_{i}': mask[f'MaskedModule_{i}']\n        for i in masked_layer_indices\n    }\n\n  return mask\n\n\ndef _filter_param(param_names,\n                  invert = False):\n  \"\"\"Convenience function for filtering maskable parameters from paths.\n\n  Args:\n    param_names: Names of parameters we are looking for.\n    invert: Inverts filter to exclude, rather than include, given parameters.\n\n  Returns:\n    A function to use with flax.deprecated.nn.optim.ModelParamTraversal for\n    filtering.\n  \"\"\"\n\n  def filter_fn(path, value):\n    del value  # Unused.\n    parameter_found = any([\n        '{}/{}'.format(MaskedModule.UNMASKED, param_name) in path\n        for param_name in param_names\n    ])\n    return not parameter_found if invert else parameter_found\n\n  return filter_fn\n\n\ndef mask_map(model,\n             fn):\n  \"\"\"Convenience function to create a mask for a model.\n\n  Args:\n    model: The Flax model, with at least one MaskedModule layer.\n    fn: The function to call on each masked parameter, to create the mask for\n      that parameter, takes the parameter name, and parameter value as arguments\n      and returns the new parameter value.\n\n  Returns:\n    A model parameter dictionary, with all masked parameters set by the given\n    function, and all other parameters set to None.\n\n  Raises:\n    ValueError: If the given model does not support masking, i.e. none of the\n                layers are wrapped by a MaskedModule.\n  \"\"\"\n  maskable = False\n  for layer_key, layer in model.params.items():\n    if MaskedModule.UNMASKED not in layer:\n      logging.warning(\n          'Layer \\'%s\\' does not support masking, i.e. it is not '\n          'wrapped by a MaskedModule', layer_key)\n    else:\n      maskable = True\n\n  if not maskable:\n    raise ValueError('Model does not support masking, i.e. no layers are '\n                     'wrapped by a MaskedModule.')\n\n  # First set all non-masked params to None in copy of model pytree.\n  filter_non_masked = _filter_param(WEIGHT_PARAM_NAMES, invert=True)\n  nonmasked_traversal = flax.optim.ModelParamTraversal(filter_non_masked)  # pytype: disable=module-attr\n  mask_model = nonmasked_traversal.update(lambda _: None, model)\n\n  # Then find params to mask, and set to array.\n  for param_name in WEIGHT_PARAM_NAMES:\n    filter_masked = _filter_param(WEIGHT_PARAM_NAMES)\n    mask_traversal = flax.optim.ModelParamTraversal(filter_masked)  # pytype: disable=module-attr\n    mask_model = mask_traversal.update(\n        functools.partial(fn, param_name), mask_model)\n\n  mask = mask_model.params\n  # Remove unneeded unmasked param for mask.\n  for layer_key, layer in mask.items():\n    if MaskedModule.UNMASKED in layer:\n      mask[layer_key] = layer[MaskedModule.UNMASKED]\n\n  return mask\n\n\ndef iterate_mask(\n    mask,\n    param_names = None\n):\n  \"\"\"Iterate over the parameters in as mask.\n\n  Args:\n    mask: The model mask.\n    param_names: The parameter names to iterate over in each layer, if None\n      iterates over all parameters of all layers.\n\n  Yields:\n    An iterator of tuples containing the parameter path and parameter value\n    in sorted order of layer parameters matching the names in param_names (or\n    all parameters if None).\n  \"\"\"\n  flat_mask = flax.traverse_util.flatten_dict(mask)\n  for key, value in flat_mask.items():\n    if param_names is None or key in param_names:\n      path = '/' + '/'.join(key)\n      yield path, value\n\n\ndef shuffled_mask(model, rng,\n                  sparsity):\n  \"\"\"Returns a randomly shuffled mask with a given sparsity for all layers.\n\n  Returns a random weight mask for a model param array, by randomly shuffling a\n  mask with a fixed number of non-zero/zero entries, given by the sparsity.\n\n  Args:\n    model: Flax model that contains masked modules.\n    rng: Random number generator, i.e. jax.random.PRNGKey.\n    sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will\n      mask all weights, while 0 will mask none.\n\n  Returns:\n    A randomly shuffled weight mask, in the same form as flax.Module.params.\n\n  Raises:\n    ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are\n                maskable, i.e. is wrapped by MaskedModule.\n  \"\"\"\n  if sparsity > 1 or sparsity < 0:\n    raise ValueError(\n        'Given sparsity, {}, is not in range [0, 1]'.format(sparsity))\n\n  def create_shuffled_mask(param_name, param):\n    del param_name  # Unused.\n    mask = jnp.arange(param.size)\n    mask = jnp.where(mask >= sparsity * param.size, jnp.ones_like(mask),\n                     jnp.zeros_like(mask))\n    mask = jax.random.permutation(rng, mask)\n    return mask.reshape(param.shape)\n\n  return mask_map(model, create_shuffled_mask)\n\n\ndef random_mask(model,\n                rng,\n                mean_sparsity = 0.5):\n  \"\"\"Returns a random weight mask for a masked model.\n\n  Args:\n    model: Flax model that contains masked modules.\n    rng: Random number generator, i.e. jax.random.PRNGKey.\n    mean_sparsity: The mean number of 0's in the mask, i.e. mean = (1 -\n      mean_sparsity) for the Bernoulli distribution to sample from.\n\n  Returns:\n    A random weight mask, in the same form as flax.Module.params\n\n  Raises:\n    ValueError: If the sparsity is beyond the bounds [0, 1], or if a layer to\n                mask is not maskable, i.e. is not wrapped by MaskedModule.\n  \"\"\"\n  if mean_sparsity > 1 or mean_sparsity < 0:\n    raise ValueError(\n        'Given sparsity, {}, is not in range [0, 1]'.format(mean_sparsity))\n\n  # Invert mean_sparsity to get mean for Bernoulli distribution.\n  mean = 1. - mean_sparsity\n\n  def create_random_mask(param_name, param):\n    del param_name  # Unused.\n    return jax.random.bernoulli(\n        rng, p=mean,\n        shape=param.shape).astype(jnp.int32)  # TPU doesn't support uint8.\n\n  return mask_map(model, create_random_mask)\n\n\ndef simple_mask(model,\n                init_fn,\n                masked_param):\n  \"\"\"Creates a mask given a model and numpy initialization function.\n\n  Args:\n    model: The model to create a mask for.\n    init_fn: The numpy initialization function, e.g. numpy.ones.\n    masked_param: The list of parameters to mask.\n\n  Returns:\n    A mask for the model.\n  \"\"\"\n\n  def create_init_fn_mask(param_name, param):\n    if param_name in masked_param:\n      return init_fn(param.shape)\n    return None\n\n  return mask_map(model, create_init_fn_mask)\n\n\ndef symmetric_mask(model,\n                   rng,\n                   sparsity = 0.5):\n  \"\"\"Generates a random weight mask that's symmetric, i.e. structurally pruned.\n\n  Args:\n    model: Flax model that contains masked modules.\n    rng: Random number generator, i.e. jax.random.PRNGKey.\n    sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), in the\n      range  [0, 1]: 1.0 will mask all weights, while 0 will mask none.\n\n  Returns:\n    A symmetric random weight mask, in the same form as flax.Module.params.\n  \"\"\"\n  if sparsity > 1 or sparsity < 0:\n    raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')\n\n  def create_neuron_symmetric_mask(param_name, param):\n    del param_name  # Unused.\n    neuron_length = functools.reduce(operator.mul, param.shape[:-1])\n    neuron_mask = jnp.arange(neuron_length)\n    neuron_mask = jnp.where(neuron_mask >= sparsity * neuron_mask.size,\n                            jnp.ones_like(neuron_mask),\n                            jnp.zeros_like(neuron_mask))\n    neuron_mask = jax.random.shuffle(rng, neuron_mask)\n    mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1)\n    return mask.reshape(param.shape)\n\n  return mask_map(model, create_neuron_symmetric_mask)\n\n\nclass _PerNeuronShuffle:\n  \"\"\"This class is needed to get around the fact that JAX RNG is stateless.\"\"\"\n\n  def __init__(self, init_rng, sparsity):\n    \"\"\"Creates the per-neuron shuffle class, with initial RNG state.\n\n    Args:\n      init_rng: The initial random number generator state to use.\n      sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will\n        mask all weights, while 0 will mask none.\n    \"\"\"\n    self._rng = init_rng\n    self._sparsity = sparsity\n\n  def __call__(self, param_name, param):\n    \"\"\"Shuffles the weight matrix/mask for a given parameter, per-neuron.\n\n    This is to be used with mask_map, and accepts the standard mask_map\n    function parameters.\n\n    Args:\n      param_name: The parameter's name.\n      param: The parameter's weight or mask matrix.\n\n    Returns:\n      A shuffled weight/mask matrix, with each neuron shuffled independently.\n    \"\"\"\n    del param_name  # Unused.\n    neuron_length = functools.reduce(operator.mul, param.shape[:-1])\n    neuron_mask = jnp.arange(neuron_length)\n    neuron_mask = jnp.where(neuron_mask >= self._sparsity * neuron_mask.size,\n                            jnp.ones_like(neuron_mask),\n                            jnp.zeros_like(neuron_mask))\n    mask = jnp.repeat(neuron_mask[Ellipsis, jnp.newaxis], param.shape[-1], axis=1)\n    self._rng, rng_input = jax.random.split(self._rng)\n    mask = jax.random.shuffle(rng_input, mask, axis=0)\n    return mask.reshape(param.shape)\n\n\ndef shuffled_neuron_mask(model,\n                         rng,\n                         sparsity):\n  \"\"\"Returns a shuffled mask with a given fixed sparsity for all neurons/layers.\n\n  Returns a randomly shuffled weight mask for a model param array, by setting a\n  fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector\n  in the model, and then randomly shuffling each neuron's weight mask with a\n  fixed number of non-zero/zero entries, given by the sparsity. This ensures no\n  neuron is ablated for a non-zero sparsity.\n\n  Note: This is much more complicated for convolutional layers due to the\n  receptive field being different for every pixel! We only take into account\n  channel-wise masks and not spatial ablations in propagation in that case.\n\n  Args:\n    model: Flax model that contains masked modules.\n    rng: Random number generator, i.e. jax.random.PRNGKey.\n    sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will\n      mask all weights, while 0 will mask none.\n\n  Returns:\n    A randomly shuffled weight mask, in the same form as flax.Module.params.\n\n  Raises:\n    ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are\n                maskable, i.e. is wrapped by MaskedModule.\n  \"\"\"\n  if sparsity > 1 or sparsity < 0:\n    raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')\n\n  return mask_map(model, _PerNeuronShuffle(rng, sparsity))\n\n\ndef _fill_diagonal_wrap(shape,\n                        value,\n                        dtype = jnp.uint8):\n  \"\"\"Fills the diagonal of a 2D array, while also wrapping tall arrays.\n\n  For a matrix of dimensions (N x M),:\n    if N <= M, i.e. the array is wide rectangular, the array's diagonal is\n    filled, for example:\n\n    _fill_diagonal_wrap(jnp.zeroes((2, 3), dtype=uint8), 1)\n    > [[1, 0, 0],\n       [0, 1, 0]]\n\n    if N > M, i.e. the array is tall rectangular, the array's diagonal, and\n    offset diagonals are filled. This differs from\n    numpy.fill_diagonal(..., wrap=True), in that it does not include a single\n    row gap between the diagonals, and it is not in-place but returns a copy of\n    the given array. For example,\n\n    _fill_diagonal_wrap(jnp.zeroes((3, 2), dtype=uint8), 1)\n    > [[1, 0],\n       [0, 1],\n       [1, 0]]\n\n  Args:\n    shape: The shape of the 2D array to return with the diagonal filled.\n    value: The value to fill in for the diagonal, and offset diagonals.\n    dtype: The datatype of the jax numpy array to return.\n  Returns:\n    A copy of the given array with the main diagonal filled, and offset\n    diagonals filled if the given array is tall.\n  \"\"\"\n  if len(shape) != 2:\n    raise ValueError(\n        f'Expected an 2D array, however array has dimensions: {shape}')\n\n  array = jnp.zeros(shape, dtype=dtype)\n  rows, cols = shape\n\n  def diagonal_indices(offset):  # Returns jax.ops._Indexable.\n    \"\"\"Returns slice of the nth diagonal of an array, where n is offset.\"\"\"\n    # This is an a numpy-style advanced slice of the form [start:end:step], that\n    # gives you the offset (vertically) diagonal of an array. If it was the main\n    # diagonal of a (flattened) square matrix of n X n it would be 0:n**2:n+1,\n    # i.e. start at 0, and look at each n+1 elements, end when you get to end\n    # of array. We need to look at vertically-offset diagonals as well, which is\n    # handled by offset.\n    return jnp.index_exp[cols * offset:cols * (offset + cols):cols + 1]\n\n  # Fills (square) matrix diagonals with the given value, tiling over tall\n  # rectangular arrays by offsetting the filled diagonals by multiples of the\n  # height of the square arrays.\n  diagonals = [\n      array.ravel().at[diagonal_indices(offset)].set(value).reshape(array.shape)\n      for offset in range(0, rows, cols)\n  ]\n  return functools.reduce(jnp.add, diagonals)\n\n\ndef _random_neuron_mask(neuron_length,\n                        unmasked_count,\n                        rng,\n                        dtype = jnp.uint32):\n  \"\"\"Generates a random mask for a neuron.\n\n  Args:\n    neuron_length: The length of the neuron's weight vector.\n    unmasked_count: The number of elements that should be unmasked.\n    rng: A jax.random.PRNGKey random seed.\n    dtype: Type of array to create.\n  Returns:\n    A random neuron weight vector mask.\n  \"\"\"\n  if unmasked_count > neuron_length:\n    raise ValueError('unmasked_count cannot be greater that neuron_length: '\n                     f'{unmasked_count} > {neuron_length}')\n  neuron_mask = jnp.concatenate(\n      (jnp.ones(unmasked_count), jnp.zeros(neuron_length - unmasked_count)),\n      axis=0)\n  neuron_mask = jax.random.shuffle(rng, neuron_mask)\n  return neuron_mask.astype(dtype)\n\n\nclass _PerNeuronNoInputAblationShuffle:\n  \"\"\"This class is needed to get around the fact that JAX RNG is stateless.\"\"\"\n\n  def __init__(self, init_rng, sparsity):\n    \"\"\"Creates the per-neuron shuffle class, with initial RNG state.\n\n    Args:\n      init_rng: The initial random number generator state to use.\n      sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will\n        mask all weights, while 0 will mask none.\n    \"\"\"\n    self._rng = init_rng\n    self._sparsity = sparsity\n\n  def _get_rng(self):\n    \"\"\"Creates a new JAX RNG, while updating RNG state.\"\"\"\n    self._rng, rng_input = jax.random.split(self._rng)\n    return rng_input\n\n  def __call__(self, param_name, param):\n    \"\"\"Shuffles the weight matrix/mask for a given parameter, per-neuron.\n\n    This is to be used with mask_map, and accepts the standard mask_map\n    function parameters.\n\n    Args:\n      param_name: The parameter's name.\n      param: The parameter's weight or mask matrix.\n\n    Returns:\n      A shuffled weight/mask matrix, with each neuron shuffled independently.\n    \"\"\"\n    del param_name  # Unused.\n\n    incoming_connections = jnp.prod(jnp.array(param.shape[:-1]))\n    num_neurons = param.shape[-1]\n\n    # Ensure each input neuron has at least one connection unmasked.\n    mask = _fill_diagonal_wrap((incoming_connections, num_neurons), 1,\n                               dtype=jnp.uint8)\n\n    # Randomly shuffle which of the neurons have these connections.\n    mask = jax.random.shuffle(self._get_rng(), mask, axis=0)\n\n    # Add extra required random connections to mask to satisfy sparsity.\n    mask_cols = []\n    for col in range(mask.shape[-1]):\n      neuron_mask = mask[:, col]\n      off_diagonal_count = max(\n          round((1 - self._sparsity) * incoming_connections)\n          - jnp.count_nonzero(neuron_mask), 0)\n\n      zero_indices = jnp.flatnonzero(neuron_mask == 0)\n      random_entries = _random_neuron_mask(\n          len(zero_indices), off_diagonal_count, self._get_rng())\n\n      neuron_mask = neuron_mask.at[zero_indices].set(random_entries)\n      mask_cols.append(neuron_mask)\n\n    return jnp.column_stack(mask_cols).reshape(param.shape)\n\n\ndef shuffled_neuron_no_input_ablation_mask(model,\n                                           rng,\n                                           sparsity):\n  \"\"\"Returns a shuffled mask with a given fixed sparsity for all neurons/layers.\n\n  Returns a randomly shuffled weight mask for a model param array, by setting a\n  fixed sparsity (i.e. number of ones/zeros) for every neuron's weight vector\n  in the model, and then randomly shuffling each neuron's weight mask with a\n  fixed number of non-zero/zero entries, given by the sparsity. This ensures no\n  neuron is ablated for a non-zero sparsity.\n\n  This function also ensures that no neurons in the previous layer are\n  effectively ablated, by ensuring that each neuron has at least one connection.\n\n  Note: This is much more complicated for convolutional layers due to the\n  receptive field being different for every pixel! We only take into account\n  channel-wise masks and not spatial ablations in propagation in that case.\n\n  Args:\n    model: Flax model that contains masked modules.\n    rng: Random number generator, i.e. jax.random.PRNGKey.\n    sparsity: The per-layer sparsity of the mask (i.e. % of zeroes), 1.0 will\n      mask all weights, except for the minimum number required to maintain,\n      connectivity with the input layer, while 0 will mask none.\n\n  Returns:\n    A randomly shuffled weight mask, in the same form as flax.Module.params.\n\n  Raises:\n    ValueError: If the sparsity is beyond the bounds [0, 1], or no layers are\n                maskable, i.e. is wrapped by MaskedModule.\n  \"\"\"\n  if sparsity > 1.0 or sparsity < 0.0:\n    raise ValueError(f'Given sparsity, {sparsity}, is not in range [0, 1]')\n\n  # First, generate a random permutation matrix, and ensure our mask has at\n  # least N connections, where there are N neurons in the previous layer.\n  return mask_map(model, _PerNeuronNoInputAblationShuffle(rng, sparsity))\n\n\ndef propagate_masks(\n    mask,\n    param_names = WEIGHT_PARAM_NAMES\n):\n  \"\"\"Accounts for implicitly pruned neurons in a model's weight masks.\n\n  When neurons are randomly ablated in one layer, they can effectively ablate\n  neurons in the next layer if in effect all incoming weights of a neuron are\n  zero. This method accounts for this by propagating forward mask information\n  through the entire model.\n\n  Args:\n    mask: Model masks to check, in same pytree structure as Model.params.\n    param_names: List of param keys in mask to count.\n\n  Returns:\n   A refined model mask with weights that are effectively ablated in the\n   original mask set to zero.\n  \"\"\"\n\n  flat_mask = flax.traverse_util.flatten_dict(mask)\n  mask_layer_list = list(flat_mask.values())\n  mask_layer_keys = list(flat_mask.keys())\n\n  mask_layer_param_names = [layer_param[-1] for layer_param in mask_layer_keys]\n\n  for param_name in param_names:\n    # Find which of the param arrays correspond to leaf nodes with this name.\n    param_indices = [\n        i for i, names in enumerate(mask_layer_param_names)\n        if param_name in names\n    ]\n\n    for i in range(1, len(param_indices)):\n      last_weight_mask = mask_layer_list[param_indices[i - 1]]\n      weight_mask = mask_layer_list[param_indices[i]]\n\n      if last_weight_mask is None or weight_mask is None:\n        continue\n\n      last_weight_mask_reshaped = jnp.reshape(last_weight_mask,\n                                              (-1, last_weight_mask.shape[-1]))\n\n      # Neurons with any outgoing weights from previous layer.\n      alive_incoming = jnp.sum(last_weight_mask_reshaped, axis=0) != 0\n\n      # Combine effective mask of previous layer with neuron's current mask.\n      if len(weight_mask.shape) > 2:\n        # Convolutional layer, only consider channel-wise masks, if any spatial\n        # weight is non-zero that channel is considered non-masked.\n        spatial_dim = len(weight_mask.shape) - 2\n        new_weight_mask = alive_incoming[:, jnp.newaxis] * jnp.amax(\n            weight_mask, axis=tuple(range(spatial_dim)))\n        new_weight_mask = jnp.tile(new_weight_mask,\n                                   weight_mask.shape[:-2] + (1, 1))\n      else:\n        # Check for case of dense following convolution, i.e. spatial input into\n        # dense, to prevent b/156135283. Must use convolution for these layers.\n        if len(last_weight_mask.shape) > 2:\n          raise ValueError(\n              'propagate_masks requires knowledge of the spatial '\n              'dimensions of the previous layer. Use a functionally equivalent '\n              'conv. layer in place of a dense layer in a model with a mixed '\n              'conv/dense setting.')\n        new_weight_mask = alive_incoming[:, jnp.newaxis] * weight_mask\n\n      mask_layer_list[param_indices[i]] = jnp.reshape(\n          new_weight_mask, mask_layer_list[param_indices[i]].shape)\n\n  return flax.traverse_util.unflatten_dict(\n      dict(zip(mask_layer_keys, mask_layer_list)))\n\n\ndef mask_layer_sparsity(mask_layer):\n  \"\"\"Calculates the sparsity of a single layer's mask.\n\n  Args:\n    mask_layer: mask layer to calculate the sparsity of.\n\n  Returns:\n    The sparsity of the mask.\n  \"\"\"\n  parameter_count = 0\n  masked_count = 0\n\n  for key in mask_layer:\n    if mask_layer[key] is not None and key in WEIGHT_PARAM_NAMES:\n      parameter_count += mask_layer[key].size\n      masked_count += jnp.sum(mask_layer[key])\n\n  if parameter_count == 0:\n    return 0.\n\n  return 1. - masked_count/parameter_count\n\n\ndef mask_sparsity(\n    mask,\n    param_names = None):\n  \"\"\"Calculates the sparsity of the given parameters over a model mask.\n\n  Args:\n    mask: Model mask to calculate sparsity over.\n    param_names: List of param keys in mask to count.\n\n  Returns:\n    The overall sparsity of the mask.\n  \"\"\"\n  if param_names is None:\n    param_names = WEIGHT_PARAM_NAMES\n\n  parameter_count = 0\n  masked_count = 0\n\n  for path, value in iterate_mask(mask):\n    if value is not None and any(\n        param_name in path for param_name in param_names):\n      parameter_count += value.size\n      masked_count += jnp.sum(value.flatten())\n\n  if parameter_count == 0:\n    return 0.\n\n  return 1.0 - float(masked_count / parameter_count)\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/masked_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.pruning.masked.\"\"\"\nfrom typing import Mapping, Optional, Sequence\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom rigl.experimental.jax.pruning import masked\n\n\nclass Dense(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Dense Non-Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self, inputs):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n    return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES)\n\n\nclass MaskedDense(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_0'] if mask else None)\n\n\nclass DenseTwoLayer(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Dense Non-Masked Network.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (32, 64)\n\n  def apply(self, inputs):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n    inputs = flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[0])\n    return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[1])\n\n\nclass MaskedTwoLayerDense(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (32, 64)\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[0],\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_0'] if mask else None)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[1],\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_1'] if mask else None)\n\n\nclass MaskedConv(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Conv Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 16\n\n  def apply(self,\n            inputs,\n            mask = None):\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        kernel_size=(3, 3),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n\n\nclass MaskedTwoLayerConv(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Conv Masked Network.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (16, 32)\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[0],\n        kernel_size=(5, 5),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[1],\n        kernel_size=(3, 3),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_1'] if mask is not None else None)\n\n\nclass MaskedThreeLayerConvDense(flax.deprecated.nn.Module):\n  \"\"\"Three-layer Conv Masked Network with Dense layer.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (16, 32, 64)\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[0],\n        kernel_size=(5, 5),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[1],\n        kernel_size=(3, 3),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_1'] if mask is not None else None)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[2],\n        kernel_size=inputs.shape[1:-1],\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_2'] if mask is not None else None)\n\n\nclass MaskedTwoLayerMixedConvDense(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Mixed Conv/Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (16, 32)\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[0],\n        kernel_size=(5, 5),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[1],\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_1'] if mask is not None else None)\n\n\nclass MaskedTest(parameterized.TestCase):\n  \"\"\"Tests the flax layer mask.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._batch_size = 2\n    self._input_dimensions = (28, 28, 1)\n    self._input_shape = ((self._batch_size,) + self._input_dimensions,\n                         jnp.float32)\n    self._input = jnp.ones(*self._input_shape)\n\n    _, initial_params = Dense.init_by_shape(self._rng, (self._input_shape,))\n    self._unmasked_model = flax.deprecated.nn.Model(Dense, initial_params)\n    self._unmasked_output = self._unmasked_model(self._input)\n\n    # Use the same initialization for both masked/unmasked models.\n    masked_initial_params = {\n        'MaskedModule_0': {\n            'unmasked': initial_params['Dense_0']\n        }\n    }\n    self._masked_model = flax.deprecated.nn.Model(MaskedDense,\n                                                  masked_initial_params)\n\n    _, initial_params = DenseTwoLayer.init_by_shape(self._rng,\n                                                    (self._input_shape,))\n    self._unmasked_model_twolayer = flax.deprecated.nn.Model(\n        DenseTwoLayer, initial_params)\n    self._unmasked_output_twolayer = self._unmasked_model_twolayer(self._input)\n\n    # Use the same initialization for both masked/unmasked models.\n    masked_initial_params = {\n        'MaskedModule_0': {\n            'unmasked': initial_params['Dense_0']\n        },\n        'MaskedModule_1': {\n            'unmasked': initial_params['Dense_1']\n        },\n    }\n    _, initial_params = MaskedTwoLayerDense.init_by_shape(\n        self._rng, (self._input_shape,))\n    self._masked_model_twolayer = flax.deprecated.nn.Model(\n        MaskedTwoLayerDense, masked_initial_params)\n\n    _, initial_params = MaskedConv.init_by_shape(self._rng,\n                                                 (self._input_shape,))\n    self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv,\n                                                       initial_params)\n\n    _, initial_params = MaskedTwoLayerConv.init_by_shape(\n        self._rng, (self._input_shape,))\n    self._masked_conv_model_twolayer = flax.deprecated.nn.Model(\n        MaskedTwoLayerConv, initial_params)\n\n    _, initial_params = MaskedTwoLayerMixedConvDense.init_by_shape(\n        self._rng, (self._input_shape,))\n    self._masked_mixed_model_twolayer = flax.deprecated.nn.Model(\n        MaskedTwoLayerMixedConvDense, initial_params)\n\n    _, initial_params = MaskedThreeLayerConvDense.init_by_shape(\n        self._rng, (self._input_shape,))\n    self._masked_conv_fc_model_threelayer = flax.deprecated.nn.Model(\n        MaskedThreeLayerConvDense, initial_params)\n\n  def test_fully_masked_layer(self):\n    \"\"\"Tests masked module with full-sparsity mask.\"\"\"\n    full_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])\n\n    masked_output = self._masked_model(self._input, mask=full_mask)\n\n    with self.subTest(name='fully_masked_dense_values'):\n      self.assertTrue((masked_output == 0).all())\n\n    with self.subTest(name='fully_masked_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_no_mask_masked_layer(self):\n    \"\"\"Tests masked module with no mask.\"\"\"\n    masked_output = self._masked_model(self._input, mask=None)\n\n    with self.subTest(name='no_mask_masked_dense_values'):\n      self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())\n\n    with self.subTest(name='no_mask_masked_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_empty_mask_masked_layer(self):\n    \"\"\"Tests masked module with an empty (not sparse) mask.\"\"\"\n    empty_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])\n\n    masked_output = self._masked_model(self._input, mask=empty_mask)\n\n    with self.subTest(name='empty_mask_masked_dense_values'):\n      self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())\n\n    with self.subTest(name='empty_mask_masked_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_invalid_mask(self):\n    \"\"\"Tests using an invalid mask.\"\"\"\n    invalid_mask = {\n        'MaskedModule_0': {\n            'not_kernel':\n                jnp.ones(self._unmasked_model.params['Dense_0']['kernel'].shape)\n        }\n    }\n\n    with self.assertRaisesRegex(ValueError, 'Mask is invalid for model.'):\n      self._masked_model(self._input, mask=invalid_mask)\n\n  def test_shuffled_mask_invalid_model(self):\n    \"\"\"Tests shuffled mask with model containing no masked layers.\"\"\"\n    with self.assertRaisesRegex(\n        ValueError, 'Model does not support masking, i.e. no layers are '\n        'wrapped by a MaskedModule.'):\n      masked.shuffled_mask(self._unmasked_model, self._rng, 0.5)\n\n  def test_shuffled_mask_invalid_sparsity(self):\n    \"\"\"Tests shuffled mask with invalid sparsity.\"\"\"\n\n    with self.subTest(name='sparsity_too_small'):\n      with self.assertRaisesRegex(\n          ValueError, r'Given sparsity, -0.5, is not in range \\[0, 1\\]'):\n        masked.shuffled_mask(self._masked_model, self._rng, -0.5)\n\n    with self.subTest(name='sparsity_too_large'):\n      with self.assertRaisesRegex(\n          ValueError, r'Given sparsity, 1.5, is not in range \\[0, 1\\]'):\n        masked.shuffled_mask(self._masked_model, self._rng, 1.5)\n\n  def test_shuffled_mask_sparsity_full(self):\n    \"\"\"Tests shuffled mask generation, for 100% sparsity.\"\"\"\n    mask = masked.shuffled_mask(self._masked_model, self._rng, 1.0)\n\n    with self.subTest(name='shuffled_full_mask'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_full_mask_values'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())\n\n    with self.subTest(name='shuffled_full_mask_not_masked_values'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_full_mask_dense_values'):\n      self.assertTrue((masked_output == 0).all())\n\n    with self.subTest(name='shuffled_full_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_shuffled_mask_sparsity_empty(self):\n    \"\"\"Tests shuffled mask generation, for 0% sparsity.\"\"\"\n    mask = masked.shuffled_mask(self._masked_model, self._rng, 0.0)\n\n    with self.subTest(name='shuffled_empty_mask'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_empty_mask_values'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())\n\n    with self.subTest(name='shuffled_empty_mask_not_masked_values'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_empty_dense_values'):\n      self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())\n\n    with self.subTest(name='shuffled_empty_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_shuffled_mask_sparsity_half_full(self):\n    \"\"\"Tests shuffled mask generation, for a half-full mask.\"\"\"\n    mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5)\n    param_len = self._masked_model.params['MaskedModule_0']['unmasked'][\n        'kernel'].size\n\n    with self.subTest(name='shuffled_mask_values'):\n      self.assertEqual(\n          jnp.sum(mask['MaskedModule_0']['kernel']), param_len // 2)\n\n  def test_shuffled_mask_sparsity_full_twolayer(self):\n    \"\"\"Tests shuffled mask generation for two layers, and 100% sparsity.\"\"\"\n    mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 1.0)\n\n    with self.subTest(name='shuffled_full_mask_layer1'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_full_mask_values_layer1'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())\n\n    with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    with self.subTest(name='shuffled_full_mask_layer2'):\n      self.assertIn('MaskedModule_1', mask)\n\n    with self.subTest(name='shuffled_full_mask_values_layer2'):\n      self.assertTrue((mask['MaskedModule_1']['kernel'] == 0).all())\n\n    with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):\n      self.assertIsNone(mask['MaskedModule_1']['bias'])\n\n    masked_output = self._masked_model_twolayer(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_full_mask_dense_values'):\n      self.assertTrue((masked_output == 0).all())\n\n    with self.subTest(name='shuffled_full_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape,\n                               self._unmasked_output_twolayer.shape)\n\n  def test_shuffled_mask_sparsity_empty_twolayer(self):\n    \"\"\"Tests shuffled mask generation for two layers, for 0% sparsity.\"\"\"\n    mask = masked.shuffled_mask(self._masked_model_twolayer, self._rng, 0.0)\n\n    with self.subTest(name='shuffled_empty_mask_layer1'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_empty_mask_values_layer1'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())\n\n    with self.subTest(name='shuffled_empty_mask_layer2'):\n      self.assertIn('MaskedModule_1', mask)\n\n    with self.subTest(name='shuffled_empty_mask_values_layer2'):\n      self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all())\n\n    masked_output = self._masked_model_twolayer(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_empty_dense_values'):\n      self.assertTrue(\n          jnp.isclose(masked_output, self._unmasked_output_twolayer).all())\n\n    with self.subTest(name='shuffled_empty_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape,\n                               self._unmasked_output_twolayer.shape)\n\n  def test_random_invalid_model(self):\n    \"\"\"Tests random mask with model containing no masked layers.\"\"\"\n    with self.assertRaisesRegex(\n        ValueError, 'Model does not support masking, i.e. no layers are '\n        'wrapped by a MaskedModule.'):\n      masked.random_mask(self._unmasked_model, self._rng, 0.5)\n\n  def test_random_invalid_sparsity(self):\n    \"\"\"Tests random mask with invalid sparsity.\"\"\"\n\n    with self.subTest(name='random_sparsity_too_small'):\n      with self.assertRaisesRegex(\n          ValueError, r'Given sparsity, -0.5, is not in range \\[0, 1\\]'):\n        masked.random_mask(self._masked_model, self._rng, -0.5)\n\n    with self.subTest(name='random_sparsity_too_large'):\n      with self.assertRaisesRegex(\n          ValueError, r'Given sparsity, 1.5, is not in range \\[0, 1\\]'):\n        masked.random_mask(self._masked_model, self._rng, 1.5)\n\n  def test_random_mask_sparsity_full(self):\n    \"\"\"Tests random mask generation, for 100% sparsity.\"\"\"\n    mask = masked.random_mask(self._masked_model, self._rng, 1.)\n\n    with self.subTest(name='random_full_mask_values'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 0).all())\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='random_full_mask_dense_values'):\n      self.assertTrue((masked_output.all() == 0).all())\n\n    with self.subTest(name='random_full_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_random_mask_sparsity_empty(self):\n    \"\"\"Tests random mask generation, for 0% sparsity.\"\"\"\n    mask = masked.random_mask(self._masked_model, self._rng, 0.)\n\n    with self.subTest(name='random_empty_mask_values'):\n      self.assertEqual(\n          jnp.sum(mask['MaskedModule_0']['kernel']),\n          mask['MaskedModule_0']['kernel'].size)\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='random_empty_dense_values'):\n      self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())\n\n    with self.subTest(name='random_empty_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_random_mask_sparsity_half_full(self):\n    \"\"\"Tests random mask generation, for a half-full mask.\"\"\"\n    mask = masked.random_mask(self._masked_model, self._rng, 0.5)\n    param_len = self._masked_model.params['MaskedModule_0']['unmasked'][\n        'kernel'].size\n    half_full = param_len / 2\n\n    with self.subTest(name='random_mask_values'):\n      self.assertBetween(\n          jnp.sum(mask['MaskedModule_0']['kernel']), 0.66 * half_full,\n          1.33 * half_full)\n\n  def test_simple_mask_one_layer(self):\n    \"\"\"Tests generation of a simple mask.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros(self._masked_model.params['MaskedModule_0']\n                          ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        }\n    }\n\n    gen_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])\n\n    result, _ = jax.tree_flatten(\n        jax.tree_util.tree_map(lambda x, *xs: (x == xs[0]).all(), mask,\n                               gen_mask))\n\n    self.assertTrue(all(result))\n\n  def test_simple_mask_two_layer(self):\n    \"\"\"Tests generation of a simple mask.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0']\n                          ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n        'MaskedModule_1': {\n            'kernel':\n                jnp.zeros(self._masked_model_twolayer.params['MaskedModule_1']\n                          ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n    }\n\n    gen_mask = masked.simple_mask(self._masked_model_twolayer, jnp.zeros,\n                                  ['kernel'])\n\n    result, _ = jax.tree_flatten(\n        jax.tree_util.tree_map(lambda x, *xs: (x == xs[0]).all(), mask,\n                               gen_mask))\n\n    self.assertTrue(all(result))\n\n  def test_shuffled_mask_neuron_mask_sparsity_empty(self):\n    \"\"\"Tests shuffled neuron mask generation, for 0% sparsity.\"\"\"\n    mask = masked.shuffled_neuron_mask(self._masked_model, self._rng, 0.0)\n\n    with self.subTest(name='shuffled_neuron_empty_mask'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_neuron_empty_mask_values'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())\n\n    with self.subTest(name='shuffled_neuron_empty_mask_not_masked_values'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_neuron_empty_dense_values'):\n      self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())\n\n    with self.subTest(name='shuffled_neuron_empty_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_shuffled_mask_neuron_mask_sparsity_half_full(self):\n    \"\"\"Tests shuffled mask generation, for a half-full mask.\"\"\"\n    mask = masked.shuffled_neuron_mask(self._masked_model, self._rng, 0.5)\n    param_len = len(\n        self._masked_model.params['MaskedModule_0']['unmasked']['kernel'][:, 0])\n    mask_sum = jnp.sum(mask['MaskedModule_0']['kernel'][:, 0])\n\n    with self.subTest(name='shuffled_mask_values'):\n      # Check that single neuron has the correct sparsity.\n      self.assertEqual(mask_sum, param_len // 2)\n\n    with self.subTest(name='shuffled_mask_rows_different'):\n      # Check that two rows are different.\n      self.assertFalse(\n          jnp.isclose(mask['MaskedModule_0']['kernel'][:, 0],\n                      mask['MaskedModule_0']['kernel'][:, 1]).all())\n\n  def test_symmetric_mask_sparsity_empty(self):\n    \"\"\"Tests symmetric mask generation, for 0% sparsity.\"\"\"\n    mask = masked.symmetric_mask(self._masked_model, self._rng, 0.0)\n\n    with self.subTest(name='shuffled_neuron_empty_mask'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='symmetric_empty_mask_values'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())\n\n    with self.subTest(name='symmetric_empty_mask_not_masked_values'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='symmetric_empty_dense_values'):\n      self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())\n\n    with self.subTest(name='symmetric_empty_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_symmetric_mask_sparsity_half_full(self):\n    \"\"\"Tests shuffled mask generation, for a half-full mask.\"\"\"\n    mask = masked.symmetric_mask(self._masked_model, self._rng, 0.5)\n    param_len = len(\n        self._masked_model.params['MaskedModule_0']['unmasked']['kernel'][:, 0])\n    mask_sum = jnp.sum(mask['MaskedModule_0']['kernel'][:, 0])\n\n    with self.subTest(name='symmetric_mask_values'):\n      # Check that single neuron has the correct sparsity.\n      self.assertEqual(mask_sum, param_len // 2)\n\n    with self.subTest(name='symmetric_mask_rows_different'):\n      # Check that two rows are same.\n      self.assertTrue(\n          jnp.isclose(mask['MaskedModule_0']['kernel'][:, 0],\n                      mask['MaskedModule_0']['kernel'][:, 1]).all())\n\n  def test_propagate_masks_ablated_neurons_one_layer(self):\n    \"\"\"Tests mask propagation on a single layer model.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jax.random.normal(\n                    self._rng,\n                    self._masked_model_twolayer.params['MaskedModule_0']\n                    ['unmasked']['kernel'].shape,\n                    dtype=jnp.float32),\n            'bias':\n                None,\n        },\n    }\n\n    refined_mask = masked.propagate_masks(mask)\n\n    # Since this is a single layer, should not affect mask at all.\n    self.assertTrue((mask['MaskedModule_0']['kernel'] ==\n                     refined_mask['MaskedModule_0']['kernel']).all())\n\n  def test_propagate_masks_ablated_neurons_two_layers(self):\n    \"\"\"Tests mask propagation on a two-layer model.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros(self._masked_model_twolayer.params['MaskedModule_0']\n                          ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n        'MaskedModule_1': {\n            'kernel':\n                jnp.ones(self._masked_model_twolayer.params['MaskedModule_1']\n                         ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n    }\n\n    refined_mask = masked.propagate_masks(mask)\n\n    with self.subTest(name='layer_1'):\n      self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all())\n\n    # Since layer 1 is all zero, layer 2 is also effectively zero.\n    with self.subTest(name='layer_2'):\n      self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all())\n\n  def test_propagate_masks_ablated_neurons_two_layers_nonmasked(self):\n    \"\"\"Tests mask propagation where previous layer is not masked.\"\"\"\n    mask = {\n        'Dense_0': {\n            'kernel': None,\n            'bias': None,\n        },\n        'MaskedModule_1': {\n            'kernel':\n                jax.random.normal(\n                    self._rng,\n                    self._masked_model_twolayer.params['MaskedModule_1']\n                    ['unmasked']['kernel'].shape,\n                    dtype=jnp.float32),\n            'bias':\n                None,\n        },\n    }\n\n    refined_mask = masked.propagate_masks(mask)\n\n    with self.subTest(name='layer_1'):\n      self.assertIsNone(refined_mask['Dense_0']['kernel'])\n\n    # Since layer 1 is all zero, layer 2 is also effectively zero.\n    with self.subTest(name='layer_2'):\n      # Since this is a single masked layer, should not affect mask at all.\n      self.assertTrue((mask['MaskedModule_1']['kernel'] ==\n                       refined_mask['MaskedModule_1']['kernel']).all())\n\n  def test_propagate_masks_ablated_neurons_one_conv_layer(self):\n    \"\"\"Tests mask propagation on a single layer model.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jax.random.normal(\n                    self._rng,\n                    self._masked_conv_model.params['MaskedModule_0']['unmasked']\n                    ['kernel'].shape,\n                    dtype=jnp.float32),\n            'bias':\n                None,\n        },\n    }\n\n    refined_mask = masked.propagate_masks(mask)\n\n    # Since this is a single layer, should not affect mask at all.\n    self.assertTrue((mask['MaskedModule_0']['kernel'] ==\n                     refined_mask['MaskedModule_0']['kernel']).all())\n\n  def test_propagate_masks_ablated_neurons_two_conv_layers(self):\n    \"\"\"Tests mask propagation on a two-layer convolutional model.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros(\n                    self._masked_conv_model_twolayer.params['MaskedModule_0']\n                    ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n        'MaskedModule_1': {\n            'kernel':\n                jnp.ones(\n                    self._masked_conv_model_twolayer.params['MaskedModule_1']\n                    ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n    }\n\n    refined_mask = masked.propagate_masks(mask)\n\n    with self.subTest(name='layer_1'):\n      self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all())\n\n    # Since layer 1 is all zero, layer 2 is also effectively zero.\n    with self.subTest(name='layer_2'):\n      self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all())\n\n  def test_propagate_masks_ablated_neurons_three_conv_fc_layers(self):\n    \"\"\"Tests mask propagation on a two-layer convolutional model with dense.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros(self._masked_conv_fc_model_threelayer\n                          .params['MaskedModule_0']['unmasked']['kernel'].shape\n                         ),\n            'bias':\n                None,\n        },\n        'MaskedModule_1': {\n            'kernel':\n                jnp.ones(self._masked_conv_fc_model_threelayer\n                         .params['MaskedModule_1']['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n        'MaskedModule_2': {\n            'kernel':\n                jnp.ones(self._masked_conv_fc_model_threelayer\n                         .params['MaskedModule_2']['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n    }\n\n    refined_mask = masked.propagate_masks(mask)\n\n    with self.subTest(name='layer_1'):\n      self.assertTrue((refined_mask['MaskedModule_0']['kernel'] == 0).all())\n\n    # Since layer 1 is all zero, layer 2 is also effectively zero.\n    with self.subTest(name='layer_2'):\n      self.assertTrue((refined_mask['MaskedModule_1']['kernel'] == 0).all())\n\n    # Since layer 2 is all zero, layer 3 is also effectively zero.\n    with self.subTest(name='layer_3'):\n      self.assertTrue((refined_mask['MaskedModule_2']['kernel'] == 0).all())\n\n  def test_propagate_masks_ablated_neurons_mixed_conv_dense_layers(self):\n    \"\"\"Tests mask propagation on a two-layer convolutional/dense model.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros(\n                    self._masked_mixed_model_twolayer.params['MaskedModule_0']\n                    ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n        'MaskedModule_1': {\n            'kernel':\n                jnp.ones(\n                    self._masked_mixed_model_twolayer.params['MaskedModule_1']\n                    ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n    }\n\n    with self.assertRaisesRegex(\n        ValueError, 'propagate_masks requires knowledge of the spatial '\n        'dimensions of the previous layer. Use a functionally equivalent '\n        'conv. layer in place of a dense layer in a model with a mixed '\n        'conv/dense setting.'):\n      masked.propagate_masks(mask)\n\n  def test_mask_layer_sparsity_zero_mask(self):\n    \"\"\"Tests mask calculation with a zeroed mask.\"\"\"\n    zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])\n\n    self.assertEqual(\n        masked.mask_layer_sparsity(zero_mask['MaskedModule_0']), 0.)\n\n  def test_mask_layer_sparsity_half_mask(self):\n    \"\"\"Tests mask calculation with a half-filled mask.\"\"\"\n    half_mask = masked.shuffled_mask(self._masked_model, self._rng, 0.5)\n\n    self.assertAlmostEqual(\n        masked.mask_layer_sparsity(half_mask['MaskedModule_0']), 0.5)\n\n  def test_mask_layer_sparsity_ones_mask(self):\n    \"\"\"Tests mask calculation with a mask full of ones.\"\"\"\n    one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])\n\n    self.assertEqual(\n        masked.mask_layer_sparsity(one_mask['MaskedModule_0']), 1.)\n\n  def test_mask_sparsity_zero_mask(self):\n    \"\"\"Tests mask calculation with a zeroed mask.\"\"\"\n    zero_mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])\n\n    self.assertEqual(masked.mask_sparsity(zero_mask), 0.)\n\n  def test_mask_sparsity_ones_mask(self):\n    \"\"\"Tests mask calculation with a mask full of ones.\"\"\"\n    one_mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])\n\n    self.assertEqual(masked.mask_sparsity(one_mask), 1.)\n\n  def test_mask_sparsity_mixed_mask(self):\n    \"\"\"Tests mask calculation with a mask different sparsity masked layers.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel':\n                jnp.zeros(\n                    self._masked_conv_model_twolayer.params['MaskedModule_0']\n                    ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n        'MaskedModule_1': {\n            'kernel':\n                jnp.ones(\n                    self._masked_conv_model_twolayer.params['MaskedModule_1']\n                    ['unmasked']['kernel'].shape),\n            'bias':\n                None,\n        },\n    }\n\n    mask_sparsity = masked.mask_sparsity(mask)\n    true_sparsity = self._masked_conv_model_twolayer.params['MaskedModule_1'][\n        'unmasked']['kernel'].size / (\n            self._masked_conv_model_twolayer.params['MaskedModule_0']\n            ['unmasked']['kernel'].size + self._masked_conv_model_twolayer\n            .params['MaskedModule_1']['unmasked']['kernel'].size)\n\n    self.assertAlmostEqual(mask_sparsity, 1.0 - true_sparsity)\n\n  @parameterized.parameters(\n      # Simple masked 1-layer model.\n      (1,),\n      # Simple masked 2-layer model.\n      (2,),\n      # Simple masked 10-layer model.\n      (10,),\n  )\n  def test_generate_model_masks_depth_only(self, depth):\n    mask = masked.generate_model_masks(depth)\n    with self.subTest(name='test_model_mask_length'):\n      self.assertLen(mask, depth)\n\n    for i in range(depth):\n      with self.subTest(name=f'test_model_mask_value_layer_{i}'):\n        self.assertIsNone(mask[f'MaskedModule_{i}'])\n\n  @parameterized.parameters(\n      # Simple masked 1-layer model, no masked indices.\n      (1, []),\n      # Simple masked 2-layer model, second layer masked.\n      (2, (1,)),\n      # Simple masked 10-layer model, 4 layers masked.\n      (10, (1, 2, 3, 9)),\n  )\n  def test_generate_model_masks_indices(self, depth, indices):\n    mask = masked.generate_model_masks(depth, None, indices)\n\n    with self.subTest(name='test_model_mask_length'):\n      self.assertLen(mask, len(indices))\n\n    for i in indices:\n      with self.subTest(name=f'test_model_mask_value_layer_{i}'):\n        self.assertIsNone(mask[f'MaskedModule_{i}'])\n\n  @parameterized.parameters(\n      # Existing 1-layer mask.\n      (1, {'MaskedModule_0': np.ones(1)}, None),\n      (2, {'MaskedModule_0': np.ones(1),\n           'MaskedModule_1': np.ones(1)}, None),\n      # Existing 2-layer mask, only using one due to mask indices.\n      (2, {'MaskedModule_0': np.ones(1),\n           'MaskedModule_1': np.ones(1),}, (1,)),\n  )\n  def test_generate_model_masks_existing_mask(self, depth, existing_mask,\n                                              indices):\n    mask = masked.generate_model_masks(depth, existing_mask, indices)\n\n    # Need to differentiate from empty sequence by explicitly checking is None.\n    if indices is None:\n      indices = range(depth)\n\n    with self.subTest(name='test_model_mask_length'):\n      self.assertLen(mask, len(indices))\n\n    for i in indices:\n      with self.subTest(name=f'test_model_mask_value_layer_{i}'):\n        self.assertIsNotNone(mask[f'MaskedModule_{i}'])\n\n    # Ensure existing mask layers that aren't in indices aren't in output.\n    for i in range(depth):\n      if i not in indices:\n        with self.subTest(\n            name=f'test_model_mask_only_allowed_indices_layer_{i}'):\n          self.assertNotIn(f'MaskedModule_{i}', mask)\n\n  def test_generate_model_masks_invalid_depth_zero(self):\n    with self.assertRaisesWithLiteralMatch(ValueError,\n                                           'Invalid model depth: 0'):\n      masked.generate_model_masks(0)\n\n  def test_generate_model_masks_invalid_index_toohigh(self):\n    with self.assertRaisesWithLiteralMatch(\n        ValueError, 'Invalid indices for given depth (2): (1, 2)'):\n      masked.generate_model_masks(2, None, (1, 2))\n\n  def test_generate_model_masks_invalid_index_negative(self):\n    with self.assertRaisesWithLiteralMatch(\n        ValueError, 'Invalid indices for given depth (2): (-1, 2)'):\n      masked.generate_model_masks(2, None, (-1, 2))\n\n  def test_shuffled_neuron_no_input_ablation_mask_invalid_model(self):\n    \"\"\"Tests shuffled mask with model containing no masked layers.\"\"\"\n    with self.assertRaisesRegex(\n        ValueError, 'Model does not support masking, i.e. no layers are '\n        'wrapped by a MaskedModule.'):\n      masked.shuffled_neuron_no_input_ablation_mask(self._unmasked_model,\n                                                    self._rng, 0.5)\n\n  def test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity(self):\n    \"\"\"Tests shuffled mask with invalid sparsity.\"\"\"\n\n    with self.subTest(name='sparsity_too_small'):\n      with self.assertRaisesRegex(\n          ValueError, r'Given sparsity, -0.5, is not in range \\[0, 1\\]'):\n        masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,\n                                                      self._rng, -0.5)\n\n    with self.subTest(name='sparsity_too_large'):\n      with self.assertRaisesRegex(\n          ValueError, r'Given sparsity, 1.5, is not in range \\[0, 1\\]'):\n        masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,\n                                                      self._rng, 1.5)\n\n  def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self):\n    \"\"\"Tests shuffled mask generation, for 100% sparsity.\"\"\"\n    mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,\n                                                         self._rng, 1.0)\n\n    with self.subTest(name='shuffled_full_mask'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_full_mask_values'):\n      self.assertEqual(jnp.count_nonzero(mask['MaskedModule_0']['kernel']),\n                       jnp.prod(jnp.array(self._input_dimensions)))\n\n    with self.subTest(name='shuffled_full_no_input_ablation'):\n      # Check no row (neurons are columns) is completely ablated.\n      self.assertTrue((jnp.count_nonzero(\n          mask['MaskedModule_0']['kernel'], axis=0) != 0).all())\n\n    with self.subTest(name='shuffled_full_mask_not_masked_values'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_full_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty(self):\n    \"\"\"Tests shuffled mask generation, for 0% sparsity.\"\"\"\n    mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,\n                                                         self._rng, 0.0)\n\n    with self.subTest(name='shuffled_empty_mask'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_empty_mask_values'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())\n\n    with self.subTest(name='shuffled_empty_mask_not_masked_values'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    masked_output = self._masked_model(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_empty_dense_values'):\n      self.assertTrue(jnp.isclose(masked_output, self._unmasked_output).all())\n\n    with self.subTest(name='shuffled_empty_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape, self._unmasked_output.shape)\n\n  def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self):\n    \"\"\"Tests shuffled mask generation, for a half-full mask.\"\"\"\n    mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,\n                                                         self._rng, 0.5)\n    param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][\n        'kernel'].shape\n\n    with self.subTest(name='shuffled_mask_values'):\n      self.assertEqual(\n          jnp.sum(mask['MaskedModule_0']['kernel']),\n          param_shape[0]//2 * param_shape[1])\n\n    with self.subTest(name='shuffled_half_no_input_ablation'):\n      # Check no row (neurons are columns) is completely ablated.\n      self.assertTrue((jnp.count_nonzero(\n          mask['MaskedModule_0']['kernel'], axis=0) != 0).all())\n\n  def test_shuffled_neuron_no_input_ablation_mask_sparsity_quarter_full(self):\n    \"\"\"Tests shuffled mask generation, for a half-full mask.\"\"\"\n    mask = masked.shuffled_neuron_no_input_ablation_mask(self._masked_model,\n                                                         self._rng, 0.25)\n    param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][\n        'kernel'].shape\n\n    with self.subTest(name='shuffled_mask_values'):\n      self.assertEqual(\n          jnp.sum(mask['MaskedModule_0']['kernel']),\n          0.75 * param_shape[0] * param_shape[1])\n\n    with self.subTest(name='shuffled_half_no_input_ablation'):\n      # Check no row (neurons are columns) is completely ablated.\n      self.assertTrue((jnp.count_nonzero(\n          mask['MaskedModule_0']['kernel'], axis=0) != 0).all())\n\n  def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer(self):\n    \"\"\"Tests shuffled mask generation for two layers, and 100% sparsity.\"\"\"\n    mask = masked.shuffled_neuron_no_input_ablation_mask(\n        self._masked_model_twolayer, self._rng, 1.0)\n\n    with self.subTest(name='shuffled_full_mask_layer1'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_full_mask_values_layer1'):\n      self.assertEqual(jnp.count_nonzero(mask['MaskedModule_0']['kernel']),\n                       jnp.prod(jnp.array(self._input_dimensions)))\n\n    with self.subTest(name='shuffled_full_mask_not_masked_values_layer1'):\n      self.assertIsNone(mask['MaskedModule_0']['bias'])\n\n    with self.subTest(name='shuffled_full_no_input_ablation_layer1'):\n      # Check no row (neurons are columns) is completely ablated.\n      self.assertTrue((jnp.count_nonzero(\n          mask['MaskedModule_0']['kernel'], axis=0) != 0).all())\n\n    with self.subTest(name='shuffled_full_mask_layer2'):\n      self.assertIn('MaskedModule_1', mask)\n\n    with self.subTest(name='shuffled_full_mask_values_layer2'):\n      self.assertEqual(jnp.count_nonzero(mask['MaskedModule_1']['kernel']),\n                       jnp.prod(MaskedTwoLayerDense.NUM_FEATURES[0]))\n\n    with self.subTest(name='shuffled_full_mask_not_masked_values_layer2'):\n      self.assertIsNone(mask['MaskedModule_1']['bias'])\n\n    with self.subTest(name='shuffled_full_no_input_ablation_layer2'):\n      # Note: check no *inputs* are ablated, and inputs < num_neurons.\n      self.assertEqual(\n          jnp.sum(jnp.count_nonzero(mask['MaskedModule_1']['kernel'], axis=0)),\n          MaskedTwoLayerDense.NUM_FEATURES[0])\n\n    masked_output = self._masked_model_twolayer(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_full_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape,\n                               self._unmasked_output_twolayer.shape)\n\n  def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolayer(self):\n    \"\"\"Tests shuffled mask generation for two layers, for 0% sparsity.\"\"\"\n    mask = masked.shuffled_neuron_no_input_ablation_mask(\n        self._masked_model_twolayer, self._rng, 0.0)\n\n    with self.subTest(name='shuffled_empty_mask_layer1'):\n      self.assertIn('MaskedModule_0', mask)\n\n    with self.subTest(name='shuffled_empty_mask_values_layer1'):\n      self.assertTrue((mask['MaskedModule_0']['kernel'] == 1).all())\n\n    with self.subTest(name='shuffled_empty_mask_layer2'):\n      self.assertIn('MaskedModule_1', mask)\n\n    with self.subTest(name='shuffled_empty_mask_values_layer2'):\n      self.assertTrue((mask['MaskedModule_1']['kernel'] == 1).all())\n\n    masked_output = self._masked_model_twolayer(self._input, mask=mask)\n\n    with self.subTest(name='shuffled_empty_dense_values'):\n      self.assertTrue(\n          jnp.isclose(masked_output, self._unmasked_output_twolayer).all())\n\n    with self.subTest(name='shuffled_empty_mask_dense_shape'):\n      self.assertSequenceEqual(masked_output.shape,\n                               self._unmasked_output_twolayer.shape)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/pruning.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Functions for pruning FLAX masked models.\"\"\"\nfrom collections import abc\nfrom typing import Any, Callable, Mapping, Optional, Union\n\nimport flax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import masked\n\n\ndef weight_magnitude(weights):\n  \"\"\"Creates weight magnitude-based saliencies, given a weight matrix.\"\"\"\n  return jnp.absolute(weights)\n\n\ndef prune(\n    model,\n    pruning_rate,\n    saliency_fn = weight_magnitude,\n    mask = None,\n    compare_fn = jnp.greater):\n  \"\"\"Returns a mask for a model where the params in each layer are pruned using a saliency function.\n\n  Args:\n    model: The model to create a pruning mask for.\n    pruning_rate: The fraction of lowest magnitude saliency weights that are\n      pruned. If a float, the same rate is used for all layers, otherwise if it\n      is a mapping, it must contain a rate for all masked layers in the model.\n    saliency_fn: A function that returns a float number used to rank\n      the importance of individual weights in the layer.\n    mask: If the model has an existing mask, the mask will be applied before\n      pruning the model.\n    compare_fn: A pairwise operator to compare saliency with threshold, and\n      return True if the saliency indicates the value should not be masked.\n\n  Returns:\n    A pruned mask for the given model.\n  \"\"\"\n  if not mask:\n    mask = masked.simple_mask(model, jnp.ones, masked.WEIGHT_PARAM_NAMES)\n\n  if not isinstance(pruning_rate, abc.Mapping):\n    pruning_rate_dict = {}\n    for param_name, _ in masked.iterate_mask(mask):\n      # Get the layer name from the parameter's full name/path.\n      layer_name = param_name.split('/')[-2]\n      pruning_rate_dict[layer_name] = pruning_rate\n    pruning_rate = pruning_rate_dict\n\n  for param_path, param_mask in masked.iterate_mask(mask):\n    split_param_path = param_path.split('/')\n    layer_name = split_param_path[-2]\n    param_name = split_param_path[-1]\n\n    # If we don't have a pruning rate for the given layer, don't mask it.\n    if layer_name in pruning_rate and mask[layer_name][param_name] is not None:\n      param_value = model.params[layer_name][\n          masked.MaskedModule.UNMASKED][param_name]\n\n      # Here any existing mask is first applied to weight matrix.\n      # Note: need to check explicitly is not None for np array.\n      if param_mask is not None:\n        saliencies = saliency_fn(param_mask * param_value)\n      else:\n        saliencies = saliency_fn(param_value)\n\n      # TODO: Use partition here (partial sort) instead of sort,\n      # since it's O(N), not O(N log N), however JAX doesn't support it.\n      sorted_param = jnp.sort(jnp.abs(saliencies.flatten()))\n\n      # Figure out the weight magnitude threshold.\n      threshold_index = jnp.round(pruning_rate[layer_name] *\n                                  sorted_param.size).astype(jnp.int32)\n      threshold = sorted_param[threshold_index]\n\n      mask[layer_name][param_name] = jnp.array(\n          compare_fn(saliencies, threshold), dtype=jnp.int32)\n\n  return mask\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/pruning_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.pruning.pruning.\"\"\"\nfrom typing import Mapping, Optional, Sequence\n\nfrom absl.testing import absltest\nimport flax\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.pruning import masked\nfrom rigl.experimental.jax.pruning import pruning\n\n\nclass MaskedDense(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_0'] if mask else None)\n\n\nclass MaskedTwoLayerDense(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Dense Masked Network.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (32, 64)\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[0],\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_0'] if mask else None)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[1],\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_1'] if mask else None)\n\n\nclass MaskedConv(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Conv Masked Network.\"\"\"\n\n  NUM_FEATURES: int = 32\n\n  def apply(self,\n            inputs,\n            mask = None):\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        kernel_size=(3, 3),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n\n\nclass MaskedTwoLayerConv(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Conv Masked Network.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (16, 32)\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[0],\n        kernel_size=(5, 5),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[1],\n        kernel_size=(3, 3),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_1'] if mask is not None else None)\n\n\nclass PruningTest(absltest.TestCase):\n  \"\"\"Tests the flax layer pruning module.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._batch_size = 2\n    self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)\n    self._input = jnp.ones(*self._input_shape)\n\n    _, initial_params = MaskedDense.init_by_shape(self._rng,\n                                                  (self._input_shape,))\n    self._masked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)\n\n    _, initial_params = MaskedTwoLayerDense.init_by_shape(\n        self._rng, (self._input_shape,))\n    self._masked_model_twolayer = flax.deprecated.nn.Model(\n        MaskedTwoLayerDense, initial_params)\n\n    _, initial_params = MaskedConv.init_by_shape(self._rng,\n                                                 (self._input_shape,))\n    self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv,\n                                                       initial_params)\n\n    _, initial_params = MaskedTwoLayerConv.init_by_shape(\n        self._rng, (self._input_shape,))\n    self._masked_conv_model_twolayer = flax.deprecated.nn.Model(\n        MaskedTwoLayerConv, initial_params)\n\n  def test_prune_single_layer_dense_no_mask(self):\n    \"\"\"Tests pruning of single dense layer without an existing mask.\"\"\"\n    pruned_mask = pruning.prune(self._masked_model, 0.5)\n    mask_sparsity = masked.mask_sparsity(pruned_mask)\n\n    with self.subTest(name='test_mask_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_sparsity'):\n      self.assertAlmostEqual(mask_sparsity, 0.5, places=3)\n\n  def test_prune_single_layer_local_pruning(self):\n    \"\"\"Test pruning of model with a single layer, and local pruning schedule.\"\"\"\n    pruned_mask = pruning.prune(self._masked_model, {\n        'MaskedModule_0': 0.5,\n    })\n    mask_sparsity = masked.mask_sparsity(pruned_mask)\n\n    with self.subTest(name='test_mask_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_sparsity'):\n      self.assertAlmostEqual(mask_sparsity, 0.5, places=3)\n\n  def test_prune_single_layer_dense_with_mask(self):\n    \"\"\"Tests pruning of single dense layer with an existing mask.\"\"\"\n    pruned_mask = pruning.prune(\n        self._masked_model,\n        0.5,\n        mask=masked.shuffled_mask(self._masked_model, self._rng, 0.95))\n    mask_sparsity = masked.mask_sparsity(pruned_mask)\n\n    with self.subTest(name='test_mask_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_sparsity'):\n      self.assertAlmostEqual(mask_sparsity, 0.95, places=3)\n\n  def test_prune_two_layers_dense_no_mask(self):\n    \"\"\"Tests pruning of model with two dense layers without an existing mask.\"\"\"\n    pruned_mask = pruning.prune(self._masked_model_twolayer, 0.5)\n    mask_sparsity = masked.mask_sparsity(pruned_mask)\n\n    with self.subTest(name='test_mask_layer1_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_layer2_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])\n\n    with self.subTest(name='test_mask_sparsity'):\n      self.assertAlmostEqual(mask_sparsity, 0.5, places=3)\n\n  def test_prune_two_layer_local_pruning_rate(self):\n    \"\"\"Test pruning of model with two layers, and a local pruning schedule.\"\"\"\n    pruned_mask = pruning.prune(self._masked_model_twolayer, {\n        'MaskedModule_1': 0.5,\n    })\n    mask_layer_0_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_0'])\n    mask_layer_1_sparsity = masked.mask_sparsity(pruned_mask['MaskedModule_1'])\n\n    with self.subTest(name='test_mask_layer1_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_layer2_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])\n\n    with self.subTest(name='test_mask_layer_0_sparsity'):\n      self.assertEqual(mask_layer_0_sparsity, 0.)\n\n    with self.subTest(name='test_mask_layer_1_sparsity'):\n      self.assertAlmostEqual(mask_layer_1_sparsity, 0.5, places=3)\n\n  def test_prune_one_layer_conv_no_mask(self):\n    \"\"\"Tests pruning of model with one conv. layer without an existing mask.\"\"\"\n    pruned_mask = pruning.prune(self._masked_conv_model, 0.5)\n    mask_sparsity = masked.mask_sparsity(pruned_mask)\n\n    with self.subTest(name='test_mask_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_sparsity'):\n      self.assertAlmostEqual(mask_sparsity, 0.5, places=1)\n\n  def test_prune_one_layer_conv_with_mask(self):\n    \"\"\"Tests pruning of model with one conv. layer with an existing mask.\"\"\"\n    pruned_mask = pruning.prune(\n        self._masked_conv_model,\n        0.5,\n        mask=masked.shuffled_mask(self._masked_model, self._rng, 0.95))\n    mask_sparsity = masked.mask_sparsity(pruned_mask)\n\n    with self.subTest(name='test_mask_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_sparsity'):\n      self.assertAlmostEqual(mask_sparsity, 0.95, places=3)\n\n  def test_prune_two_layer_conv_no_mask(self):\n    \"\"\"Tests pruning of model with two conv. layer without an existing mask.\"\"\"\n    pruned_mask = pruning.prune(self._masked_conv_model_twolayer, 0.5)\n    mask_sparsity = masked.mask_sparsity(pruned_mask)\n\n    with self.subTest(name='test_mask_layer1_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_0']['kernel'])\n\n    with self.subTest(name='test_mask_layer2_param_not_none'):\n      self.assertNotEmpty(pruned_mask['MaskedModule_1']['kernel'])\n\n    with self.subTest(name='test_mask_sparsity'):\n      self.assertAlmostEqual(mask_sparsity, 0.5, places=3)\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/symmetry.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Code for analyzing symmetries in NN.\"\"\"\n\nimport functools\nimport math\nimport operator\nfrom typing import Dict, Optional, Union\n\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom rigl.experimental.jax.pruning import masked\nfrom rigl.experimental.jax.utils import utils\n\n\ndef count_permutations_mask_layer(\n    mask_layer,\n    next_mask_layer = None,\n    parameter_key = 'kernel'):\n  \"\"\"Calculates the number of permutations for a layer, given binary masks.\n\n  Args:\n   mask_layer: The binary weight mask of a dense/conv layer, where last\n     dimension is number of neurons/filters.\n   next_mask_layer: The binary weight mask of the following a dense/conv layer,\n     or None if this is the last layer.\n   parameter_key: The name of the parameter to count the permutations of in each\n     layer.\n\n  Returns:\n   A dictionary with stats on the permutation structure of a mask, including\n   the number of symmetric permutations of the mask, number of unique mask\n   columns, count of the zeroed out (structurally pruned) neurons, and total\n   number of neurons/filters.\n  \"\"\"\n  # Have to check 'is None' since mask_layer[parameter_key] is jnp.array.\n  if not mask_layer or parameter_key not in mask_layer or mask_layer[\n      parameter_key] is None:\n    return {\n        'permutations': 1,\n        'zeroed_neurons': 0,\n        'total_neurons': 0,\n        'unique_neurons': 0,\n    }\n\n  mask = mask_layer[parameter_key]\n\n  num_neurons = mask.shape[-1]\n\n  # Initialize with stats for an empty mask.\n  mask_stats = {\n      'permutations': 0,\n      'zeroed_neurons': num_neurons,\n      'total_neurons': num_neurons,\n      'unique_neurons': 0,\n  }\n\n  # Re-shape masks as 1D, in case they are 2D (e.g. convolutional).\n  connection_mask = mask.reshape(-1, num_neurons)\n\n  # Only consider non-zero columns (in JAX neurons/filters are last index).\n  non_zero_neurons = ~jnp.all(connection_mask == 0, axis=0)\n\n  # Count only zeroed neurons in the current layer.\n  zeroed_count = num_neurons - jnp.count_nonzero(non_zero_neurons)\n\n  # Special case where all neurons in current layer are ablated.\n  if zeroed_count == num_neurons:\n    return mask_stats\n\n  # Have to check is None since next_mask_layer[parameter_key] is jnp.array.\n  if next_mask_layer and parameter_key in next_mask_layer and next_mask_layer[\n      parameter_key] is not None:\n    next_mask = next_mask_layer[parameter_key]\n\n    # Re-shape masks as 1D, in case they are 2D (e.g. convolutional).\n    next_connection_mask = next_mask.T.reshape(-1, num_neurons)\n\n    # Update with neurons that are non-zero in outgoing connections too.\n    non_zero_neurons &= ~jnp.all(next_connection_mask == 0, axis=0)\n\n    # Remove rows corresponding to neurons that are ablated.\n    next_connection_mask = next_connection_mask[:, non_zero_neurons]\n\n    connection_mask = connection_mask[:, non_zero_neurons]\n\n    # Combine the outgoing and incoming masks in one vector per-neuron.\n    connection_mask = jnp.concatenate(\n        (connection_mask, next_connection_mask), axis=0)\n\n  else:\n    connection_mask = connection_mask[:, non_zero_neurons]\n\n  # Effectively no connections between these two layers.\n  if not connection_mask.size:\n    return mask_stats\n\n  # Note: np.unique not implemented in JAX numpy yet.\n  _, unique_counts = np.unique(connection_mask, axis=-1, return_counts=True)\n\n  # Convert from device array.\n  mask_stats['zeroed_neurons'] = int(zeroed_count)\n\n  mask_stats['permutations'] = functools.reduce(\n      operator.mul, (np.math.factorial(t) for t in unique_counts))\n  mask_stats['unique_neurons'] = len(unique_counts)\n\n  return mask_stats\n\n\ndef count_permutations_mask(mask):\n  \"\"\"Calculates the number of permutations for a given model mask.\n\n  Args:\n    mask: Model masks to check, similar to Model.params.\n\n  Returns:\n   A dictionary with stats on the permutation structure of a mask, including\n   the number of symmetric permutations of the mask, number of unique mask\n   columns, count of the zeroed out (structurally pruned) neurons, and total\n   number of neurons/filters.\n  \"\"\"\n  sum_keys = ('total_neurons', 'unique_neurons', 'zeroed_neurons')\n  product_keys = ('permutations',)\n\n  # Count permutation stats for each pairwise set of layers.\n  # Note: I tried doing this with more_itertools.pairwise/itertools.chain, but\n  # there is a type conflict in passing iterators of different types to\n  # itertools.chain.\n  counts = [\n      count_permutations_mask_layer(layer, next_layer)\n      for layer, next_layer in utils.pairwise_longest(mask.values())\n  ]\n\n  sum_stats = {}\n  for key in sum_keys:\n    sum_stats[key] = functools.reduce(operator.add, (z[key] for z in counts))\n\n  product_stats = {}\n  for key in product_keys:\n    product_stats[key] = functools.reduce(operator.mul,\n                                          (z[key] for z in counts))\n\n  return {**sum_stats, **product_stats}\n\n\ndef get_mask_stats(mask):\n  \"\"\"Calculates an array of mask statistics.\n\n  Args:\n    mask: A model mask to calculate the statistics of.\n\n  Returns:\n    A dictionary, containing a set of mask statistics.\n  \"\"\"\n  mask_stats = count_permutations_mask(mask)\n  mask_stats.update({\n      'sparsity': masked.mask_sparsity(mask),\n      'permutation_num_digits': len(str(mask_stats['permutations'])),\n      'permutation_log10': math.log10(mask_stats['permutations'] + 1),\n  })\n\n  return mask_stats\n"
  },
  {
    "path": "rigl/experimental/jax/pruning/symmetry_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.pruning.symmetry.\"\"\"\nimport functools\nimport math\nimport operator\nfrom typing import Mapping, Optional, Sequence\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom rigl.experimental.jax.pruning import masked\nfrom rigl.experimental.jax.pruning import symmetry\n\n\nclass MaskedDense(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Dense Masked Network.\n\n  Attributes:\n    NUM_FEATURES: The number of neurons in the single dense layer.\n  \"\"\"\n\n  NUM_FEATURES: int = 16\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n\n\nclass MaskedConv(flax.deprecated.nn.Module):\n  \"\"\"Single-layer Conv Masked Network.\n\n  Attributes:\n    NUM_FEATURES: The number of filters in the single conv layer.\n  \"\"\"\n\n  NUM_FEATURES: int = 16\n\n  def apply(self,\n            inputs,\n            mask = None):\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES,\n        kernel_size=(3, 3),\n        wrapped_module=flax.deprecated.nn.Conv,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n\n\nclass MaskedTwoLayerDense(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Dense Masked Network.\n\n  Attributes:\n    NUM_FEATURES: The number of neurons in the dense layers.\n  \"\"\"\n\n  NUM_FEATURES: Sequence[int] = (16, 32)\n\n  def apply(self,\n            inputs,\n            mask = None):\n    inputs = inputs.reshape(inputs.shape[0], -1)\n    inputs = masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[0],\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_0'] if mask is not None else None)\n    inputs = flax.deprecated.nn.relu(inputs)\n    return masked.MaskedModule(\n        inputs,\n        features=self.NUM_FEATURES[1],\n        wrapped_module=flax.deprecated.nn.Dense,\n        mask=mask['MaskedModule_1'] if mask is not None else None)\n\n\nclass SymmetryTest(parameterized.TestCase):\n  \"\"\"Tests symmetry analysis methods.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._rng = jax.random.PRNGKey(42)\n    self._batch_size = 2\n    self._input_shape = ((self._batch_size, 2, 2, 1), jnp.float32)\n    self._flat_input_shape = ((self._batch_size, 2 * 2 * 1), jnp.float32)\n\n    _, initial_params = MaskedDense.init_by_shape(self._rng,\n                                                  (self._flat_input_shape,))\n    self._masked_model = flax.deprecated.nn.Model(MaskedDense, initial_params)\n\n    _, initial_params = MaskedConv.init_by_shape(self._rng,\n                                                 (self._input_shape,))\n    self._masked_conv_model = flax.deprecated.nn.Model(MaskedConv,\n                                                       initial_params)\n\n    _, initial_params = MaskedTwoLayerDense.init_by_shape(\n        self._rng, (self._flat_input_shape,))\n    self._masked_two_layer_model = flax.deprecated.nn.Model(\n        MaskedTwoLayerDense, initial_params)\n\n  def test_count_permutations_layer_mask_full(self):\n    \"\"\"Tests count of weight permutations in a full mask.\"\"\"\n    mask_layer = {\n        'kernel':\n            jnp.ones(self._masked_model.params['MaskedModule_0']['unmasked']\n                     ['kernel'].shape),\n    }\n\n    stats = symmetry.count_permutations_mask_layer(mask_layer)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 1)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'],\n                       math.factorial(MaskedDense.NUM_FEATURES))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedDense.NUM_FEATURES)\n\n  def test_count_permutations_layer_mask_empty(self):\n    \"\"\"Tests count of weight permutations in an empty mask.\"\"\"\n    mask_layer = {\n        'kernel':\n            jnp.zeros(self._masked_model.params['MaskedModule_0']['unmasked']\n                      ['kernel'].shape),\n    }\n\n    stats = symmetry.count_permutations_mask_layer(mask_layer)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 0)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], 0)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], MaskedDense.NUM_FEATURES)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedDense.NUM_FEATURES)\n\n  def test_count_permutations_conv_layer_mask_full(self):\n    \"\"\"Tests count of weight permutations in a full mask for a conv. layer.\"\"\"\n    mask_layer = {\n        'kernel':\n            jnp.ones(self._masked_conv_model.params['MaskedModule_0']\n                     ['unmasked']['kernel'].shape),\n    }\n\n    stats = symmetry.count_permutations_mask_layer(mask_layer)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 1)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'],\n                       math.factorial(MaskedConv.NUM_FEATURES))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)\n\n  def test_count_permutations_conv_layer_mask_empty(self):\n    \"\"\"Tests count of weight permutations in an empty mask for a conv. layer.\"\"\"\n    mask_layer = {\n        'kernel':\n            jnp.zeros(self._masked_conv_model.params['MaskedModule_0']\n                      ['unmasked']['kernel'].shape),\n    }\n\n    stats = symmetry.count_permutations_mask_layer(mask_layer)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 0)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], 0)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)\n\n  def test_count_permutations_layer_mask_known_perm(self):\n    \"\"\"Tests count of weight permutations in a mask with known # permutations.\"\"\"\n    param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][\n        'kernel'].shape\n\n    # Create two unique random mask rows.\n    row_type_one = jax.random.bernoulli(\n        self._rng, p=0.3, shape=(param_shape[0],)).astype(jnp.int32)\n    row_type_two = jax.random.bernoulli(\n        self._rng, p=0.9, shape=(param_shape[0],)).astype(jnp.int32)\n\n    # Create mask by repeating the two unique rows.\n    repeat_one = param_shape[-1] // 3\n    repeat_two = param_shape[-1] - repeat_one\n    mask_layer = {'kernel': jnp.concatenate(\n        (jnp.repeat(row_type_one[:, jnp.newaxis], repeat_one, axis=-1),\n         jnp.repeat(row_type_two[:, jnp.newaxis], repeat_two, axis=-1)),\n        axis=-1)}\n\n    stats = symmetry.count_permutations_mask_layer(mask_layer)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 2)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'],\n                       math.factorial(repeat_one) * math.factorial(repeat_two))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], param_shape[-1])\n\n  def test_count_permutations_layer_mask_known_perm_zeros(self):\n    \"\"\"Tests count of weight permutations in a mask with zeroed neurons.\"\"\"\n    param_shape = self._masked_model.params['MaskedModule_0']['unmasked'][\n        'kernel'].shape\n\n    # Create two unique random mask rows.\n    row_type_one = jax.random.bernoulli(\n        self._rng, p=0.3, shape=(param_shape[0],)).astype(jnp.int32)\n    row_type_two = jnp.zeros(shape=(param_shape[0],), dtype=jnp.int32)\n\n    # Create mask by repeating the two unique rows.\n    repeat_one = param_shape[-1] // 3\n    repeat_two = param_shape[-1] - repeat_one\n    mask_layer = {'kernel': jnp.concatenate(\n        (jnp.repeat(row_type_one[:, jnp.newaxis], repeat_one, axis=-1),\n         jnp.repeat(row_type_two[:, jnp.newaxis], repeat_two, axis=-1)),\n        axis=-1)}\n\n    stats = symmetry.count_permutations_mask_layer(mask_layer)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 1)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], math.factorial(repeat_one))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], repeat_two)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], param_shape[-1])\n\n  def test_count_permutations_shuffled_full_mask(self):\n    \"\"\"Tests count of weight permutations on a generated full mask.\"\"\"\n    mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=1)\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 0)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], 0)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)\n\n  def test_count_permutations_shuffled_empty_mask(self):\n    \"\"\"Tests count of weight permutations on a generated empty mask.\"\"\"\n    mask = masked.shuffled_mask(self._masked_model, rng=self._rng, sparsity=0)\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 1)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'],\n                       math.factorial(MaskedConv.NUM_FEATURES))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)\n\n  def test_count_permutations_mask_layer_twolayer_known_symmetric(self):\n    \"\"\"Tests count of permutations in a known mask with 2 permutations.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T,\n        },\n        'MaskedModule_1': {\n            'kernel': jnp.array(((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T,\n        },\n    }\n\n    stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'],\n                                                   mask['MaskedModule_1'])\n\n    with self.subTest(name='count_permutations_unique'):\n      self.assertEqual(stats['unique_neurons'], 2)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], 2)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'],\n                       mask['MaskedModule_0']['kernel'].shape[-1])\n\n  # Note: Can't pass jnp.array here since global, InitGoogle() not called yet.\n  @parameterized.parameters(\n      # Tests mask with 1 permutation only if both layers are considered.\n      ({\n          'MaskedModule_0': {\n              'kernel': np.array(((1, 0), (1, 0), (0, 1))).T,\n          },\n          'MaskedModule_1': {\n              'kernel':\n                  np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T,\n          },\n      }, 3, 1, 0, 3),\n      # Tests mask zero count with an ablated neuron in first layer.\n      ({\n          'MaskedModule_0': {\n              'kernel': np.array(((1, 0), (1, 0), (0, 0))).T,\n          },\n          'MaskedModule_1': {\n              'kernel':\n                  np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T,\n          },\n      }, 2, 1, 1, 3),\n      # Tests mask zero count with first layer completely ablated.\n      ({\n          'MaskedModule_0': {\n              'kernel': np.array(((0, 0), (0, 0), (0, 0))).T,\n          },\n          'MaskedModule_1': {\n              'kernel':\n                  np.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T,\n          },\n      }, 0, 0, 3, 3),\n      # Tests mask zero count with second layer completely ablated.\n      ({\n          'MaskedModule_0': {\n              'kernel': np.array(((1, 0), (1, 0), (0, 1))).T,\n          },\n          'MaskedModule_1': {\n              'kernel':\n                  np.array(((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0))).T,\n          },\n      }, 0, 0, 3, 3),\n      # \"\"\"Tests layer 1 permutation matrix mask, having only 1 permutation.\"\"\"\n      ({\n          'MaskedModule_0': {\n              'kernel': np.array(((1, 0, 0), (0, 1, 0), (0, 0, 1))).T,\n          },\n          'MaskedModule_1': {\n              'kernel':\n                  np.array(((1, 1, 1), (0, 1, 1), (1, 1, 1), (1, 1, 1))).T,\n          },\n      }, 3, 1, 0, 3),\n      )\n  def test_count_permutations_mask_layer_twolayer(self, mask, unique,\n                                                  permutations, zeroed, total):\n    \"\"\"Test mask permutations if both layers are considered.\"\"\"\n    stats = symmetry.count_permutations_mask_layer(mask['MaskedModule_0'],\n                                                   mask['MaskedModule_1'])\n\n    with self.subTest(name='count_permutations_unique'):\n      self.assertEqual(stats['unique_neurons'], unique)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], permutations)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], zeroed)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], total)\n\n  def test_count_permutations_mask_full(self):\n    \"\"\"Tests count of weight permutations in a full mask.\"\"\"\n    mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 1)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'],\n                       math.factorial(MaskedDense.NUM_FEATURES))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)\n\n  def test_count_permutations_mask_bn_layer_full(self):\n    \"\"\"Tests count of permutations on a mask for model with non-masked layers.\"\"\"\n    mask = masked.simple_mask(self._masked_model, jnp.ones, ['kernel'])\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 1)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'],\n                       math.factorial(MaskedDense.NUM_FEATURES))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)\n\n  def test_count_permutations_mask_empty(self):\n    \"\"\"Tests count of weight permutations in an empty mask.\"\"\"\n    mask = masked.simple_mask(self._masked_model, jnp.zeros, ['kernel'])\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 0)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], 0)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], MaskedConv.NUM_FEATURES)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'], MaskedConv.NUM_FEATURES)\n\n  def test_count_permutations_mask_twolayer_full(self):\n    \"\"\"Tests count of weight permutations in a full mask for 2 layers.\"\"\"\n    mask = masked.simple_mask(self._masked_two_layer_model, jnp.ones,\n                              ['kernel'])\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 2)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(\n          stats['permutations'],\n          functools.reduce(\n              operator.mul,\n              [math.factorial(x) for x in MaskedTwoLayerDense.NUM_FEATURES]))\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 0)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'],\n                       sum(MaskedTwoLayerDense.NUM_FEATURES))\n\n  def test_count_permutations_mask_twolayers_empty(self):\n    \"\"\"Tests count of weight permutations in an empty mask for 2 layers.\"\"\"\n    mask = masked.simple_mask(self._masked_two_layer_model, jnp.zeros,\n                              ['kernel'])\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 0)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], 0)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'],\n                       sum(MaskedTwoLayerDense.NUM_FEATURES))\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(stats['total_neurons'],\n                       sum(MaskedTwoLayerDense.NUM_FEATURES))\n\n  def test_count_permutations_mask_twolayer_known_symmetric(self):\n    \"\"\"Tests count of permutations in a known mask with 4 permutations.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T\n        },\n        'MaskedModule_1': {\n            'kernel': jnp.array(((1, 1, 0), (0, 0, 1), (0, 0, 1), (0, 0, 0))).T\n        }\n    }\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_full_mask_unique'):\n      self.assertEqual(stats['unique_neurons'], 4)\n\n    with self.subTest(name='count_permutations_full_mask_permutations'):\n      self.assertEqual(stats['permutations'], 4)\n\n    with self.subTest(name='count_permutations_full_mask_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 1)\n\n    with self.subTest(name='Count_permutations_full_mask_total'):\n      self.assertEqual(\n          stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] +\n          mask['MaskedModule_1']['kernel'].shape[-1])\n\n  def test_count_permutations_mask_twolayer_known_non_symmetric(self):\n    \"\"\"Tests mask with 1 permutation only if both layers are considered.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T\n        },\n        'MaskedModule_1': {\n            'kernel': jnp.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T\n        }\n    }\n\n    stats = symmetry.count_permutations_mask(mask)\n\n    with self.subTest(name='count_permutations_unique'):\n      self.assertEqual(stats['unique_neurons'], 6)\n\n    with self.subTest(name='count_permutations_permutations'):\n      self.assertEqual(stats['permutations'], 1)\n\n    with self.subTest(name='count_permutations_zeroed'):\n      self.assertEqual(stats['zeroed_neurons'], 1)\n\n    with self.subTest(name='count_permutations_total'):\n      self.assertEqual(\n          stats['total_neurons'], mask['MaskedModule_0']['kernel'].shape[-1] +\n          mask['MaskedModule_1']['kernel'].shape[-1])\n\n  def test_get_mask_stats_keys_values(self):\n    \"\"\"Tests the returned dict has the required keys, and value types/ranges.\"\"\"\n    mask = {\n        'MaskedModule_0': {\n            'kernel': jnp.array(((1, 0), (1, 0), (0, 1))).T\n        },\n        'MaskedModule_1': {\n            'kernel': jnp.array(((1, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, 0))).T\n        }\n    }\n\n    mask_stats = symmetry.get_mask_stats(mask)\n\n    with self.subTest(name='sparsity_exists'):\n      self.assertIn('sparsity', mask_stats)\n\n    with self.subTest(name='sparsity_value'):\n      self.assertBetween(mask_stats['sparsity'], 0.0, 1.0)\n\n    with self.subTest(name='permutation_num_digits_exists'):\n      self.assertIn('permutation_num_digits', mask_stats)\n\n    with self.subTest(name='permutation_num_digits_value'):\n      self.assertGreaterEqual(mask_stats['permutation_num_digits'], 0.0)\n\n    with self.subTest(name='permutation_log10_exists'):\n      self.assertIn('permutation_log10', mask_stats)\n\n    with self.subTest(name='permutation_log10_value'):\n      self.assertGreaterEqual(mask_stats['permutation_log10'], 0.0)\n\n    with self.subTest(name='unique_neurons_exists'):\n      self.assertIn('unique_neurons', mask_stats)\n\n    with self.subTest(name='unique_neurons_value'):\n      self.assertEqual(mask_stats['unique_neurons'], 6)\n\n    with self.subTest(name='permutations_exists'):\n      self.assertIn('permutations', mask_stats)\n\n    with self.subTest(name='permutations_value'):\n      self.assertEqual(mask_stats['permutations'], 1)\n\n    with self.subTest(name='zeroed_neurons_exists'):\n      self.assertIn('zeroed_neurons', mask_stats)\n\n    with self.subTest(name='zeroed_neurons_value'):\n      self.assertEqual(mask_stats['zeroed_neurons'], 1)\n\n    with self.subTest(name='total_neurons_exists'):\n      self.assertIn('total_neurons', mask_stats)\n\n    with self.subTest(name='total_neurons_value'):\n      self.assertEqual(mask_stats['total_neurons'],\n                       mask['MaskedModule_0']['kernel'].shape[-1] +\n                       mask['MaskedModule_1']['kernel'].shape[-1])\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/random_mask.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Weight Symmetry: Train model with randomly sampled sparse mask.\"\"\"\nimport ast\nfrom os import path\nfrom typing import List, Sequence\nimport uuid\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nimport flax\nfrom flax.metrics import tensorboard\nfrom flax.training import lr_schedule\nimport jax\nimport jax.numpy as jnp\nfrom rigl.experimental.jax.datasets import dataset_factory\nfrom rigl.experimental.jax.models import model_factory\nfrom rigl.experimental.jax.pruning import mask_factory\nfrom rigl.experimental.jax.pruning import masked\nfrom rigl.experimental.jax.pruning import symmetry\nfrom rigl.experimental.jax.training import training\nfrom rigl.experimental.jax.utils import utils\n  experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))\n\n  logging.info('Saving experimental results to %s', experiment_dir)\n\n  host_count = jax.host_count()\n  local_device_count = jax.local_device_count()\n  logging.info('Device count: %d, host count: %d, local device count: %d',\n               jax.device_count(), host_count, local_device_count)\n\n  if jax.host_id() == 0:\n    summary_writer = tensorboard.SummaryWriter(experiment_dir)\n\n  dataset = dataset_factory.create_dataset(\n      FLAGS.dataset,\n      FLAGS.batch_size,\n      FLAGS.batch_size_test,\n      shuffle_buffer_size=FLAGS.shuffle_buffer_size)\n\n  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)\n\n  rng = jax.random.PRNGKey(FLAGS.random_seed)\n\n  input_shape = (1,) + dataset.shape\n  base_model, _ = model_factory.create_model(\n      FLAGS.model,\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes,\n      masked_layer_indices=FLAGS.masked_layer_indices)\n\n  logging.info('Generating random mask based on model')\n\n  # Re-initialize the RNG to maintain same training pattern (as in prune code).\n  mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed)\n\n  mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng,\n                                  FLAGS.mask_sparsity)\n\n  if jax.host_id() == 0:\n    mask_stats = symmetry.get_mask_stats(mask)\n    logging.info('Mask stats: %s', str(mask_stats))\n\n\n    for label, value in mask_stats.items():\n      try:\n        summary_writer.scalar(f'mask/{label}', value, 0)\n      # This is needed because permutations (long int) can't be cast to float32.\n      except (OverflowError, ValueError):\n        summary_writer.text(f'mask/{label}', str(value), 0)\n        logging.error('Could not write mask/%s to tensorflow summary as float32'\n                      ', writing as string instead.', label)\n\n    if FLAGS.dump_json:\n      mask_stats['permutations'] = str(mask_stats['permutations'])\n      utils.dump_dict_json(\n          mask_stats, path.join(experiment_dir, 'mask_stats.json'))\n\n  mask = masked.propagate_masks(mask)\n\n  if jax.host_id() == 0:\n    mask_stats = symmetry.get_mask_stats(mask)\n    logging.info('Propagated mask stats: %s', str(mask_stats))\n\n\n    for label, value in mask_stats.items():\n      try:\n        summary_writer.scalar(f'propagated_mask/{label}', value, 0)\n      # This is needed because permutations (long int) can't be cast to float32.\n      except (OverflowError, ValueError):\n        summary_writer.text(f'propagated_mask/{label}', str(value), 0)\n        logging.error('Could not write mask/%s to tensorflow summary as float32'\n                      ', writing as string instead.', label)\n\n    if FLAGS.dump_json:\n      mask_stats['permutations'] = str(mask_stats['permutations'])\n      utils.dump_dict_json(\n          mask_stats, path.join(experiment_dir, 'propagated_mask_stats.json'))\n\n  model, initial_state = model_factory.create_model(\n      FLAGS.model,\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes,\n      masks=mask)\n\n  if FLAGS.optimizer == 'Adam':\n    optimizer = flax.optim.Adam(\n        learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)\n  elif FLAGS.optimizer == 'Momentum':\n    optimizer = flax.optim.Momentum(\n        learning_rate=FLAGS.lr,\n        beta=FLAGS.momentum,\n        weight_decay=FLAGS.weight_decay,\n        nesterov=False)\n\n  steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size\n\n  if FLAGS.lr_schedule == 'constant':\n    lr_fn = lr_schedule.create_constant_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch)\n  elif FLAGS.lr_schedule == 'stepped':\n    lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)\n    lr_fn = lr_schedule.create_stepped_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, lr_schedule_steps)\n  elif FLAGS.lr_schedule == 'cosine':\n    lr_fn = lr_schedule.create_cosine_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, FLAGS.epochs)\n  else:\n    raise ValueError(f'Unknown LR schedule type {FLAGS.lr_schedule}')\n\n  if jax.host_id() == 0:\n    trainer = training.Trainer(\n        optimizer,\n        model,\n        initial_state,\n        dataset,\n        rng,\n        summary_writer=summary_writer,\n    )\n  else:\n    trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)\n\n  _, best_metrics = trainer.train(\n      FLAGS.epochs,\n      lr_fn=lr_fn,\n      update_iter=FLAGS.update_iterations,\n      update_epoch=FLAGS.update_epoch,\n  )\n\n  logging.info('Best metrics: %s', str(best_metrics))\n\n  if jax.host_id() == 0:\n    if FLAGS.dump_json:\n      utils.dump_dict_json(best_metrics,\n                           path.join(experiment_dir, 'best_metrics.json'))\n\n    for label, value in best_metrics.items():\n      summary_writer.scalar(f'best/{label}', value,\n                            FLAGS.epochs * steps_per_epoch)\n    summary_writer.close()\n\n\ndef main(argv: List[str]):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n  run_training()\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "rigl/experimental/jax/random_mask_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.random_mask.\"\"\"\nimport glob\nfrom os import path\nimport tempfile\n\nfrom absl.testing import absltest\nfrom absl.testing import flagsaver\nfrom rigl.experimental.jax import random_mask\n\n\nclass RandomMaskTest(absltest.TestCase):\n\n  def test_run_fc(self):\n    \"\"\"Test random mask driver with fully-connected model.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        model='MNIST_FC',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      random_mask.main([])\n\n    with self.subTest(name='tf_summary_file_exists'):\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_conv(self):\n    \"\"\"Test random mask driver with CNN model.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        model='MNIST_CNN',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      random_mask.main([])\n\n    with self.subTest(name='tf_summary_file_exists'):\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_random(self):\n    \"\"\"Test random mask driver with per-neuron sparsity.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        mask_type='random',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      random_mask.main([])\n\n    with self.subTest(name='tf_summary_file_exists'):\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_per_neuron(self):\n    \"\"\"Test random mask driver with per-neuron sparsity.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        mask_type='per_neuron',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      random_mask.main([])\n\n    with self.subTest(name='tf_summary_file_exists'):\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_symmetric(self):\n    \"\"\"Test random mask driver with per-neuron sparsity.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        mask_type='symmetric',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      random_mask.main([])\n\n    with self.subTest(name='tf_summary_file_exists'):\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/requirements.txt",
    "content": "absl-py>=0.10.0\nflax>=0.2.2\njax>=0.2.0\njaxlib>=0.1.55\ntensorboard>=2.3.0\ntensorflow>=2.3.1\ntensorflow_datasets>=3.2.1\n"
  },
  {
    "path": "rigl/experimental/jax/run.sh",
    "content": "# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/bin/bash\nset -e\nset -x\n\nvirtualenv -p python3 .\nsource ./bin/activate\n\npip install -r weight_symmetry/requirements.txt\nTEST_NAMES='training.training_test\ntrain_test\nfixed_param_test\nshuffled_mask_test\nmodels.model_factory_test\nmodels.cifar10_cnn_test\nmodels.mnist_cnn_test\nmodels.mnist_fc_test\nutils.utils_test\nprune_test\nrandom_mask_test\npruning.mask_factory_test\npruning.init_test\npruning.symmetry_test\npruning.pruning_test\npruning.masked_test\ndatasets.dataset_factory_test\ndatasets.dataset_base_test\ndatasets.cifar10_test\ndatasets.mnist_test'\n\nIFS=$'\\n' readarray -t tests <<<$TEST_NAMES\n\nfor test in ${tests[@]}; do\n  python3 -m \"weight_symmetry.${test}\"\ndone\n"
  },
  {
    "path": "rigl/experimental/jax/shuffled_mask.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Weight Symmetry: Train model with randomly shuffled sparse mask.\"\"\"\n# TODO: Refactor drivers to separate logic from flags/IO.\nimport ast\nfrom os import path\nfrom typing import List, Sequence\nimport uuid\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nimport flax\nfrom flax.metrics import tensorboard\nfrom flax.training import lr_schedule\nimport jax\nimport jax.numpy as jnp\nfrom rigl.experimental.jax.datasets import dataset_factory\nfrom rigl.experimental.jax.models import model_factory\nfrom rigl.experimental.jax.pruning import mask_factory\nfrom rigl.experimental.jax.pruning import masked\nfrom rigl.experimental.jax.pruning import symmetry\nfrom rigl.experimental.jax.training import training\nfrom rigl.experimental.jax.utils import utils\n\n  experiment_dir = '{}/{}/'.format(FLAGS.experiment_dir, work_unit_id)\n\n  logging.info('Saving experimental results to %s', experiment_dir)\n\n  host_count = jax.host_count()\n  local_device_count = jax.local_device_count()\n  logging.info('Device count: %d, host count: %d, local device count: %d',\n               jax.device_count(), host_count, local_device_count)\n\n  if jax.host_id() == 0:\n    summary_writer = tensorboard.SummaryWriter(experiment_dir)\n\n  dataset = dataset_factory.create_dataset(\n      FLAGS.dataset,\n      FLAGS.batch_size,\n      FLAGS.batch_size_test,\n      shuffle_buffer_size=FLAGS.shuffle_buffer_size)\n\n  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)\n\n  rng = jax.random.PRNGKey(FLAGS.random_seed)\n\n  input_shape = (1,) + dataset.shape\n  base_model, _ = model_factory.create_model(\n      FLAGS.model,\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes)\n\n  logging.info('Generating random mask based on model')\n\n  # Re-initialize the RNG to maintain same training pattern (as in prune code).\n  mask_rng = jax.random.PRNGKey(FLAGS.mask_randomseed)\n\n  mask = mask_factory.create_mask(FLAGS.mask_type, base_model, mask_rng,\n                                  FLAGS.mask_sparsity)\n\n  if jax.host_id() == 0:\n    mask_stats = symmetry.get_mask_stats(mask)\n    logging.info('Mask stats: %s', str(mask_stats))\n\n\n    for label, value in mask_stats.items():\n      try:\n        summary_writer.scalar(f'mask/{label}', value, 0)\n      # This is needed because permutations (long int) can't be cast to float32.\n      except (OverflowError, ValueError):\n        summary_writer.text(f'mask/{label}', str(value), 0)\n        logging.error('Could not write mask/%s to tensorflow summary as float32'\n                      ', writing as string instead.', label)\n\n    if FLAGS.dump_json:\n      mask_stats['permutations'] = str(mask_stats['permutations'])\n      utils.dump_dict_json(\n          mask_stats, path.join(experiment_dir, 'mask_stats.json'))\n\n  mask = masked.propagate_masks(mask)\n\n  if jax.host_id() == 0:\n    mask_stats = symmetry.get_mask_stats(mask)\n    logging.info('Propagated mask stats: %s', str(mask_stats))\n\n\n    for label, value in mask_stats.items():\n      try:\n        summary_writer.scalar(f'propagated_mask/{label}', value, 0)\n      # This is needed because permutations (long int) can't be cast to float32.\n      except (OverflowError, ValueError):\n        summary_writer.text(f'propagated_mask/{label}', str(value), 0)\n        logging.error('Could not write mask/%s to tensorflow summary as float32'\n                      ', writing as string instead.', label)\n\n    if FLAGS.dump_json:\n      mask_stats['permutations'] = str(mask_stats['permutations'])\n      utils.dump_dict_json(\n          mask_stats, path.join(experiment_dir, 'propagated_mask_stats.json'))\n\n  model, initial_state = model_factory.create_model(\n      FLAGS.model,\n      rng, ((input_shape, jnp.float32),),\n      num_classes=dataset.num_classes,\n      masks=mask)\n\n  if FLAGS.optimizer == 'Adam':\n    optimizer = flax.optim.Adam(\n        learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)\n  elif FLAGS.optimizer == 'Momentum':\n    optimizer = flax.optim.Momentum(\n        learning_rate=FLAGS.lr,\n        beta=FLAGS.momentum,\n        weight_decay=FLAGS.weight_decay,\n        nesterov=False)\n\n  steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size\n\n  if FLAGS.lr_schedule == 'constant':\n    lr_fn = lr_schedule.create_constant_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch)\n  elif FLAGS.lr_schedule == 'stepped':\n    lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)\n    lr_fn = lr_schedule.create_stepped_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, lr_schedule_steps)\n  elif FLAGS.lr_schedule == 'cosine':\n    lr_fn = lr_schedule.create_cosine_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, FLAGS.epochs)\n  else:\n    raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule))\n\n  if jax.host_id() == 0:\n    trainer = training.Trainer(\n        optimizer,\n        model,\n        initial_state,\n        dataset,\n        rng,\n        summary_writer=summary_writer,\n    )\n  else:\n    trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)\n\n  _, best_metrics = trainer.train(\n      FLAGS.epochs,\n      lr_fn=lr_fn,\n      update_iter=FLAGS.update_iterations,\n      update_epoch=FLAGS.update_epoch,\n  )\n\n  logging.info('Best metrics: %s', str(best_metrics))\n\n  if jax.host_id() == 0:\n    if FLAGS.dump_json:\n      utils.dump_dict_json(best_metrics,\n                           path.join(experiment_dir, 'best_metrics.json'))\n\n    for label, value in best_metrics.items():\n      summary_writer.scalar('best/{}'.format(label), value,\n                            FLAGS.epochs * steps_per_epoch)\n    summary_writer.close()\n\n\ndef main(argv: List[str]):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n  run_training()\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "rigl/experimental/jax/shuffled_mask_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.shuffled_mask.\"\"\"\nimport glob\nfrom os import path\nimport tempfile\n\nfrom absl.testing import absltest\nfrom absl.testing import flagsaver\nfrom rigl.experimental.jax import shuffled_mask\n\n\nclass ShuffledMaskTest(absltest.TestCase):\n\n  def test_run_fc(self):\n    \"\"\"Tests if the driver for shuffled training runs correctly with FC NN.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        model='MNIST_FC',\n    )\n\n    with flagsaver.flagsaver(**eval_flags):\n      shuffled_mask.main([])\n\n    outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n    files = glob.glob(outfile)\n\n    self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_conv(self):\n    \"\"\"Tests if the driver for shuffled training runs correctly with CNN.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        model='MNIST_CNN',\n    )\n\n    with flagsaver.flagsaver(**eval_flags):\n      shuffled_mask.main([])\n\n    outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n    files = glob.glob(outfile)\n\n    self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_random(self):\n    \"\"\"Test random mask driver with per-neuron sparsity.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        mask_type='random',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      shuffled_mask.main([])\n\n    outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n    files = glob.glob(outfile)\n\n    self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_per_neuron(self):\n    \"\"\"Test random mask driver with per-neuron sparsity.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        mask_type='per_neuron',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      shuffled_mask.main([])\n\n    outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n    files = glob.glob(outfile)\n\n    self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\n  def test_run_symmetric(self):\n    \"\"\"Test random mask driver with per-neuron sparsity.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    self._eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n        mask_type='symmetric',\n    )\n\n    with flagsaver.flagsaver(**self._eval_flags):\n      shuffled_mask.main([])\n\n    outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n    files = glob.glob(outfile)\n\n    self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/train.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Weight Symmetry: Train Model.\n\nTrains a model from scratch, saving the relevant early weight snapshots.\n\"\"\"\nimport ast\nfrom os import path\nfrom typing import List, Sequence\nimport uuid\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nimport flax\nfrom flax.metrics import tensorboard\nfrom flax.training import lr_schedule\nimport jax\nimport jax.numpy as np\nfrom rigl.experimental.jax.datasets import dataset_factory\nfrom rigl.experimental.jax.models import model_factory\nfrom rigl.experimental.jax.training import training\n\n\nFLAGS = flags.FLAGS\n\nMODEL_LIST: Sequence[str] = tuple(model_factory.MODELS.keys())\nDATASET_LIST: Sequence[str] = tuple(dataset_factory.DATASETS.keys())\n\nflags.DEFINE_enum('model', MODEL_LIST[0], MODEL_LIST,\n                  'Model to train.')\nflags.DEFINE_enum('dataset', DATASET_LIST[0], DATASET_LIST,\n                  'Dataset to train on.')\nflags.DEFINE_enum('optimizer', 'Adam', ['Momentum', 'Adam'],\n                  'Optimizer to use.')\nflags.DEFINE_float('learning_rate', 0.01, 'Learning rate.', short_name='lr')\nflags.DEFINE_float('weight_decay', 1e-5, 'Weight decay penalty.',\n                   short_name='wd')\nflags.DEFINE_float('momentum', 0.9, 'Momentum weighting.')\nflags.DEFINE_string(\n    'lr_schedule', default='stepped',\n    help=('Learning rate schedule type; constant, stepped or cosine.'))\nflags.DEFINE_string(\n    'lr_schedule_steps', default='[[50, 0.01], [70, 0.001], [90, 0.0001]]',\n    help=('Learning rate schedule steps as a Python list; '\n          '[[step1_epoch, step1_lr_scale], '\n          '[step2_epoch, step2_lr_scale], ...]'))\nflags.DEFINE_integer(\n    'batch_size', 128, 'Training minibatch size.', lower_bound=1)\nflags.DEFINE_integer(\n    'batch_size_test',\n    128,\n    'Test minibatch size.',\n    lower_bound=1)\nflags.DEFINE_integer(\n    'epochs', 100, 'Number of epochs to train over.', lower_bound=1)\nflags.DEFINE_integer('random_seed', 42, 'Random seed.')\nflags.DEFINE_integer('shuffle_buffer_size', 1024,\n                     'Dataset shuffle buffer size.')\nflags.DEFINE_string(\n    'experiment_dir', '/tmp/experiments',\n    'Path to store experiment output in, i.e. models, snapshots.')\nflags.DEFINE_integer(\n    'update_iterations',\n    1000,\n    'Epoch interval after which to evaluate test error.',\n    lower_bound=1)\nflags.DEFINE_integer(\n    'update_epoch', 10, 'Epoch interval after which to evaluate test error.',\n    lower_bound=1)\n\n\ndef run_training():\n  \"\"\"Trains a model.\"\"\"\n  print('Logging to {}'.format(FLAGS.log_dir))\n  work_unit_id = uuid.uuid4()\n  experiment_dir = path.join(FLAGS.experiment_dir, str(work_unit_id))\n\n  logging.info('Saving experimental results to %s', experiment_dir)\n\n  host_count = jax.host_count()\n  local_device_count = jax.local_device_count()\n  logging.info('Device count: %d, host count: %d, local device count: %d',\n               jax.device_count(), host_count, local_device_count)\n\n  if jax.host_id() == 0:\n    summary_writer = tensorboard.SummaryWriter(experiment_dir)\n\n  dataset = dataset_factory.create_dataset(\n      FLAGS.dataset,\n      FLAGS.batch_size,\n      FLAGS.batch_size_test,\n      shuffle_buffer_size=FLAGS.shuffle_buffer_size)\n\n  logging.info('Training %s on the %s dataset...', FLAGS.model, FLAGS.dataset)\n\n  rng = jax.random.PRNGKey(FLAGS.random_seed)\n\n  input_shape = (1,) + dataset.shape\n  model, initial_state = model_factory.create_model(\n      FLAGS.model,\n      rng, ((input_shape, np.float32),),\n      num_classes=dataset.num_classes)\n\n  if FLAGS.optimizer == 'Adam':\n    optimizer = flax.optim.Adam(\n        learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay)\n  elif FLAGS.optimizer == 'Momentum':\n    optimizer = flax.optim.Momentum(\n        learning_rate=FLAGS.lr,\n        beta=FLAGS.momentum,\n        weight_decay=FLAGS.weight_decay,\n        nesterov=False)\n\n  steps_per_epoch = dataset.get_train_len() // FLAGS.batch_size\n\n  if FLAGS.lr_schedule == 'constant':\n    lr_fn = lr_schedule.create_constant_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch)\n  elif FLAGS.lr_schedule == 'stepped':\n    lr_schedule_steps = ast.literal_eval(FLAGS.lr_schedule_steps)\n    lr_fn = lr_schedule.create_stepped_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, lr_schedule_steps)\n  elif FLAGS.lr_schedule == 'cosine':\n    lr_fn = lr_schedule.create_cosine_learning_rate_schedule(\n        FLAGS.lr, steps_per_epoch, FLAGS.epochs)\n  else:\n    raise ValueError('Unknown LR schedule type {}'.format(FLAGS.lr_schedule))\n\n  if jax.host_id() == 0:\n    trainer = training.Trainer(\n        optimizer,\n        model,\n        initial_state,\n        dataset,\n        rng,\n        summary_writer=summary_writer,\n    )\n  else:\n    trainer = training.Trainer(optimizer, model, initial_state, dataset, rng)\n\n  _, best_metrics = trainer.train(\n      FLAGS.epochs,\n      lr_fn=lr_fn,\n      update_iter=FLAGS.update_iterations,\n      update_epoch=FLAGS.update_epoch)\n\n  logging.info('Best metrics: %s', str(best_metrics))\n\n  if jax.host_id() == 0:\n    for label, value in best_metrics.items():\n      summary_writer.scalar('best/{}'.format(label), value,\n                            FLAGS.epochs * steps_per_epoch)\n    summary_writer.close()\n\n\ndef main(argv):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n  run_training()\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "rigl/experimental/jax/train_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.train.\"\"\"\nimport glob\nfrom os import path\nimport tempfile\n\nfrom absl.testing import absltest\nfrom absl.testing import flagsaver\nfrom rigl.experimental.jax import train\n\n\nclass TrainTest(absltest.TestCase):\n\n  def test_train_driver_run(self):\n    \"\"\"Tests that the training driver runs, and outputs a TF summary.\"\"\"\n    experiment_dir = tempfile.mkdtemp()\n    eval_flags = dict(\n        epochs=1,\n        experiment_dir=experiment_dir,\n    )\n\n    with flagsaver.flagsaver(**eval_flags):\n      train.main([])\n\n    with self.subTest(name='tf_summary_file_exists'):\n      outfile = path.join(experiment_dir, '*', 'events.out.tfevents.*')\n      files = glob.glob(outfile)\n\n      self.assertTrue(len(files) == 1 and path.exists(files[0]))\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/training/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "rigl/experimental/jax/training/training.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Common training code.\n\nThis module contains utility functions for training NN.\n\nAttributes:\n  LABELKEY: The key used to retrieve a label from the batch dictionary.\n  DATAKEY: The key used to retrieve an input image from the batch dictionary.\n  PruningRateFnType: Typing alias for a valid pruning rate function.\n\"\"\"\nfrom collections import abc\nimport functools\nimport time\nfrom typing import Callable, Dict, Mapping, Optional, Tuple, Union\n\nfrom absl import logging\nimport flax\nfrom flax import jax_utils\nfrom flax.training import common_utils\nimport jax\nimport jax.numpy as jnp\nfrom rigl.experimental.jax.datasets import dataset_base\nfrom rigl.experimental.jax.models import model_factory\nfrom rigl.experimental.jax.pruning import masked\nfrom rigl.experimental.jax.pruning import pruning\nfrom rigl.experimental.jax.pruning import symmetry\nfrom rigl.experimental.jax.utils import utils\nimport tensorflow.compat.v2 as tf\n\nLABELKEY = dataset_base.ImageDataset.LABELKEY\nDATAKEY = dataset_base.ImageDataset.DATAKEY\n\nPruningRateFnType = Union[Mapping[str, Callable[[int], float]], Callable[[int],\n                                                                         float]]\n\n\ndef _shard_batch(xs):\n  \"\"\"Shards a batch for a pmap, based on the number of devices.\"\"\"\n  local_device_count = jax.local_device_count()\n\n  def _prepare(x):\n    return x.reshape((local_device_count, -1) + x.shape[1:])\n\n  return jax.tree_map(_prepare, xs)\n\n\ndef train_step(\n    optimizer: flax.optim.Optimizer, batch: Mapping[str, jnp.array],  # pytype: disable=module-attr\n    rng: Callable[[int], jnp.array], state: flax.deprecated.nn.Collection,\n    learning_rate_fn: Callable[[int], float]\n) -> Tuple[flax.optim.Optimizer, flax.deprecated.nn.Collection, float, float]:  # pytype: disable=module-attr\n  \"\"\"Performs training for one minibatch.\n\n  Args:\n    optimizer: Optimizer to use.\n    batch: Minibatch to train with.\n    rng: Random number generator, i.e. jax.random.PRNGKey, to use for model\n      training, e.g. dropout.\n    state: Model state.\n    learning_rate_fn: A function that returns the learning rate given the step.\n\n  Returns:\n    A tuple consisting of the new optimizer, new state, mini-batch loss, and\n    gradient norm.\n  \"\"\"\n\n  def loss_fn(\n      model: flax.deprecated.nn.Model\n  ) -> Tuple[float, Tuple[flax.deprecated.nn.Collection, jnp.array]]:\n    \"\"\"Evaluates the loss function.\n\n    Args:\n      model: The model with which to evaluate the loss.\n\n    Returns:\n      Tuple of the loss for the mini-batch, and model state.\n    \"\"\"\n    with flax.deprecated.nn.stateful(state) as new_state:\n      with flax.deprecated.nn.stochastic(rng):\n        logits = model(batch[DATAKEY])\n    loss = utils.cross_entropy_loss(logits, batch[LABELKEY])\n    return loss, new_state\n\n  lr = learning_rate_fn(optimizer.state.step)\n  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n  (loss, new_state), grad = grad_fn(optimizer.target)\n  grad = jax.lax.pmean(grad, 'batch')\n\n  new_opt = optimizer.apply_gradient(grad, learning_rate=lr)\n\n  grad_norm = jnp.linalg.norm(utils.param_as_array(grad))\n\n  return new_opt, new_state, loss, grad_norm\n\n\nclass Trainer:\n  \"\"\"Training class with the state and methods for training a neural network.\n\n  Attributes:\n    optimizer: Optimizer used for training, None if training hasn't begun.\n    state: Model state used for training.\n  \"\"\"\n\n  def __init__(\n      self,\n      optimizer_def: flax.optim.OptimizerDef,  # pytype: disable=module-attr\n      initial_model: flax.deprecated.nn.Model,\n      initial_state: flax.deprecated.nn.Collection,\n      dataset: jnp.array,\n      rng: Callable[[int], jnp.array] = None,\n      summary_writer: Optional[tf.summary.SummaryWriter] = None,\n  ):\n    \"\"\"Creates a Trainer object.\n\n    Args:\n      optimizer_def: The flax optimizer def (i.e. not instantiated with a model\n        using .create) to use for training.\n      initial_model: The initial model to train.\n      initial_state: The initial state of the model.\n      dataset: The training dataset.\n      rng: Random number generator, i.e. jax.random.PRNGKey, to use for model\n        training, e.g. dropout.\n      summary_writer: An optional tensorboard summary writer for logging\n    self._rng = rng\n\n    if self._rng is None:\n      self._rng = jax.random.PRNGKey(42)\n\n  def _update_optimizer(self, model: flax.deprecated.nn.Model):\n    \"\"\"Updates the optimizer based on the given model.\"\"\"\n    self.optimizer = jax_utils.replicate(\n        self._optimizer_def.create(model))\n\n  def train(\n      self,\n      num_epochs: int,\n      lr_fn: Optional[Callable[[int], float]] = None,\n      pruning_rate_fn: Optional[PruningRateFnType] = None,\n      update_iter: int = 100,\n      update_epoch: int = 10\n  ) -> Tuple[flax.deprecated.nn.Model, Mapping[str, Union[int, float, Mapping[\n      str, float]]]]:\n    \"\"\"Trains the model over the given number of epochs.\n\n    Args:\n      num_epochs: The total number of epochs to train over.\n      lr_fn: The learning rate function, takes the current iteration/step as an\n        argument and returns the current learning rate, uses constant learning\n        rate if no function is provided.\n      pruning_rate_fn: The pruning rate function, takes the current epoch as an\n        argument and returns the current pruning rate, no further pruning is\n        performed during training if not set. Can be a dictionary, containing\n        the pruning rate schedule functions for each layer, or a single function\n        for all layers.\n      update_iter: Period of iterations in which to log/update per-batch\n        metrics.\n      update_epoch: Period of epochs in which to log/update full training/test\n        metrics.\n\n    Returns:\n      Tuple consisting of the best model found during training, and metrics.\n\n    Raises:\n      ValueError: If the batch size of the data set is not evenly divisible by\n                  number of devices, or the model batch size is not the training\n                  data batch size/number of jax devices.\n    \"\"\"\n    best_test_acc = 0\n    best_train_loss = None\n    best_iter = None\n\n    if lr_fn is None:\n      lr_fn = lambda _: self.optimizer.optimizer_def.hyper_params.learning_rate\n\n    host_count = jax.host_count()\n    device_count = jax.device_count()\n    local_device_count = jax.local_device_count()\n    logging.info('JAX hosts %d, devices: %d, local devices: %d', host_count,\n                 device_count, local_device_count)\n\n    # TODO Implement multi-host training.\n    if host_count > 1:\n      raise NotImplementedError('Multi-host training is not supported yet, '\n                                'see b/155550457.')\n\n    if self._dataset.batch_size % device_count > 0:\n      raise ValueError(\n          'Train batch size ({}) must be divisible by number of local devices '\n          '({})'.format(self._dataset.batch_size, local_device_count))\n\n    if self._dataset.batch_size_test % device_count > 0:\n      raise ValueError(\n          'Test batch size ({}) must be divisible by number of local devices '\n          '({})'.format(self._dataset.batch_size_test, local_device_count))\n\n    # Required to use state and optimizer with jax.pmap.\n    state = jax_utils.replicate(self.state)\n    self._update_optimizer(self._initial_model)\n\n    p_train_step = jax.pmap(\n        functools.partial(train_step, learning_rate_fn=lr_fn),\n        axis_name='batch')\n\n    # Function to sync the batch statistics across replicas.\n    p_synchronized_batch_stats = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')\n\n    p_cosine_similarity = functools.partial(utils.cosine_similarity_model,\n                                            self._initial_model)\n    p_vector_difference_norm = functools.partial(\n        utils.vector_difference_norm_model, self._initial_model)\n\n    pruning_rate = None\n    mask = None\n\n    cumulative_grad_norm = 0\n\n    start_time = time.time()\n\n    # Main training loop.\n    for epoch in range(num_epochs):\n      if epoch % update_epoch == 0 or epoch == num_epochs - 1:\n        epoch_start_time = time.time()\n\n      # If we get different schedules for different layers.\n      if isinstance(pruning_rate_fn, abc.Mapping):\n        next_pruning_rate = {\n            layer: layer_fn(epoch)\n            for layer, layer_fn in pruning_rate_fn.items()\n        }\n      elif pruning_rate_fn:\n        next_pruning_rate = pruning_rate_fn(epoch)\n\n      # If pruning rate has changed/is first epoch, we need to update mask.\n      # Note: pruning_rate could be zero, so must explicitly check it's None.\n      if pruning_rate_fn and (pruning_rate is None or\n                              pruning_rate != next_pruning_rate):\n\n        pruning_rate = next_pruning_rate\n\n        logging.info('[%d] Pruning Rate: %s', epoch, str(pruning_rate))\n\n        # Unreplicate optimizer/current model, and mask.\n        self.optimizer = jax_utils.unreplicate(self.optimizer)\n        mask = jax_utils.unreplicate(mask) if mask else None\n\n        # Performs pruning to get updated mask.\n        mask = pruning.prune(self.optimizer.target, pruning_rate, mask=mask)\n\n        logging.info('[%d] Mask Sparsity: %0.3f', epoch,\n                     masked.mask_sparsity(mask))\n\n        for layer, layer_mask in sorted(mask.items()):\n          if layer_mask:\n            logging.info('[%d] Layer: %s, Mask Sparsity: %0.3f', epoch, layer,\n                         masked.mask_layer_sparsity(layer_mask))\n\n        if jax.host_id() == 0:\n          mask_stats = symmetry.get_mask_stats(mask)\n          logging.info('Mask stats: %s', str(mask_stats))\n\n\n          if self._summary_writer:\n            for label, value in mask_stats.items():\n              try:\n                self._summary_writer.scalar(f'mask_{epoch}/{label}', value, 0)\n              # Needed when permutations (long int) can't be cast to float32.\n              except (OverflowError, ValueError):\n                self._summary_writer.text(f'mask_{epoch}/{label}', str(value),\n                                          0)\n                logging.error(\n                    'Could not write mask_%d/%s to tensorflow summary as float32'\n                    ', writing as string instead.', epoch, label)\n\n        # Creates a new optimizer, based on a new model with new mask.\n        self._update_optimizer(\n            model_factory.update_model(self.optimizer.target, masks=mask))\n\n      # Begins epoch.\n      for batch in self._dataset.get_train():\n        # Note: Because of replicate, step has # device identical vals.\n        step = jax_utils.unreplicate(self.optimizer.state.step)\n\n        if step % update_iter == 0:\n          batch_start_time = time.time()\n\n        # These are required for pmap call.\n        self._rng, step_key = jax.random.split(self._rng)\n        batch = _shard_batch(batch)\n        sharded_keys = common_utils.shard_prng_key(step_key)\n\n        (self.optimizer, state, opt_loss,\n         grad_norm) = p_train_step(self.optimizer, batch, sharded_keys, state)\n\n        if state.state:\n          state = p_synchronized_batch_stats(state)\n\n        grad_norm = jax_utils.unreplicate(grad_norm)\n\n        cumulative_grad_norm += grad_norm\n\n        # Per-iteration status/metrics update.\n        if jax.host_id() == 0 and step % update_iter == 0:\n          batch_time = time.time() - batch_start_time\n\n          if self._summary_writer is not None:\n            self._summary_writer.scalar('training/train_batch_loss',\n                                        jnp.mean(opt_loss),\n                                        step)\n            self._summary_writer.scalar('training/gradient_norm', grad_norm,\n                                        step)\n          logging.info('[epoch %d] %d, loss %0.5f, lr %0.3f, %0.3f sec', epoch,\n                       step, jnp.mean(opt_loss), lr_fn(step), batch_time)\n\n      # Per-epoch status/metrics update.\n      if (jax.host_id() == 0 and\n          (epoch % update_epoch == 0 or epoch == num_epochs - 1)):\n        epoch_time = time.time() - epoch_start_time\n\n        cosine_distance = p_cosine_similarity(\n            jax_utils.unreplicate(self.optimizer.target))\n        vector_difference_norm = p_vector_difference_norm(\n            jax_utils.unreplicate(self.optimizer.target))\n\n        train_metrics = eval_model(self.optimizer.target, state,\n                                   self._dataset.get_train())\n        test_metrics = eval_model(self.optimizer.target, state,\n                                  self._dataset.get_test())\n\n        train_loss = train_metrics['loss']\n        train_acc = train_metrics['accuracy']\n\n        test_loss = test_metrics['loss']\n        test_acc = test_metrics['accuracy']\n\n        if jax.host_id() == 0:\n          metrics = {\n              'wallclock_time':\n                  float(epoch_time),\n              'train_accuracy':\n                  float(train_acc),\n              'train_avg_loss':\n                  float(train_loss),\n              'test_accuracy':\n                  float(test_acc),\n              'test_avg_loss':\n                  float(test_loss),\n              'lr':\n                  float(lr_fn(step)),\n              'cosine_distance':\n                  float(cosine_distance),\n              'cumulative_gradient_norm':\n                  float(cumulative_grad_norm),\n              'vector_difference_norm':\n                  float(vector_difference_norm),\n          }\n\n\n          if self._summary_writer is not None:\n            for label, value in metrics.items():\n              self._summary_writer.scalar('training/{}'.format(label), value,\n                                          step)\n\n        if test_acc >= best_test_acc:\n          best_model = self.optimizer.target\n\n          best_test_acc = test_acc\n          best_test_metrics = {\n              'train_avg_loss': float(train_loss),\n              'train_accuracy': float(train_acc),\n              'test_avg_loss': float(test_loss),\n              'test_accuracy': float(test_acc),\n              'step': int(step),\n              'cosine_distance': float(cosine_distance),\n              'cumulative_gradient_norm': float(cumulative_grad_norm),\n              'vector_difference_norm': float(vector_difference_norm),\n          }\n          best_iter = step\n\n        if best_train_loss is None or train_loss <= best_train_loss:\n          best_train_loss = train_loss\n          best_train_metrics = {\n              'train_avg_loss': float(train_loss),\n              'train_accuracy': float(train_acc),\n              'test_avg_loss': float(test_loss),\n              'test_accuracy': float(test_acc),\n              'step': int(step),\n              'cosine_distance': float(cosine_distance),\n              'cumulative_gradient_norm': float(cumulative_grad_norm),\n              'vector_difference_norm': float(vector_difference_norm),\n          }\n\n        log_format_str = (\n            '[epoch %d] train avg. loss %0.4f, train acc. %0.4f, test avg. '\n            'loss %0.4f, test acc. %0.4f, %0.4f sec, cosine sim.: %0.3f, cum. '\n            'grad. norm: %0.3f, vector diff: %0.3f')\n        log_vars = [\n            epoch, train_loss, train_acc, test_loss, test_acc, epoch_time,\n            float(cosine_distance),\n            float(cumulative_grad_norm),\n            float(vector_difference_norm)\n        ]\n        logging.info(log_format_str, *log_vars)\n      # End epoch.\n\n\n    training_time = time.time() - start_time\n    logging.info('Training finished, Total wallclock time: %0.2f sec',\n                 training_time)\n\n    if jax.host_id() == 0 and self._summary_writer is not None:\n      for label, value in best_test_metrics.items():\n        self._summary_writer.scalar('best_test_acc/{}'.format(label), value,\n                                    best_iter)\n    logging.info('Best Test Accuracy: iteration %d, test acc. %0.5f',\n                 best_test_metrics['step'], best_test_acc)\n\n    if jax.host_id() == 0 and self._summary_writer is not None:\n      for label, value in best_test_metrics.items():\n        self._summary_writer.scalar(\n            'best_train_loss/{}'.format(label),\n            value,\n            step=best_train_metrics['step'])\n    logging.info('Best Train Loss: iteration %d, test loss. %0.5f',\n                 best_train_metrics['step'], best_train_loss)\n\n    return (best_model, best_test_metrics)\n\n\ndef _eval_step(model: flax.deprecated.nn.Model,\n               state: flax.deprecated.nn.Collection,\n               batch: Mapping[str, jnp.array]) -> Dict[str, jnp.array]:\n  \"\"\"Evaluates a mini-batch of data.\n\n  Args:\n    model: The model to use to evaluate.\n    state: Model state containing state for stateful flax.deprecated.nn\n      functions, such as batch normalization.\n    batch: Mini-batch of data to evaluate on.\n\n  Returns:\n    Dictionary consisting of the mini-batch the loss and accuracy.\n  \"\"\"\n  state = jax.lax.pmean(state, 'batch')\n  with flax.deprecated.nn.stateful(state, mutable=False):\n    logits = model(batch[DATAKEY], train=False)\n  metrics = utils.compute_metrics(logits, batch[LABELKEY])\n  return metrics\n\n\ndef eval_model(model: flax.deprecated.nn.Model,\n               state: flax.deprecated.nn.Collection,\n               eval_dataset: jnp.array) -> Dict[str, float]:\n  \"\"\"Evaluates the given model using the given dataset.\n\n  Args:\n    model: The model the evaluate.\n    state: Model state containing state for stateful flax.deprecated.nn\n      functions, such as batch normalization.\n    eval_dataset: Dataset to evaluate the model over.\n\n  Returns:\n  Dictionary containing the average loss and accuracy of the model on the given\n  dataset.\n  \"\"\"\n  p_eval_step = jax.pmap(_eval_step, axis_name='batch')\n\n  batch_sizes = []\n  metrics = []\n  for batch in eval_dataset:\n    batch_size = len(batch[LABELKEY])\n\n    # These are required for pmap call.\n    batch = _shard_batch(batch)\n    batch_metrics = p_eval_step(model, state, batch)\n\n    batch_sizes.append(batch_size)\n    metrics.append(batch_metrics)\n\n  # Note: use weighted mean, since we do mean of means with potentially\n  # different batch sizes otherwise.\n  batch_sizes = jnp.array(batch_sizes)\n  weights = batch_sizes / jnp.sum(batch_sizes)\n  eval_metrics = common_utils.get_metrics(metrics)\n  return jax.tree_map(lambda x: (weights * x).sum(), eval_metrics)\n"
  },
  {
    "path": "rigl/experimental/jax/training/training_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.training.training.\"\"\"\nimport functools\nimport math\n\nfrom absl.testing import absltest\nimport flax\nfrom flax import jax_utils\nfrom flax.metrics import tensorboard\nfrom flax.training import common_utils\nimport jax\nimport jax.numpy as jnp\n\nfrom rigl.experimental.jax.datasets import dataset_factory\nfrom rigl.experimental.jax.models import model_factory\nfrom rigl.experimental.jax.training import training\n\n\nclass TrainingTest(absltest.TestCase):\n  \"\"\"Tests functions for training loop and training convenience functions.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n\n    self._batch_size = 128  # Note: Tests are run on GPU/TPU.\n    self._batch_size_test = 128\n    self._shuffle_buffer_size = 1024\n    self._rng = jax.random.PRNGKey(42)\n    self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)\n    self._num_classes = 10\n    self._num_epochs = 1\n\n    self._learning_rate_fn = lambda _: 0.01\n    self._weight_decay = 0.0001\n    self._momentum = 0.9\n    self._rng = jax.random.PRNGKey(42)\n\n    self._min_loss = jnp.finfo(float).eps\n    self._max_loss = 2.0 * math.log(self._num_classes)\n\n    self._dataset_name = 'MNIST'\n    self._model_name = 'MNIST_CNN'\n\n    self._summarywriter = tensorboard.SummaryWriter('/tmp/')\n\n    self._dataset = dataset_factory.create_dataset(\n        self._dataset_name,\n        self._batch_size,\n        self._batch_size_test,\n        shuffle_buffer_size=self._shuffle_buffer_size)\n\n    self._model, self._state = model_factory.create_model(\n        self._model_name,\n        self._rng, (self._input_shape,),\n        num_classes=self._num_classes)\n\n    self._optimizer = flax.optim.Momentum(  # pytype: disable=module-attr\n        learning_rate=self._learning_rate_fn(0),\n        beta=self._momentum,\n        weight_decay=self._weight_decay)\n\n  def test_train_one_step(self):\n    \"\"\"Tests training loop over one step.\"\"\"\n    iterator = self._dataset.get_train()\n    batch = next(iterator)\n\n    state = jax_utils.replicate(self._state)\n    optimizer = jax_utils.replicate(self._optimizer.create(self._model))\n\n    self._rng, step_key = jax.random.split(self._rng)\n    batch = training._shard_batch(batch)\n    sharded_keys = common_utils.shard_prng_key(step_key)\n\n    p_train_step = jax.pmap(\n        functools.partial(\n            training.train_step, learning_rate_fn=self._learning_rate_fn),\n        axis_name='batch')\n    _, _, loss, gradient_norm = p_train_step(optimizer, batch, sharded_keys,\n                                             state)\n\n    loss = jnp.mean(loss)\n    gradient_norm = jax_utils.unreplicate(gradient_norm)\n\n    with self.subTest(name='test_loss_range'):\n      self.assertBetween(loss, self._min_loss, self._max_loss)\n\n    with self.subTest(name='test_gradient_norm'):\n      self.assertGreaterEqual(gradient_norm, 0)\n\n  def test_train_one_epoch(self):\n    \"\"\"Tests training loop over one epoch.\"\"\"\n    trainer = training.Trainer(self._optimizer, self._model, self._state,\n                               self._dataset)\n\n    with self.subTest(name='trainer_instantiation'):\n      self.assertIsInstance(trainer, training.Trainer)\n\n    best_model, best_metrics = trainer.train(self._num_epochs)\n\n    with self.subTest(name='best_model_type'):\n      self.assertIsInstance(best_model, flax.deprecated.nn.Model)\n\n    with self.subTest(name='train_accuracy'):\n      self.assertBetween(best_metrics['train_accuracy'], 0., 1.)\n\n    with self.subTest(name='train_avg_loss'):\n      self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n    with self.subTest(name='step'):\n      self.assertGreater(best_metrics['step'], 0)\n\n    with self.subTest(name='cosine_distance'):\n      self.assertBetween(best_metrics['cosine_distance'], 0., 1.)\n\n    with self.subTest(name='cumulative_gradient_norm'):\n      self.assertGreater(best_metrics['cumulative_gradient_norm'], 0)\n\n    with self.subTest(name='test_accuracy'):\n      self.assertBetween(best_metrics['test_accuracy'], 0., 1.)\n\n    with self.subTest(name='test_avg_loss'):\n      self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n  def test_train_one_epoch_tensorboard(self):\n    \"\"\"Tests training loop over one epoch, with tensorboard.\"\"\"\n\n    trainer = training.Trainer(\n        self._optimizer,\n        self._model,\n        self._state,\n        self._dataset,\n        summary_writer=self._summarywriter)\n\n    with self.subTest(name='TrainerInstantiation'):\n      self.assertIsInstance(trainer, training.Trainer)\n\n    best_model, best_metrics = trainer.train(self._num_epochs)\n    with self.subTest(name='best_model_type'):\n      self.assertIsInstance(best_model, flax.deprecated.nn.Model)\n\n    with self.subTest(name='train_accuracy'):\n      self.assertBetween(best_metrics['train_accuracy'], 0., 1.)\n\n    with self.subTest(name='train_avg_loss'):\n      self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n    with self.subTest(name='step'):\n      self.assertGreater(best_metrics['step'], 0)\n\n    with self.subTest(name='cosine_distance'):\n      self.assertBetween(best_metrics['cosine_distance'], 0., 1.)\n\n    with self.subTest(name='cumulative_gradient_norm'):\n      self.assertGreater(best_metrics['cumulative_gradient_norm'], 0)\n\n    with self.subTest(name='test_accuracy'):\n      self.assertBetween(best_metrics['test_accuracy'], 0., 1.)\n\n    with self.subTest(name='test_avg_loss'):\n      self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n  def test_train_one_epoch_pruning_global_schedule(self):\n    \"\"\"Tests training loop over one epoch with global pruning rate schedule.\"\"\"\n    trainer = training.Trainer(self._optimizer, self._model, self._state,\n                               self._dataset)\n\n    with self.subTest(name='trainer_instantiation'):\n      self.assertIsInstance(trainer, training.Trainer)\n\n    best_model, best_metrics = trainer.train(self._num_epochs,\n                                             pruning_rate_fn=lambda _: 0.5)\n\n    with self.subTest(name='best_model_type'):\n      self.assertIsInstance(best_model, flax.deprecated.nn.Model)\n\n    with self.subTest(name='train_accuracy'):\n      self.assertBetween(best_metrics['train_accuracy'], 0., 1.)\n\n    with self.subTest(name='train_avg_loss'):\n      self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n    with self.subTest(name='step'):\n      self.assertGreater(best_metrics['step'], 0)\n\n    with self.subTest(name='cosine_distance'):\n      self.assertBetween(best_metrics['cosine_distance'], 0., 1.)\n\n    with self.subTest(name='cumulative_gradient_norm'):\n      self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.)\n\n    with self.subTest(name='test_accuracy'):\n      self.assertBetween(best_metrics['test_accuracy'], 0., 1.)\n\n    with self.subTest(name='test_avg_loss'):\n      self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n  def test_train_one_epoch_pruning_local_schedule(self):\n    \"\"\"Tests training loop over one epoch with local pruning rate schedule.\"\"\"\n    trainer = training.Trainer(self._optimizer, self._model, self._state,\n                               self._dataset)\n\n    with self.subTest(name='trainer_instantiation'):\n      self.assertIsInstance(trainer, training.Trainer)\n\n    best_model, best_metrics = trainer.train(\n        self._num_epochs, pruning_rate_fn={'MaskedModule_0': lambda _: 0.5})\n\n    with self.subTest(name='best_model_type'):\n      self.assertIsInstance(best_model, flax.deprecated.nn.Model)\n\n    with self.subTest(name='train_accuracy'):\n      self.assertBetween(best_metrics['train_accuracy'], 0., 1.)\n\n    with self.subTest(name='train_avg_loss'):\n      self.assertBetween(best_metrics['train_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n    with self.subTest(name='step'):\n      self.assertGreater(best_metrics['step'], 0)\n\n    with self.subTest(name='cosine_distance'):\n      self.assertBetween(best_metrics['cosine_distance'], 0., 1.)\n\n    with self.subTest(name='cumulative_gradient_norm'):\n      self.assertGreater(best_metrics['cumulative_gradient_norm'], 0.)\n\n    with self.subTest(name='test_accuracy'):\n      self.assertBetween(best_metrics['test_accuracy'], 0., 1.)\n\n    with self.subTest(name='test_avg_loss'):\n      self.assertBetween(best_metrics['test_avg_loss'], self._min_loss,\n                         self._max_loss)\n\n  def test_eval_batch(self):\n    \"\"\"Tests model per-batch evaluation function.\"\"\"\n    state = jax_utils.replicate(self._state)\n    optimizer = jax_utils.replicate(self._optimizer.create(self._model))\n\n    iterator = self._dataset.get_test()\n    batch = next(iterator)\n    batch = training._shard_batch(batch)\n\n    metrics = jax.pmap(training._eval_step, axis_name='batch')(\n        optimizer.target, state, batch)\n\n    loss = jnp.mean(metrics['loss'])\n    accuracy = jnp.mean(metrics['accuracy'])\n\n    with self.subTest(name='test_eval_batch_loss'):\n      self.assertBetween(loss, self._min_loss, self._max_loss)\n\n    with self.subTest(name='test_eval_batch_accuracy'):\n      self.assertBetween(accuracy, 0., 1.)\n\n  def test_eval(self):\n    \"\"\"Tests model evaluation function.\"\"\"\n    state = jax_utils.replicate(self._state)\n    optimizer = jax_utils.replicate(self._optimizer.create(self._model))\n\n    metrics = training.eval_model(optimizer.target, state,\n                                  self._dataset.get_test())\n\n    loss = metrics['loss']\n    accuracy = metrics['accuracy']\n\n    with self.subTest(name='test_eval_loss'):\n      self.assertBetween(loss, 0., 2.0*math.log(self._num_classes))\n\n    with self.subTest(name='test_eval_accuracy'):\n      self.assertBetween(accuracy, 0., 1.)\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/experimental/jax/utils/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "rigl/experimental/jax/utils/utils.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Convenience Functions for NN training.\n\nMisc. common functions used in training NN models.\n\"\"\"\nimport functools\nimport itertools\nimport json\nimport operator\nfrom typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Tuple, TypeVar\n\nimport flax\nfrom flax.training import common_utils\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\n\n\ndef cross_entropy_loss(log_softmax_logits,\n                       labels):\n  \"\"\"Returns the cross-entropy classification loss.\n\n  Args:\n    log_softmax_logits: The log of the softmax of the logits for the mini-batch,\n      e.g. as output by jax.nn.log_softmax(logits).\n    labels: The labels for the mini-batch.\n  \"\"\"\n  num_classes = log_softmax_logits.shape[-1]\n  one_hot_labels = common_utils.onehot(labels, num_classes)\n  return -jnp.sum(one_hot_labels * log_softmax_logits) / labels.size\n\n\ndef compute_metrics(logits,\n                    labels):\n  \"\"\"Computes the classification loss and accuracy for a mini-batch.\n\n  Args:\n     logits: NN model's logit outputs for the mini-batch.\n     labels: The classification labels for the mini-batch.\n\n  Returns:\n     Metrics dictionary where 'loss' the mini-batch loss and 'accuracy' is\n     the classification accuracy.\n\n  Raises:\n    ValueError: If the given logits array is not of the correct shape.\n  \"\"\"\n  if len(logits.shape) != 2:\n    raise ValueError(\n        'Expected an array of (BATCHSIZE, NUM_CLASSES), but got {}'.format(\n            logits.shape))\n\n  metrics = {\n      'loss': cross_entropy_loss(logits, labels),\n      'accuracy': jnp.mean(jnp.argmax(logits, -1) == labels)\n  }\n\n  return jax.lax.pmean(metrics, 'batch')\n\n\ndef _np_converter(obj):\n  \"\"\"Explicitly cast Numpy types not recognized by JSON serializer.\"\"\"\n  if isinstance(obj, jnp.integer) or isinstance(obj, np.integer):\n    return int(obj)\n  elif isinstance(obj, jnp.floating) or isinstance(obj, np.floating):\n    return float(obj)\n  elif isinstance(obj, jnp.ndarray) or isinstance(obj, np.ndarray):\n    return obj.tolist()\n\n\ndef dump_dict_json(data_dict, path):\n  \"\"\"Dumps a dictionary to a JSON file, ensuring Numpy types are cast correctly.\n\n  Args:\n    data_dict: A metrics dictionary.\n    path: Path of the JSON file to save.\n\n  Raises:\n  \"\"\"\n\n  with open(path, 'w') as json_file:\n    json.dump(data_dict, json_file, default=_np_converter)\n\n\ndef count_param(model,\n                param_names):\n  \"\"\"Counts the number of parameters in the given model.\n\n  Args:\n    model: The model for which to count the parameters.\n    param_names: The parameters in each layer which should be accounted for.\n\n  Returns:\n    The total number of parameters of the given names in the model.\n  \"\"\"\n\n  param_traversal = flax.optim.ModelParamTraversal(  # pytype: disable=module-attr\n      lambda path, _: any(param_name in path for param_name in param_names))\n\n  return functools.reduce(\n      operator.add, [param.size for param in param_traversal.iterate(model)], 0)\n\n\n@jax.jit\ndef cosine_similarity(a, b):\n  \"\"\"Calculates the cosine similarity between two tensors of same shape.\"\"\"\n  a = a.flatten()\n  b = b.flatten()\n  return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))\n\n\ndef param_as_array(params):\n  \"\"\"Returns a Flax parameter pytree as a single numpy weight vector.\"\"\"\n  params_flat = jax.tree_util.tree_leaves(params)\n  return jnp.concatenate([param.flatten() for param in params_flat])\n\n\ndef cosine_similarity_model(initial_model,\n                            current_model):\n  \"\"\"Calculates the cosine similarity between two model's parameters.\"\"\"\n  initial_params = param_as_array(initial_model.params)\n  params = param_as_array(current_model.params)\n\n  return cosine_similarity(initial_params, params)\n\n\ndef vector_difference_norm_model(initial_model,\n                                 current_model):\n  \"\"\"Calculates norm of the difference between two model's parameter vectors.\"\"\"\n  initial_params = param_as_array(initial_model.params)\n  params = param_as_array(current_model.params)\n\n  return jnp.linalg.norm(params - initial_params)\n\n# Use typevar to hint that we expect unspecified types to match.\nT = TypeVar('T')\n\n\ndef pairwise_longest(iterable):\n  \"\"\"Creates a meta-iterator to iterate over current/next values concurrently.\n\n  This is different from itertools pairwise recipe in that it returns the final\n  element as (final, None).\n\n  Args:\n    iterable: An Iterable of any type.\n  Returns:\n    An iterable which returns the current and next items in the iterable, or\n    None if there is no next. For example, for an iterator over the list\n    (1, 2, 3, 4), this would return an iterator as\n    ((1, 2), (2, 3), (3, 4), (4, None)).\n  \"\"\"\n  # From itertools example documentation.\n  a, b = itertools.tee(iterable)\n  next(b, None)\n  return itertools.zip_longest(a, b)\n"
  },
  {
    "path": "rigl/experimental/jax/utils/utils_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for weight_symmetry.nn.nn_functions.\"\"\"\nimport functools\nimport json\nimport operator\nimport tempfile\nfrom typing import Optional, Sequence, TypeVar\n\nfrom absl.testing import absltest\nfrom absl.testing import parameterized\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom rigl.experimental.jax.training import training\nfrom rigl.experimental.jax.utils import utils\n\n\nclass TwoLayerDense(flax.deprecated.nn.Module):\n  \"\"\"Two-layer Dense Network.\"\"\"\n\n  NUM_FEATURES: Sequence[int] = (32, 64)\n\n  def apply(self, inputs):\n    # If inputs are in image dimensions, flatten image.\n    inputs = inputs.reshape(inputs.shape[0], -1)\n\n    inputs = flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[0])\n    return flax.deprecated.nn.Dense(inputs, features=self.NUM_FEATURES[1])\n\n\nclass UtilsTest(parameterized.TestCase):\n  \"\"\"Test functions for NN convenience functions.\"\"\"\n\n  def setUp(self):\n    \"\"\"Common setup for test cases.\"\"\"\n    super().setUp()\n    self._batch_size = 2\n    self._num_classes = 10\n    self._true_logit = 0.5\n    self._input_shape = ((self._batch_size, 28, 28, 1), jnp.float32)\n    self._input = jnp.ones(*self._input_shape)\n\n    self._rng = jax.random.PRNGKey(42)\n    _, initial_params = TwoLayerDense.init_by_shape(self._rng,\n                                                    (self._input_shape,))\n    self._model = flax.deprecated.nn.Model(TwoLayerDense, initial_params)\n    _, initial_params = TwoLayerDense.init_by_shape(self._rng,\n                                                    (self._input_shape,))\n    self._model_diff_init = flax.deprecated.nn.Model(TwoLayerDense,\n                                                     initial_params)\n\n  def _create_logits_labels(self, correct):\n    \"\"\"Creates a set of logits/labels resulting from correct classification.\n\n    Args:\n      correct: If true, creates labels for a correct classifiction, otherwise\n        creates labels for an incorrect classification.\n    Returns:\n      A tuple of logits, labels.\n    \"\"\"\n    logits = np.full((self._batch_size, self._num_classes),\n                     (1.0 - self._true_logit) / self._num_classes,\n                     dtype=np.float32)\n\n    # Diagonal over batch will be true.\n    for i in range(self._batch_size):\n      logits[i, i % self._num_classes] = self._true_logit\n\n    labels = np.zeros(self._batch_size, dtype=jnp.int32)\n\n    # Diagonal over batch will be true.\n    for i in range(self._batch_size):\n      labels[i] = (i if correct else i + 1) % self._num_classes\n\n    return jnp.array(logits), jnp.array(labels)\n\n  def test_compute_metrics_correct(self):\n    \"\"\"Tests output when logit outputs indicate correct classification.\"\"\"\n    logits, labels_correct = self._create_logits_labels(True)\n    logits = training._shard_batch(logits)\n    labels_correct = training._shard_batch(labels_correct)\n\n    p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch')\n    metrics = p_compute_metrics(logits, labels_correct)\n    loss = metrics['loss']\n    accuracy = metrics['accuracy']\n\n    with self.subTest(name='loss_type'):\n      self.assertIsInstance(loss, jnp.ndarray)\n\n    with self.subTest(name='loss_len'):\n      self.assertEqual(loss.size, 1)\n\n    with self.subTest(name='loss_values'):\n      self.assertGreaterEqual(loss.all(), 0)\n\n    with self.subTest(name='accuracy_type'):\n      self.assertIsInstance(accuracy, jnp.ndarray)\n\n    with self.subTest(name='accuracy_Len'):\n      self.assertEqual(accuracy.size, 1)\n\n    with self.subTest(name='accuracy_values'):\n      self.assertAlmostEqual(accuracy.all(), 1.0)\n\n  def test_compute_metrics_incorrect(self):\n    \"\"\"Tests output when logit outputs indicate incorrect classification.\"\"\"\n    logits, labels_incorrect = self._create_logits_labels(False)\n    logits = training._shard_batch(logits)\n    labels_incorrect = training._shard_batch(labels_incorrect)\n\n    p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch')\n    metrics = p_compute_metrics(logits, labels_incorrect)\n    loss = metrics['loss']\n    accuracy = metrics['accuracy']\n\n    with self.subTest(name='loss_type'):\n      self.assertIsInstance(loss, jnp.ndarray)\n\n    with self.subTest(name='loss_len'):\n      self.assertEqual(loss.size, 1)\n\n    with self.subTest(name='loss_values'):\n      self.assertGreaterEqual(loss.all(), 0)\n\n    with self.subTest(name='accuracy_type'):\n      self.assertIsInstance(accuracy, jnp.ndarray)\n\n    with self.subTest(name='accuracy_len'):\n      self.assertEqual(accuracy.size, 1)\n\n    with self.subTest(name='accuracy_values'):\n      self.assertAlmostEqual(accuracy.all(), 0.0)\n\n  def test_compute_metrics_equal_logits(self):\n    \"\"\"Tests output when the logit outputs are equal for all classes.\"\"\"\n    logits, labels_correct = self._create_logits_labels(True)\n    logits = training._shard_batch(logits)\n    labels_correct = training._shard_batch(labels_correct)\n\n    p_compute_metrics = jax.pmap(utils.compute_metrics, axis_name='batch')\n    metrics = p_compute_metrics(logits, labels_correct)\n    loss = metrics['loss']\n    accuracy = metrics['accuracy']\n\n    with self.subTest(name='loss_type'):\n      self.assertIsInstance(loss, jnp.ndarray)\n\n    with self.subTest(name='loss_len'):\n      self.assertEqual(loss.size, 1)\n\n    with self.subTest(name='loss_values'):\n      self.assertGreaterEqual(loss.all(), 0)\n\n    with self.subTest(name='accuracy_type'):\n      self.assertIsInstance(accuracy, jnp.ndarray)\n\n    with self.subTest(name='accuracy_len'):\n      self.assertEqual(accuracy.size, 1)\n\n    with self.subTest(name='accuracy_values'):\n      self.assertAlmostEqual(accuracy.all(), 1.0)\n\n  def test_dump_dict_json(self):\n    \"\"\"Tests JSON dumping function.\"\"\"\n    data_dict = {\n        'np_float': np.dtype('float32').type(1.0),\n        'jnp_float': jnp.dtype('float32').type(1.0),\n        'np_int': np.dtype('int32').type(1),\n        'jnp_int': jnp.dtype('int32').type(1),\n        'np_array': np.array(1.0, dtype=np.float32),\n        'jnp_array': jnp.array(1.0, dtype=jnp.float32),\n    }\n    converted_dict = {\n        key: utils._np_converter(value) for key, value in data_dict.items()\n    }\n    json_path = tempfile.NamedTemporaryFile()\n    utils.dump_dict_json(data_dict, json_path.name)\n\n    with open(json_path.name, 'r') as input_file:\n      loaded_dict = json.load(input_file)\n    self.assertDictEqual(loaded_dict, converted_dict)\n\n  def test_count_param_two_layer_dense(self):\n    \"\"\"Tests model parameter counting on small FC model.\"\"\"\n    count = utils.count_param(self._model, ('kernel',))\n\n    self.assertEqual(\n        count,\n        self._input.size / self._batch_size * TwoLayerDense.NUM_FEATURES[0] +\n        TwoLayerDense.NUM_FEATURES[0] * TwoLayerDense.NUM_FEATURES[1])\n\n  def test_count_invalid_param(self):\n    \"\"\"Tests model parameter counting for a non-existent parameter name.\"\"\"\n    count = utils.count_param(self._model, ('not_kernel',))\n\n    self.assertEqual(count, 0)\n\n  def test_model_param_as_array(self):\n    \"\"\"Tests method for returning single parameter vector for model.\"\"\"\n    param_array = utils.param_as_array(self._model.params)\n\n    with self.subTest(name='test_param_is_vector'):\n      self.assertLen(param_array.shape, 1)\n\n    param_sizes = [param.size for param in jax.tree_leaves(self._model.params)]\n    model_size = functools.reduce(operator.add, param_sizes)\n\n    with self.subTest(name='test_param_size'):\n      self.assertEqual(param_array.size, model_size)\n\n  def test_cosine_similarity_random(self):\n    \"\"\"Tests cosine similarity for two random weight matrices.\"\"\"\n    a = jax.random.normal(self._rng, (3, 4))\n    b = jax.random.normal(self._rng, (3, 4))\n\n    cosine_similarity = utils.cosine_similarity(a, b)\n\n    with self.subTest(name='test_cosine_distance_range'):\n      self.assertBetween(cosine_similarity, 0., 1.)\n\n  def test_cosine_similarity_same(self):\n    \"\"\"Tests cosine similarity for the same weight matrix.\"\"\"\n    a = jax.random.normal(self._rng, (3, 4))\n\n    cosine_similarity = utils.cosine_similarity(a, a)\n\n    with self.subTest(name='test_cosine_distance_range'):\n      self.assertAlmostEqual(cosine_similarity, 1., places=5)\n\n  def test_cosine_similarity_same_model(self):\n    \"\"\"Tests cosine similarity for the same model.\"\"\"\n    cosine_dist = utils.cosine_similarity_model(self._model, self._model)\n\n    self.assertAlmostEqual(cosine_dist, 1., places=5)\n\n  def test_vector_difference_norm_diff_model(self):\n    \"\"\"Tests vector difference norm for different models.\"\"\"\n    vector_diff_norm = utils.vector_difference_norm_model(\n        self._model, self._model_diff_init)\n\n    self.assertGreaterEqual(vector_diff_norm, 0.)\n\n  def test_vector_difference_norm_same_model(self):\n    \"\"\"Tests vector difference norm for the same model.\"\"\"\n    vector_diff_norm = utils.vector_difference_norm_model(\n        self._model, self._model)\n\n    self.assertAlmostEqual(vector_diff_norm, 0., places=5)\n\n  T = TypeVar('T')\n  @parameterized.parameters(\n\n      # Tests pairwise longest iterator convenience function with list.\n      ((1, 2, 3, 4), ((1, 2), (2, 3), (3, 4), (4, None))),\n      # Tests pairwise longest iterator with empty input iterator.\n      (iter(()), ()),\n      # Tests pairwise longest iterator with single element iterator.\n      ((1,), ((1, None),))\n  )\n  def test_pairwise_longest_list_iterator(\n      self, input_sequence,\n      output_sequence):\n    \"\"\"Tests pairwise longest iterator with list iterators.\"\"\"\n    output = list(utils.pairwise_longest(iter(input_sequence)))\n\n    self.assertSequenceEqual(output, output_sequence)\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "rigl/imagenet_resnet/colabs/MobileNet_Counting.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"e5O1UdsY202_\"\n      },\n      \"source\": [\n        \"##### Copyright 2020 Google LLC.\\n\",\n        \"\\n\",\n        \"Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Wtx39-f76KsC\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Download necessary libraries.\\n\",\n        \"%%bash \\n\",\n        \"test -d rigl || git clone https://github.com/google-research/rigl rigl_repo \\u0026\\u0026 mv rigl_repo/rigl ./ \\n\",\n        \"test -d gresearch || git clone https://github.com/google-research/google-research google_research\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"i25HTaVl6LAI\"\n      },\n      \"source\": [\n        \"## Parameter and FLOPs Counting for MobileNetv1 \"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"gAkFMbjrNCww\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"import tensorflow as tf\\n\",\n        \"from google_research.micronet_challenge import counting\\n\",\n        \"from rigl import sparse_utils\\n\",\n        \"tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 34\n        },\n        \"executionInfo\": {\n          \"elapsed\": 2458,\n          \"status\": \"ok\",\n          \"timestamp\": 1593006846761,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"dYm9k-Q47PXe\",\n        \"outputId\": \"db7fc195-6e0b-4c04-b695-5670128503d7\"\n      },\n      \"outputs\": [\n        {\n          \"data\": {\n            \"text/plain\": [\n              \"\\u003ctf.Tensor 'mobilenet_1.00_224/act_softmax/Softmax:0' shape=(2, 1000) dtype=float32\\u003e\"\n            ]\n          },\n          \"execution_count\": 2,\n          \"metadata\": {\n            \"tags\": []\n          },\n          \"output_type\": \"execute_result\"\n        }\n      ],\n      \"source\": [\n        \"tf.compat.v1.reset_default_graph()\\n\",\n        \"model=tf.keras.applications.MobileNet(input_shape=(224,224,3), weights=None)\\n\",\n        \"model(tf.ones((2,224,224,3)))\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"RNS1s5Wm7U8-\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"masked_layers = []\\n\",\n        \"dw_layers = []\\n\",\n        \"for layer in model.layers:\\n\",\n        \"  if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense, tf.keras.layers.DepthwiseConv2D)):     \\n\",\n        \"    masked_layers.append(layer)\\n\",\n        \"    if 'conv_dw' in layer.name:\\n\",\n        \"      dw_layers.append(layer)\\n\",\n        \"    # print(layer.name, sparse_utils._get_kernel(layer).shape)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"QtD03TrBSDzV\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"PARAM_SIZE=32\\n\",\n        \"import functools\\n\",\n        \"\\n\",\n        \"get_stats = functools.partial(\\n\",\n        \"    sparse_utils.get_stats, first_layer_name='conv1',\\n\",\n        \"    last_layer_name='conv_preds', param_size=PARAM_SIZE)\\n\",\n        \"\\n\",\n        \"def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',\\n\",\n        \"                custom_sparsities=None, is_debug=False, width=1.):\\n\",\n        \"  print('Method: %s, Sparsity: %f' % (method, default_sparsity))\\n\",\n        \"  total_flops, total_param_bits, sparsity = get_stats(\\n\",\n        \"      masked_layers, default_sparsity=default_sparsity, method=method,\\n\",\n        \"      custom_sparsities=custom_sparsities, is_debug=is_debug, width=width)\\n\",\n        \"  print('Total Flops: %.3f MFlops' % (total_flops/1e6))\\n\",\n        \"  print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))\\n\",\n        \"  print('Real Sparsity: %.3f' % (sparsity))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"FvqtfXePpgdb\"\n      },\n      \"source\": [\n        \"### Printing sparse network stats\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 218\n        },\n        \"executionInfo\": {\n          \"elapsed\": 548,\n          \"status\": \"ok\",\n          \"timestamp\": 1593006940695,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"qupDcQOlTxDk\",\n        \"outputId\": \"f59b39d2-eedb-4e45-db93-f52958f24a45\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Method: erdos_renyi_kernel, Sparsity: 0.750000\\n\",\n            \"Total Flops: 599.144 MFlops\\n\",\n            \"Total Size: 4.888 Mbytes\\n\",\n            \"Real Sparsity: 0.742\\n\",\n            \"Method: random, Sparsity: 0.750000\\n\",\n            \"Total Flops: 330.769 MFlops\\n\",\n            \"Total Size: 4.894 Mbytes\\n\",\n            \"Real Sparsity: 0.742\\n\",\n            \"Method: random, Sparsity: 0.000000\\n\",\n            \"Total Flops: 1141.544 MFlops\\n\",\n            \"Total Size: 16.864 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"c_sparsities = {'%s/depthwise_kernel:0' % l.name: 0. for l in dw_layers}\\n\",\n        \"c_sparsities_uniform = c_sparsities.copy()\\n\",\n        \"c_sparsities_uniform['conv1/kernel:0'] = 0.\\n\",\n        \"# c_sparsities_uniform['conv_preds/kernel:0'] = 0.\\n\",\n        \"# First layer has sparsity 0 by default.\\n\",\n        \"print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0.75, 'random', c_sparsities_uniform, is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0, 'random', is_debug=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 151\n        },\n        \"executionInfo\": {\n          \"elapsed\": 529,\n          \"status\": \"ok\",\n          \"timestamp\": 1593028091210,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"qvagZCnX31yP\",\n        \"outputId\": \"542832bb-7b59-4f43-d216-73260a9a3a56\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Method: erdos_renyi_kernel, Sparsity: 0.850000\\n\",\n            \"Total Flops: 439.152 MFlops\\n\",\n            \"Total Size: 3.224 Mbytes\\n\",\n            \"Real Sparsity: 0.841\\n\",\n            \"Method: random, Sparsity: 0.850000\\n\",\n            \"Total Flops: 222.666 MFlops\\n\",\n            \"Total Size: 3.229 Mbytes\\n\",\n            \"Real Sparsity: 0.841\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0.85, 'random', c_sparsities_uniform, is_debug=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 151\n        },\n        \"executionInfo\": {\n          \"elapsed\": 840,\n          \"status\": \"ok\",\n          \"timestamp\": 1593006957962,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"t3L8WlYJOhku\",\n        \"outputId\": \"e5d4709b-984e-4e6d-ded4-8bdd81071267\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Method: erdos_renyi_kernel, Sparsity: 0.900000\\n\",\n            \"Total Flops: 334.134 MFlops\\n\",\n            \"Total Size: 2.392 Mbytes\\n\",\n            \"Real Sparsity: 0.890\\n\",\n            \"Method: random, Sparsity: 0.900000\\n\",\n            \"Total Flops: 168.614 MFlops\\n\",\n            \"Total Size: 2.396 Mbytes\\n\",\n            \"Real Sparsity: 0.890\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0.9, 'random', c_sparsities_uniform, is_debug=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 153\n        },\n        \"executionInfo\": {\n          \"elapsed\": 567,\n          \"status\": \"ok\",\n          \"timestamp\": 1582843606223,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 480\n        },\n        \"id\": \"Ge1Ct0YjUME1\",\n        \"outputId\": \"7144ccdc-eae9-47d8-8a5c-b74aad94187c\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"Method: erdos_renyi_kernel, Sparsity: 0.950000\\n\",\n            \"Total Flops: 205.281 MFlops\\n\",\n            \"Total Size: 1.560 Mbytes\\n\",\n            \"Real Sparsity: 0.940\\n\",\n            \"Method: random, Sparsity: 0.950000\\n\",\n            \"Total Flops: 114.563 MFlops\\n\",\n            \"Total Size: 1.563 Mbytes\\n\",\n            \"Real Sparsity: 0.940\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', c_sparsities, is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0.95, 'random', c_sparsities_uniform, is_debug=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"2RnZ9BCDVJ2P\"\n      },\n      \"source\": [\n        \"## Finding the width Multiplier for small dense model\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 173\n        },\n        \"executionInfo\": {\n          \"elapsed\": 536,\n          \"status\": \"ok\",\n          \"timestamp\": 1569942238017,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"-qQMOoNqURfs\",\n        \"outputId\": \"4edf8c57-c3ab-45a1-f19d-13be5da23368\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"0.9933069386323201\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 266.539 MFlops\\n\",\n            \"Total Size: 4.789 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.750000\\n\",\n            \"Total Flops: 588.355 MFlops\\n\",\n            \"Total Size: 4.757 Mbytes\\n\",\n            \"Real Sparsity: 0.750\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.47)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.47)\\n\",\n        \"print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 173\n        },\n        \"executionInfo\": {\n          \"elapsed\": 536,\n          \"status\": \"ok\",\n          \"timestamp\": 1569942242149,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"P5mS-6h3ZChX\",\n        \"outputId\": \"b722e40b-2797-454e-a2bb-91cdaef4a79d\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"0.9998127484496482\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 154.770 MFlops\\n\",\n            \"Total Size: 3.076 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.850000\\n\",\n            \"Total Flops: 422.419 MFlops\\n\",\n            \"Total Size: 3.075 Mbytes\\n\",\n            \"Real Sparsity: 0.850\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.353)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.353)\\n\",\n        \"print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 168\n        },\n        \"executionInfo\": {\n          \"elapsed\": 656,\n          \"status\": \"ok\",\n          \"timestamp\": 1569028742267,\n          \"user\": {\n            \"displayName\": \"Utku Evci\",\n            \"photoUrl\": \"https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64\",\n            \"userId\": \"01088181649958641579\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"wY2Uc8RlVkRb\",\n        \"outputId\": \"03535606-8b6f-4eb9-ca48-ef235d69994f\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"0.9996546850118981\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 103.825 MFlops\\n\",\n            \"Total Size: 2.236 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.900000\\n\",\n            \"Total Flops: 312.956 MFlops\\n\",\n            \"Total Size: 2.235 Mbytes\\n\",\n            \"Real Sparsity: 0.900\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.285)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.285)\\n\",\n        \"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 168\n        },\n        \"executionInfo\": {\n          \"elapsed\": 574,\n          \"status\": \"ok\",\n          \"timestamp\": 1569089855290,\n          \"user\": {\n            \"displayName\": \"Utku Evci\",\n            \"photoUrl\": \"https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64\",\n            \"userId\": \"01088181649958641579\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"TUfPAjO5Cryq\",\n        \"outputId\": \"c528942a-f531-48df-a46e-d94d5dae0a89\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"0.9982463429660301\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 56.617 MFlops\\n\",\n            \"Total Size: 1.396 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.950000\\n\",\n            \"Total Flops: 180.359 MFlops\\n\",\n            \"Total Size: 1.393 Mbytes\\n\",\n            \"Real Sparsity: 0.950\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.})\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.204)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.204)\\n\",\n        \"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"f8sqZWZYpoqa\"\n      },\n      \"source\": [\n        \"### Big-Sparse Networks\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 242\n        },\n        \"executionInfo\": {\n          \"elapsed\": 631,\n          \"status\": \"ok\",\n          \"timestamp\": 1569285091631,\n          \"user\": {\n            \"displayName\": \"Utku Evci\",\n            \"photoUrl\": \"https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64\",\n            \"userId\": \"01088181649958641579\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"f-eD8zoFY_-U\",\n        \"outputId\": \"0341ebde-cff6-497e-afaf-65e4a39ac438\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"1.0084815029856933\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.750000\\n\",\n            \"Total Flops: 2180.140 MFlops\\n\",\n            \"Total Size: 16.723 Mbytes\\n\",\n            \"Real Sparsity: 0.742\\n\",\n            \"Method: random, Sparsity: 0.750000\\n\",\n            \"Total Flops: 1122.572 MFlops\\n\",\n            \"Total Size: 15.863 Mbytes\\n\",\n            \"Real Sparsity: 0.757\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 1141.544 MFlops\\n\",\n            \"Total Size: 16.864 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"# BIGGER\\n\",\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=1.98)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0.75, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=1.98)\\n\",\n        \"print_stats(masked_layers, 0.75, 'random', {'conv_preds/kernel:0':0.8, 'conv1/kernel:0':0.}, is_debug=False, width=1.98)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 168\n        },\n        \"executionInfo\": {\n          \"elapsed\": 581,\n          \"status\": \"ok\",\n          \"timestamp\": 1569029822060,\n          \"user\": {\n            \"displayName\": \"Utku Evci\",\n            \"photoUrl\": \"https://lh3.googleusercontent.com/a-/AAuE7mAWXSVCykm6kPzLHt5KN6jYg31_w1lnqRpCfWt35A=s64\",\n            \"userId\": \"01088181649958641579\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"z_rW4hO0ZwIG\",\n        \"outputId\": \"efe0e3cd-4ed1-49eb-db6b-d673b01cc020\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"1.0032864697591513\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.850000\\n\",\n            \"Total Flops: 2442.726 MFlops\\n\",\n            \"Total Size: 16.809 Mbytes\\n\",\n            \"Real Sparsity: 0.846\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 1141.544 MFlops\\n\",\n            \"Total Size: 16.864 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=2.52)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0.85, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=2.52)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 242\n        },\n        \"executionInfo\": {\n          \"elapsed\": 558,\n          \"status\": \"ok\",\n          \"timestamp\": 1569939161351,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"MHhuiXGlaQEi\",\n        \"outputId\": \"74db692f-bc1d-4f42-acc9-3848f4b2d21c\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"1.0120353164650686\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.900000\\n\",\n            \"Total Flops: 2452.785 MFlops\\n\",\n            \"Total Size: 16.664 Mbytes\\n\",\n            \"Real Sparsity: 0.899\\n\",\n            \"Method: random, Sparsity: 0.900000\\n\",\n            \"Total Flops: 1058.478 MFlops\\n\",\n            \"Total Size: 17.833 Mbytes\\n\",\n            \"Real Sparsity: 0.890\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 1141.544 MFlops\\n\",\n            \"Total Size: 16.864 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=3.)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=3.)\\n\",\n        \"print_stats(masked_layers, 0.9, 'random', {'conv_preds/kernel:0':0.8, 'conv1/kernel:0':0.}, is_debug=False, width=3.)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"colab\": {\n          \"height\": 173\n        },\n        \"executionInfo\": {\n          \"elapsed\": 523,\n          \"status\": \"ok\",\n          \"timestamp\": 1569939157037,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": 240\n        },\n        \"id\": \"wENtmNUGaXwj\",\n        \"outputId\": \"dab1f1c2-b647-4a67-b486-5ec5dcfcf4af\"\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stdout\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"1.0031304863290271\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.950000\\n\",\n            \"Total Flops: 2132.954 MFlops\\n\",\n            \"Total Size: 16.812 Mbytes\\n\",\n            \"Real Sparsity: 0.954\\n\",\n            \"Method: erdos_renyi_kernel, Sparsity: 0.000000\\n\",\n            \"Total Flops: 1141.544 MFlops\\n\",\n            \"Total Size: 16.864 Mbytes\\n\",\n            \"Real Sparsity: 0.000\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, width=3.98)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', {'conv1/kernel:0':0.}, is_debug=False, width=3.98)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"klQNdBJIqm3E\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [],\n      \"last_runtime\": {\n        \"build_target\": \"//learning/brain/python/client:colab_notebook\",\n        \"kind\": \"private\"\n      },\n      \"name\": \"MobileNet v1: Param/Flops Counting [OPEN_SOURCE].ipynb\"\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 2\",\n      \"name\": \"python2\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "rigl/imagenet_resnet/colabs/Resnet_50_Param_Flops_Counting.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"e5O1UdsY202_\"\n      },\n      \"source\": [\n        \"##### Copyright 2020 Google LLC.\\n\",\n        \"\\n\",\n        \"Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"P5p1fkA3rgL_\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Download the official ResNet50 implementation and other libraries.\\n\",\n        \"# the ResNet50 module s.t. we can use the model builders for our counting.\\n\",\n        \"%%bash \\n\",\n        \"test -d tpu || git clone https://github.com/tensorflow/tpu tpu \\u0026\\u0026 mv tpu/models/experimental/resnet50_keras ./ \\n\",\n        \"test -d rigl || git clone https://github.com/google-research/rigl rigl_repo \\u0026\\u0026 mv rigl_repo/rigl ./ \\n\",\n        \"test -d gresearch || git clone https://github.com/google-research/google-research google_research\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"tmr3djWe1rKj\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"import tensorflow as tf\\n\",\n        \"from micronet_challenge import counting\\n\",\n        \"from resnet50_keras import resnet_model as resnet_keras\\n\",\n        \"from rigl import sparse_utils\\n\",\n        \"tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"dYm9k-Q47PXe\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"tf.compat.v1.reset_default_graph()\\n\",\n        \"model = resnet_keras.ResNet50(1000)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"RNS1s5Wm7U8-\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"masked_layers = []\\n\",\n        \"for layer in model.layers:\\n\",\n        \"  if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):\\n\",\n        \"    masked_layers.append(layer)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"QtD03TrBSDzV\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"PARAM_SIZE=32 # bits\\n\",\n        \"import functools\\n\",\n        \"get_stats = functools.partial(\\n\",\n        \"    sparse_utils.get_stats, first_layer_name='conv1', last_layer_name='fc1000',\\n\",\n        \"    param_size=PARAM_SIZE)\\n\",\n        \"def print_stats(masked_layers, default_sparsity=0.8, method='erdos_renyi',\\n\",\n        \"                custom_sparsities={}, is_debug=False, width=1., **kwargs):\\n\",\n        \"  print('Method: %s, Sparsity: %f' % (method, default_sparsity))\\n\",\n        \"  total_flops, total_param_bits, sparsity = get_stats(\\n\",\n        \"      masked_layers, default_sparsity=default_sparsity, method=method,\\n\",\n        \"      custom_sparsities=custom_sparsities, is_debug=is_debug, width=width, **kwargs)\\n\",\n        \"  print('Total Flops: %.3f MFlops' % (total_flops/1e6))\\n\",\n        \"  print('Total Size: %.3f Mbytes' % (total_param_bits/8e6))\\n\",\n        \"  print('Real Sparsity: %.3f' % (sparsity))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"C_2kH9dsrUqu\"\n      },\n      \"source\": [\n        \"# Pruning FLOPs\\n\",\n        \"We calculate theoratical FLOPs for pruning, which means we will start counting sparse FLOPs when the pruning starts.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"yHmbXdMyT2c8\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"p_start, p_end, p_freq = 10000,25000,1000\\n\",\n        \"target_sparsity = 0.8\\n\",\n        \"total_flops = []\\n\",\n        \"for i in range(0,32001,1000):\\n\",\n        \"  if i \\u003c p_start:\\n\",\n        \"    sparsity = 0.\\n\",\n        \"  elif p_end \\u003c i:\\n\",\n        \"    sparsity = target_sparsity\\n\",\n        \"  else:\\n\",\n        \"    sparsity = (1-(1-(i-p_start)/float(p_end-p_start))**3)*target_sparsity\\n\",\n        \"  # print(i, sparsity)\\n\",\n        \"  c_flops, _, _ = get_stats(\\n\",\n        \"      masked_layers, default_sparsity=sparsity, method='random', custom_sparsities={'conv1/kernel:0':0, 'fc1000/kernel:0':0.8})\\n\",\n        \"  # print(i, c_flops, sparsity)\\n\",\n        \"  total_flops.append(c_flops)\\n\",\n        \"avg_flops = sum(total_flops) / len(total_flops)\\n\",\n        \"print('Average Flops: ', avg_flops, avg_flops/total_flops[0])\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"xUU10hxxsZX-\"\n      },\n      \"source\": [\n        \"### Printing sparse network stats.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"qupDcQOlTxDk\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=True, erk_power_scale=0.2)\\n\",\n        \"print_stats(masked_layers, 0.8, 'erdos_renyi')\\n\",\n        \"print_stats(masked_layers, 0.8, 'random', {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0, 'random', is_debug=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"AI1HIlLrzuED\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0.9, 'erdos_renyi')\\n\",\n        \"print_stats(masked_layers, 0.9, 'random', {'conv1/kernel:0':0., 'fc1000/kernel:0':0.9}, is_debug=False)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"oX5klsS4_vy-\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0.95, 'erdos_renyi')\\n\",\n        \"print_stats(masked_layers, 0.95, 'random', {'conv1/kernel:0':0}, is_debug=False)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"fe2FHmPfzS7S\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', {'conv1/kernel:0':0}, is_debug=False)\\n\",\n        \"print_stats(masked_layers, 0.965, 'erdos_renyi')\\n\",\n        \"print_stats(masked_layers, 0.965, 'random', {'conv1/kernel:0':0}, is_debug=False)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Yc2EeP_YWUfA\"\n      },\n      \"source\": [\n        \"## Finding the width Multiplier for small dense model\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"p8NJFEo9Se2S\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.465)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.465)\\n\",\n        \"print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Gjk8Z2g2TOKq\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.34)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.34)\\n\",\n        \"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Sa1zoC-bT-Qk\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.95, 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.26)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.26)\\n\",\n        \"print_stats(masked_layers, 0.95, 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"f_IugJP5URFa\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0.965, 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel', width=0.231)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=0.231)\\n\",\n        \"print_stats(masked_layers, 0.965, 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fXd4Mx90sc9Q\"\n      },\n      \"source\": [\n        \"### Printing the Big-Sparse Results\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"BtpJ3LvKYCNn\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# BIGGER\\n\",\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0.8, 'erdos_renyi_kernel', width=2.1)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0.8, 'erdos_renyi_kernel', is_debug=False, width=2.1)\\n\",\n        \"print_stats(masked_layers, 0.8, 'random',  {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8},\\n\",\n        \"            is_debug=False, width=2.1)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.1)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"kRcOlrf4YG7K\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"_, sparse_bits, _ = get_stats(masked_layers, 0., 'erdos_renyi_kernel')\\n\",\n        \"_, bits, _ = get_stats(masked_layers, 0.9, 'erdos_renyi_kernel', width=2.8)\\n\",\n        \"print(sparse_bits/bits)\\n\",\n        \"print_stats(masked_layers, 0.9, 'erdos_renyi_kernel', is_debug=False, width=2.8)\\n\",\n        \"print_stats(masked_layers, 0.9, 'random',  {'conv1/kernel:0':0, 'fc1000/kernel:0':0.8}, is_debug=False, width=2.8)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=2.8)\\n\",\n        \"print_stats(masked_layers, 0., 'erdos_renyi_kernel', is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"BN8qfasQWva2\"\n      },\n      \"source\": [\n        \"## [BONUS] DSR FLOPs\\n\",\n        \"Obtained from figure https://arxiv.org/abs/1902.05967; exact values are probably slightly different.\\n\",\n        \"\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"RwI5aRe-SH0n\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"resnet_layers=['conv1/kernel:0',\\n\",\n        \"'res2a_branch2a/kernel:0',\\n\",\n        \"'res2a_branch2b/kernel:0',\\n\",\n        \"'res2a_branch2c/kernel:0',\\n\",\n        \"'res2a_branch1/kernel:0',\\n\",\n        \"'res2b_branch2a/kernel:0',\\n\",\n        \"'res2b_branch2b/kernel:0',\\n\",\n        \"'res2b_branch2c/kernel:0',\\n\",\n        \"'res2c_branch2a/kernel:0',\\n\",\n        \"'res2c_branch2b/kernel:0',\\n\",\n        \"'res2c_branch2c/kernel:0',\\n\",\n        \"'res3a_branch2a/kernel:0',\\n\",\n        \"'res3a_branch2b/kernel:0',\\n\",\n        \"'res3a_branch2c/kernel:0',\\n\",\n        \"'res3a_branch1/kernel:0',\\n\",\n        \"'res3b_branch2a/kernel:0',\\n\",\n        \"'res3b_branch2b/kernel:0',\\n\",\n        \"'res3b_branch2c/kernel:0',\\n\",\n        \"'res3c_branch2a/kernel:0',\\n\",\n        \"'res3c_branch2b/kernel:0',\\n\",\n        \"'res3c_branch2c/kernel:0',\\n\",\n        \"'res3d_branch2a/kernel:0',\\n\",\n        \"'res3d_branch2b/kernel:0',\\n\",\n        \"'res3d_branch2c/kernel:0',\\n\",\n        \"'res4a_branch2a/kernel:0',\\n\",\n        \"'res4a_branch2b/kernel:0',\\n\",\n        \"'res4a_branch2c/kernel:0',\\n\",\n        \"'res4a_branch1/kernel:0',\\n\",\n        \"'res4b_branch2a/kernel:0',\\n\",\n        \"'res4b_branch2b/kernel:0',\\n\",\n        \"'res4b_branch2c/kernel:0',\\n\",\n        \"'res4c_branch2a/kernel:0',\\n\",\n        \"'res4c_branch2b/kernel:0',\\n\",\n        \"'res4c_branch2c/kernel:0',\\n\",\n        \"'res4d_branch2a/kernel:0',\\n\",\n        \"'res4d_branch2b/kernel:0',\\n\",\n        \"'res4d_branch2c/kernel:0',\\n\",\n        \"'res4e_branch2a/kernel:0',\\n\",\n        \"'res4e_branch2b/kernel:0',\\n\",\n        \"'res4e_branch2c/kernel:0',\\n\",\n        \"'res4f_branch2a/kernel:0',\\n\",\n        \"'res4f_branch2b/kernel:0',\\n\",\n        \"'res4f_branch2c/kernel:0',\\n\",\n        \"'res5a_branch2a/kernel:0',\\n\",\n        \"'res5a_branch2b/kernel:0',\\n\",\n        \"'res5a_branch2c/kernel:0',\\n\",\n        \"'res5a_branch1/kernel:0',\\n\",\n        \"'res5b_branch2a/kernel:0',\\n\",\n        \"'res5b_branch2b/kernel:0',\\n\",\n        \"'res5b_branch2c/kernel:0',\\n\",\n        \"'res5c_branch2a/kernel:0',\\n\",\n        \"'res5c_branch2b/kernel:0',\\n\",\n        \"'res5c_branch2c/kernel:0',\\n\",\n        \"'fc1000/kernel:0']\\n\",\n        \"dsr_sparsities8=[0,\\n\",\n        \"            0., .15, .5, .425, .575, .55, .425, .32, .44, .15,\\n\",\n        \"            0., .15, .55, .6, .8, .65, .75, .65, .65, .65, .55, .65, .7,\\n\",\n        \"            0., .35, .65, .85, .9, .8, .85, .85, .8, .85, .85, .85, .85, .8, .8, .9, .75, .8, .85,\\n\",\n        \"            0., .65, .85, .95, .85, .8, .9, .65, .9, .8,\\n\",\n        \"            .8]\\n\",\n        \"dsr_sparsities9=[0,\\n\",\n        \"            0., .4, .6, .65, .65, .6, .6, .5, .6, .45,\\n\",\n        \"            0., .4, .7, .8, .9, .8, .85, .8, .75, .8, .7, .8, .8,\\n\",\n        \"            0., .6, .8, .95, .95, .9, .95, .9, .9, .95, .9, .9, .95, .9, .9, .95, .85, .85, .9,\\n\",\n        \"            0., 0.8, .95, .95, .9, .9, .95, .8, .95, .9,\\n\",\n        \"            .9] \"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"P6i-jjz6OLBH\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dsr_map = dict(zip(resnet_layers, dsr_sparsities8))\\n\",\n        \"print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"xeGqdHtYYlZT\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"dsr_map = dict(zip(resnet_layers, dsr_sparsities9))\\n\",\n        \"print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"Pf3qqLKrG67e\"\n      },\n      \"source\": [\n        \"# [BONUS] STR FLOPs\\n\",\n        \"Layerwise sparsities are obtained from the [STR paper](https://arxiv.org/abs/2002.03231).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"MIwBmu0NHOuI\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"str_sparsities = \\\"\\\"\\\"\\n\",\n        \"Layer 1 - conv1 9408 118013952 51.46 51.40 63.02 59.80 59.83 64.87 67.36 66.96 72.11 69.46 73.29 73.47 72.05 75.12 76.12 77.75\\n\",\n        \"Layer 2 - layer1.0.conv1 4096 12845056 69.36 73.24 87.57 83.28 85.18 89.60 91.41 91.11 92.38 91.75 94.46 94.51 94.60 95.95 96.53 96.51\\n\",\n        \"Layer 3 - layer1.0.conv2 36864 115605504 77.85 76.26 90.87 89.48 87.31 94.79 94.27 95.04 95.69 96.07 97.36 97.77 98.35 98.51 98.59 98.84\\n\",\n        \"Layer 4 - layer1.0.conv3 16384 51380224 74.81 74.65 86.52 85.80 85.25 91.85 92.78 93.67 94.13 94.69 96.61 97.03 97.37 98.04 98.21 98.47\\n\",\n        \"Layer 5 - layer1.0.downsample.0 16384 51380224 70.95 72.96 83.53 83.34 82.56 89.13 90.62 90.17 91.83 92.69 95.48 94.89 95.68 96.98 97.56 97.72\\n\",\n        \"Layer 6 - layer1.1.conv1 16384 51380224 80.27 79.58 89.82 89.89 88.51 94.56 96.64 95.78 95.81 96.81 98.79 98.90 98.98 99.13 99.62 99.47\\n\",\n        \"Layer 7 - layer1.1.conv2 36864 115605504 81.36 80.95 91.75 90.60 89.61 94.70 95.78 96.18 96.42 97.26 98.65 99.07 99.40 99.11 99.31 99.56\\n\",\n        \"Layer 8 - layer1.1.conv3 16384 51380224 84.45 80.11 91.22 91.70 90.21 95.17 97.05 95.81 96.34 97.23 98.68 98.76 98.90 99.16 99.57 99.46\\n\",\n        \"Layer 9 - layer1.2.conv1 16384 51380224 78.23 79.79 90.12 88.07 89.36 94.62 95.94 94.74 96.23 96.75 97.96 98.41 98.72 99.38 99.35 99.46\\n\",\n        \"Layer 10 - layer1.2.conv2 36864 115605504 76.01 81.53 91.06 87.03 88.27 93.90 95.63 94.26 96.24 96.11 97.54 98.27 98.44 99.32 99.19 99.39\\n\",\n        \"Layer 11 - layer1.2.conv3 16384 51380224 84.47 83.28 94.95 90.99 92.64 95.76 96.95 96.01 96.87 97.31 98.38 98.60 98.72 99.38 99.27 99.51\\n\",\n        \"Layer 12 - layer2.0.conv1 32768 102760448 73.74 73.96 86.78 85.95 85.90 92.32 94.79 93.86 94.62 95.64 97.19 98.22 98.52 98.48 98.84 98.92\\n\",\n        \"Layer 13 - layer2.0.conv2 147456 115605504 82.56 85.70 91.31 93.91 94.03 97.54 97.43 97.65 98.38 98.62 99.24 99.23 99.40 99.61 99.67 99.63\\n\",\n        \"Layer 14 - layer2.0.conv3 65536 51380224 84.70 83.55 93.04 93.13 92.13 96.61 97.37 97.21 97.59 98.14 98.80 98.95 99.18 99.29 99.47 99.43\\n\",\n        \"Layer 15 - layer2.0.downsample.0 131072 102760448 85.10 87.66 92.78 94.96 95.13 98.07 97.97 98.15 98.70 98.88 99.37 99.35 99.40 99.69 99.68 99.71\\n\",\n        \"Layer 16 - layer2.1.conv1 65536 51380224 85.42 85.79 94.04 95.31 94.94 97.92 98.53 98.21 98.84 99.06 99.46 99.53 99.72 99.78 99.81 99.80\\n\",\n        \"Layer 17 - layer2.1.conv2 147456 115605504 76.95 82.75 87.63 91.50 91.76 95.59 97.22 96.07 97.32 97.80 98.24 98.24 98.60 99.24 99.66 99.33\\n\",\n        \"Layer 18 - layer2.1.conv3 65536 51380224 84.76 84.71 93.10 93.66 93.23 97.00 98.18 97.35 98.06 98.41 98.96 99.21 99.32 99.55 99.58 99.59\\n\",\n        \"Layer 19 - layer2.2.conv1 65536 51380224 84.30 85.34 92.70 94.61 94.76 97.72 97.91 98.21 98.54 98.98 99.24 99.35 99.50 99.62 99.63 99.77\\n\",\n        \"Layer 20 - layer2.2.conv2 147456 115605504 84.28 85.43 92.99 94.86 94.90 97.52 97.21 98.11 98.19 99.04 99.28 99.37 99.46 99.63 99.59 99.72\\n\",\n        \"Layer 21 - layer2.2.conv3 65536 51380224 82.19 84.21 91.12 93.38 93.53 96.89 97.14 97.59 97.77 98.66 98.96 99.15 99.25 99.49 99.51 99.57\\n\",\n        \"Layer 22 - layer2.3.conv1 65536 51380224 83.37 84.41 90.46 93.26 93.50 96.71 97.89 96.99 98.14 98.36 99.10 99.23 99.33 99.53 99.75 99.60\\n\",\n        \"Layer 23 - layer2.3.conv2 147456 115605504 82.83 84.03 91.44 93.21 93.25 96.83 98.02 96.96 98.45 98.30 98.97 99.06 99.26 99.31 99.81 99.68\\n\",\n        \"Layer 24 - layer2.3.conv3 65536 51380224 82.93 85.65 91.02 94.14 93.56 97.20 97.97 97.04 98.16 98.36 98.88 98.97 99.20 99.32 99.67 99.62\\n\",\n        \"Layer 25 - layer3.0.conv1 131072 102760448 76.63 77.98 85.99 88.85 88.60 94.26 95.07 94.97 96.21 96.59 97.75 98.04 98.30 98.72 99.11 99.06\\n\",\n        \"Layer 26 - layer3.0.conv2 589824 115605504 87.35 88.68 94.39 96.14 96.19 98.51 98.77 98.72 99.11 99.23 99.53 99.59 99.64 99.73 99.80 99.81\\n\",\n        \"Layer 27 - layer3.0.conv3 262144 51380224 81.22 83.22 90.58 93.19 93.05 96.82 97.38 97.32 97.98 98.28 98.88 99.03 99.16 99.39 99.55 99.53\\n\",\n        \"Layer 28 - layer3.0.downsample.0 524288 102760448 89.75 90.99 96.05 97.20 97.16 98.96 99.21 99.20 99.50 99.58 99.78 99.82 99.86 99.91 99.94 99.93\\n\",\n        \"Layer 29 - layer3.1.conv1 262144 51380224 85.88 87.35 93.43 95.36 96.12 98.64 98.77 98.87 99.22 99.33 99.64 99.67 99.72 99.82 99.88 99.84\\n\",\n        \"Layer 30 - layer3.1.conv2 589824 115605504 85.06 86.24 92.74 95.06 95.30 98.09 98.28 98.36 98.75 99.08 99.46 99.48 99.54 99.69 99.76 99.76\\n\",\n        \"Layer 31 - layer3.1.conv3 262144 51380224 84.34 86.79 92.15 94.84 94.90 97.75 98.15 98.11 98.56 98.94 99.30 99.36 99.45 99.65 99.79 99.70\\n\",\n        \"Layer 32 - layer3.2.conv1 262144 51380224 87.51 89.15 94.15 96.77 96.46 98.81 98.83 98.96 99.19 99.44 99.67 99.71 99.74 99.82 99.85 99.89\\n\",\n        \"Layer 33 - layer3.2.conv2 589824 115605504 87.15 88.67 94.09 95.59 96.14 98.86 98.69 98.91 99.21 99.20 99.64 99.72 99.76 99.85 99.84 99.90\\n\",\n        \"Layer 34 - layer3.2.conv3 262144 51380224 84.86 86.90 92.40 94.99 94.99 98.19 98.19 98.42 98.76 98.97 99.42 99.56 99.62 99.76 99.75 99.88\\n\",\n        \"Layer 35 - layer3.3.conv1 262144 51380224 86.62 89.46 94.06 96.08 95.88 98.70 98.71 98.77 99.01 99.27 99.58 99.66 99.69 99.83 99.87 99.87\\n\",\n        \"Layer 36 - layer3.3.conv2 589824 115605504 86.52 87.97 93.56 96.10 96.11 98.70 98.82 98.89 99.19 99.31 99.68 99.73 99.77 99.88 99.87 99.93\\n\",\n        \"Layer 37 - layer3.3.conv3 262144 51380224 84.19 86.81 92.32 94.94 94.91 98.20 98.37 98.43 98.82 99.00 99.51 99.57 99.64 99.81 99.81 99.87\\n\",\n        \"Layer 38 - layer3.4.conv1 262144 51380224 85.85 88.40 93.55 95.49 95.86 98.35 98.44 98.55 98.79 98.96 99.54 99.59 99.60 99.82 99.86 99.87\\n\",\n        \"Layer 39 - layer3.4.conv2 589824 115605504 85.96 87.38 93.27 95.66 95.63 98.41 98.58 98.56 99.19 99.26 99.64 99.69 99.67 99.87 99.90 99.92\\n\",\n        \"Layer 40 - layer3.4.conv3 262144 51380224 83.45 85.76 91.75 94.49 94.35 97.67 98.09 97.99 98.65 98.94 99.49 99.52 99.48 99.77 99.86 99.85\\n\",\n        \"Layer 41 - layer3.5.conv1 262144 51380224 83.33 85.77 91.79 95.09 94.24 97.46 97.89 97.92 98.71 98.90 99.35 99.52 99.58 99.76 99.79 99.83\\n\",\n        \"Layer 42 - layer3.5.conv2 589824 115605504 84.98 86.67 92.48 94.92 95.13 97.88 98.14 98.32 98.91 99.00 99.44 99.58 99.69 99.80 99.83 99.87\\n\",\n        \"Layer 43 - layer3.5.conv3 262144 51380224 79.78 82.23 89.39 93.14 92.76 96.59 97.04 97.30 98.10 98.41 99.03 99.25 99.44 99.61 99.71 99.75\\n\",\n        \"Layer 44 - layer4.0.conv1 524288 102760448 77.83 79.61 87.11 90.32 90.64 95.39 95.84 95.92 97.17 97.35 98.36 98.60 98.83 99.20 99.37 99.42\\n\",\n        \"Layer 45 - layer4.0.conv2 2359296 115605504 86.18 88.00 93.53 95.66 95.78 98.31 98.47 98.55 99.08 99.16 99.54 99.63 99.69 99.81 99.85 99.86\\n\",\n        \"Layer 46 - layer4.0.conv3 1048576 51380224 78.43 80.48 87.85 91.14 91.27 96.00 96.40 96.47 97.53 97.92 98.81 99.00 99.15 99.45 99.57 99.61\\n\",\n        \"Layer 47 - layer4.0.downsample.0 2097152 102760448 88.49 89.98 95.03 96.79 96.90 98.91 99.06 99.11 99.45 99.51 99.77 99.82 99.85 99.92 99.94 99.94\\n\",\n        \"Layer 48 - layer4.1.conv1 1048576 51380224 82.07 84.02 90.34 93.69 93.72 97.15 97.56 97.76 98.45 98.75 99.27 99.36 99.54 99.67 99.76 99.80\\n\",\n        \"Layer 49 - layer4.1.conv2 2359296 115605504 83.42 85.23 91.16 93.98 93.93 97.26 97.58 97.71 98.36 98.67 99.25 99.34 99.50 99.68 99.76 99.80\\n\",\n        \"Layer 50 - layer4.1.conv3 1048576 51380224 78.08 79.96 86.66 90.48 90.22 95.22 95.76 95.89 96.88 97.65 98.70 98.85 99.13 99.45 99.58 99.66\\n\",\n        \"Layer 51 - layer4.2.conv1 1048576 51380224 76.34 77.93 84.98 87.57 88.47 93.90 93.87 94.16 95.55 95.91 97.66 97.97 98.15 98.88 99.08 99.22\\n\",\n        \"Layer 52 - layer4.2.conv2 2359296 115605504 73.57 74.97 82.32 84.37 86.01 91.92 91.66 92.22 94.02 94.16 96.65 97.13 97.29 98.44 98.74 99.00\\n\",\n        \"Layer 53 - layer4.2.conv3 1048576 51380224 68.78 70.38 78.11 80.29 81.73 89.64 89.43 89.65 91.40 92.65 96.02 96.72 96.93 98.47 98.83 99.15\\n\",\n        \"Layer 54 - fc 2048000 2048000 50.65 52.46 60.48 64.50 65.12 75.20 75.73 75.80 78.57 80.69 85.96 87.26 88.03 91.11 92.15 92.87\\\"\\\"\\\"\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"gSFw1eH1G8zh\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"resnet_layers=['conv1/kernel:0',\\n\",\n        \"'res2a_branch2a/kernel:0',\\n\",\n        \"'res2a_branch2b/kernel:0',\\n\",\n        \"'res2a_branch2c/kernel:0',\\n\",\n        \"'res2a_branch1/kernel:0',\\n\",\n        \"'res2b_branch2a/kernel:0',\\n\",\n        \"'res2b_branch2b/kernel:0',\\n\",\n        \"'res2b_branch2c/kernel:0',\\n\",\n        \"'res2c_branch2a/kernel:0',\\n\",\n        \"'res2c_branch2b/kernel:0',\\n\",\n        \"'res2c_branch2c/kernel:0',\\n\",\n        \"'res3a_branch2a/kernel:0',\\n\",\n        \"'res3a_branch2b/kernel:0',\\n\",\n        \"'res3a_branch2c/kernel:0',\\n\",\n        \"'res3a_branch1/kernel:0',\\n\",\n        \"'res3b_branch2a/kernel:0',\\n\",\n        \"'res3b_branch2b/kernel:0',\\n\",\n        \"'res3b_branch2c/kernel:0',\\n\",\n        \"'res3c_branch2a/kernel:0',\\n\",\n        \"'res3c_branch2b/kernel:0',\\n\",\n        \"'res3c_branch2c/kernel:0',\\n\",\n        \"'res3d_branch2a/kernel:0',\\n\",\n        \"'res3d_branch2b/kernel:0',\\n\",\n        \"'res3d_branch2c/kernel:0',\\n\",\n        \"'res4a_branch2a/kernel:0',\\n\",\n        \"'res4a_branch2b/kernel:0',\\n\",\n        \"'res4a_branch2c/kernel:0',\\n\",\n        \"'res4a_branch1/kernel:0',\\n\",\n        \"'res4b_branch2a/kernel:0',\\n\",\n        \"'res4b_branch2b/kernel:0',\\n\",\n        \"'res4b_branch2c/kernel:0',\\n\",\n        \"'res4c_branch2a/kernel:0',\\n\",\n        \"'res4c_branch2b/kernel:0',\\n\",\n        \"'res4c_branch2c/kernel:0',\\n\",\n        \"'res4d_branch2a/kernel:0',\\n\",\n        \"'res4d_branch2b/kernel:0',\\n\",\n        \"'res4d_branch2c/kernel:0',\\n\",\n        \"'res4e_branch2a/kernel:0',\\n\",\n        \"'res4e_branch2b/kernel:0',\\n\",\n        \"'res4e_branch2c/kernel:0',\\n\",\n        \"'res4f_branch2a/kernel:0',\\n\",\n        \"'res4f_branch2b/kernel:0',\\n\",\n        \"'res4f_branch2c/kernel:0',\\n\",\n        \"'res5a_branch2a/kernel:0',\\n\",\n        \"'res5a_branch2b/kernel:0',\\n\",\n        \"'res5a_branch2c/kernel:0',\\n\",\n        \"'res5a_branch1/kernel:0',\\n\",\n        \"'res5b_branch2a/kernel:0',\\n\",\n        \"'res5b_branch2b/kernel:0',\\n\",\n        \"'res5b_branch2c/kernel:0',\\n\",\n        \"'res5c_branch2a/kernel:0',\\n\",\n        \"'res5c_branch2b/kernel:0',\\n\",\n        \"'res5c_branch2c/kernel:0',\\n\",\n        \"'fc1000/kernel:0']\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"31sg-lNhHN7D\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from collections import defaultdict\\n\",\n        \"str_sparsities_parsed = defaultdict(list)\\n\",\n        \"for j, l in enumerate(str_sparsities.strip().split('\\\\n')):\\n\",\n        \"  l = l.split('-')[1].strip().split(' ')\\n\",\n        \"  if l[0] == 'Overall':\\n\",\n        \"    overall_sparsities = map(float, l[3:])\\n\",\n        \"  else:\\n\",\n        \"    for i, ls in enumerate(l[3:]):\\n\",\n        \"      s = overall_sparsities[i]\\n\",\n        \"      # Accuracies are between 0 and 1, so devide by 100.\\n\",\n        \"      str_sparsities_parsed[s].append(float(ls) / 100.)\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Xrjtum-4HgAT\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"for k in str_sparsities_parsed:\\n\",\n        \"  print(k)\\n\",\n        \"  dsr_map = dict(zip(resnet_layers, str_sparsities_parsed[k]))\\n\",\n        \"  print_stats(masked_layers, 0., 'random', dsr_map, is_debug=False, width=1)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [],\n      \"last_runtime\": {\n        \"build_target\": \"//research/colab/notebook:notebook_backend\",\n        \"kind\": \"private\"\n      },\n      \"name\": \"Resnet-50: Param/Flops Counting [OpenSource].ipynb\"\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "rigl/imagenet_resnet/imagenet_train_eval.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"This script trains a ResNet model that implements various pruning methods.\n\n\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nimport os\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nfrom rigl import sparse_optimizers\nfrom rigl import sparse_utils\nfrom rigl.imagenet_resnet import mobilenetv1_model\nfrom rigl.imagenet_resnet import mobilenetv2_model\nfrom rigl.imagenet_resnet import resnet_model\nfrom rigl.imagenet_resnet import utils\nfrom rigl.imagenet_resnet import vgg\nfrom official.resnet import imagenet_input\nfrom tensorflow.contrib import estimator as contrib_estimator\nfrom tensorflow.contrib import tpu as contrib_tpu\nfrom tensorflow.contrib.model_pruning.python import pruning\nfrom tensorflow.contrib.tpu.python.tpu import tpu_config\nfrom tensorflow.contrib.tpu.python.tpu import tpu_estimator\nfrom tensorflow.contrib.training.python.training import evaluation\nfrom tensorflow_estimator.python.estimator import estimator\n\nDST_METHODS = [\n    'set',\n    'momentum',\n    'rigl',\n    'static'\n]\n\nALL_METHODS = tuple(['scratch', 'baseline', 'snip', 'dnw'] + DST_METHODS)\nNO_MASK_INIT_METHODS = ('snip', 'dnw', 'baseline')\n\nflags.DEFINE_string(\n    'precision',\n    default='float32',\n    help=('Precision to use; one of: {bfloat16, float32}'))\nflags.DEFINE_integer('num_workers', 1, 'Number of training workers.')\nflags.DEFINE_float(\n    'base_learning_rate',\n    default=0.1,\n    help=('Base learning rate when train batch size is 256.'))\n\nflags.DEFINE_float(\n    'momentum',\n    default=0.9,\n    help=('Momentum parameter used in the MomentumOptimizer.'))\nflags.DEFINE_integer('ps_task', 0,\n                     'Task id of the replica running the training.')\nflags.DEFINE_float(\n    'weight_decay',\n    default=1e-4,\n    help=('Weight decay coefficiant for l2 regularization.'))\nflags.DEFINE_string('master', '', 'Master job.')\nflags.DEFINE_string('tpu_job_name', None, 'For complicated TensorFlowFlock')\nflags.DEFINE_integer(\n    'steps_per_checkpoint',\n    default=1000,\n    help=('Controls how often checkpoints are generated. More steps per '\n          'checkpoint = higher utilization of TPU and generally higher '\n          'steps/sec'))\nflags.DEFINE_integer(\n    'keep_checkpoint_max', default=0, help=('Number of checkpoints to hold.'))\nflags.DEFINE_integer(\n    'seed', default=0, help=('Sets the random seed.'))\nflags.DEFINE_string(\n    'data_directory', None, 'The location of the sstable used for training.')\nflags.DEFINE_string('eval_once_ckpt_prefix', '',\n                    'File name of the eval chekpoint used for evaluation.')\nflags.DEFINE_string(\n    'data_format',\n    default='channels_last',\n    help=('A flag to override the data format used in the model. The value'\n          ' is either channels_first or channels_last. To run the network on'\n          ' CPU or TPU, channels_last should be used. For GPU, channels_first'\n          ' will improve performance.'))\nflags.DEFINE_bool(\n    'transpose_input',\n    default=False,\n    help='Use TPU double transpose optimization')\nflags.DEFINE_bool(\n    'log_mask_imgs_each_iteration',\n    default=False,\n    help='Use to log few masks as images. Be careful when using. This is'\n    ' very likely to slow down your training and create huge logs.')\nflags.DEFINE_string(\n    'mask_init_method',\n    default='',\n    help='If not empty string and mask is not loaded from a checkpoint, '\n    'indicates the method used for mask initialization. One of the following: '\n    '`random`, `erdos_renyi`.')\nflags.DEFINE_integer(\n    'resnet_depth',\n    default=50,\n    help=('Depth of ResNet model to use. Must be one of {18, 34, 50, 101, 152,'\n          ' 200}. ResNet-18 and 34 use the pre-activation residual blocks'\n          ' without bottleneck layers. The other models use pre-activation'\n          ' bottleneck layers. Deeper models require more training time and'\n          ' more memory and may require reducing --train_batch_size to prevent'\n          ' running out of memory.'))\nflags.DEFINE_float('label_smoothing', 0.1,\n                   'Relax confidence in the labels by (1-label_smoothing).')\nflags.DEFINE_float(\n    'erk_power_scale', 1.0,\n    'Softens the ERK distribituion. Value 0 means uniform.'\n    '1 means regular ERK.')\nflags.DEFINE_integer(\n    'train_steps',\n    default=2,\n    help=('The number of steps to use for training. Default is 112590 steps'\n          ' which is approximately 90 epochs at batch size 1024. This flag'\n          ' should be adjusted according to the --train_batch_size flag.'))\nflags.DEFINE_integer(\n    'train_batch_size', default=1024, help='Batch size for training.')\nflags.DEFINE_integer(\n    'eval_batch_size', default=1000, help='Batch size for evaluation.')\nflags.DEFINE_integer(\n    'num_train_images', default=1281167, help='Size of training data set.')\nflags.DEFINE_integer(\n    'num_eval_images', default=50000, help='Size of evaluation data set.')\nflags.DEFINE_integer(\n    'num_label_classes', default=1000, help='Number of classes, at least 2')\nflags.DEFINE_integer(\n    'steps_per_eval',\n    default=1251,\n    help=('Controls how often evaluation is performed. Since evaluation is'\n          ' fairly expensive, it is advised to evaluate as infrequently as'\n          ' possible (i.e. up to --train_steps, which evaluates the model only'\n          ' after finishing the entire training regime).'))\nflags.DEFINE_bool(\n    'use_tpu',\n    default=False,\n    help=('Use TPU to execute the model for training and evaluation. If'\n          ' --use_tpu=false, will use whatever devices are available to'\n          ' TensorFlow by default (e.g. CPU and GPU)'))\nflags.DEFINE_integer(\n    'iterations_per_loop',\n    default=1251,\n    help=('Number of steps to run on TPU before outfeeding metrics to the CPU.'\n          ' If the number of iterations in the loop would exceed the number of'\n          ' train steps, the loop will exit before reaching'\n          ' --iterations_per_loop. The larger this value is, the higher the'\n          ' utilization on the TPU.'))\nflags.DEFINE_integer(\n    'num_parallel_calls',\n    default=64,\n    help=('Number of parallel threads in CPU for the input pipeline'))\nflags.DEFINE_integer(\n    'num_cores',\n    default=8,\n    help=('Number of TPU cores. For a single TPU device, this is 8 because each'\n          ' TPU has 4 chips each with 2 cores.'))\nflags.DEFINE_string('output_dir', '/tmp/imagenet/',\n                    'Directory where to write event logs and checkpoint.')\nflags.DEFINE_bool('use_folder_stub', True,\n                  'If True the output_dir is extended with some parameters.')\nflags.DEFINE_bool('use_batch_statistics', False,\n                  'If True the forward pass is made in training mode. ')\nflags.DEFINE_bool('eval_on_train', False,\n                  'If True the evaluation is made on training set.')\nflags.DEFINE_enum(\n    'mode', 'train', ('train_and_eval', 'train', 'eval', 'eval_once'),\n    'One of {\"train_and_eval\", \"train\", \"eval\"}.')\nflags.DEFINE_integer('export_model_freq', 2502,\n                     'The rate at which estimator exports the model.')\n\nflags.DEFINE_enum(\n    'training_method', 'scratch', ALL_METHODS,\n    'Method used for training sparse network. `scratch` means initial mask is '\n    'kept during training. `set` is for sparse evalutionary training and '\n    '`baseline` is for dense baseline.')\nflags.DEFINE_enum(\n    'init_method', 'baseline', ('baseline', 'sparse'),\n    'Method for initialization.  If sparse and training_method=scratch, then '\n    'use initializers that take into account starting sparsity.')\n# flags.DEFINE_enum(\n#     'mask_init_method', 'baseline', ('default'),\n#     'Method for initializating masks. If not default, end_sparsities are used'\n#     ' to define the layer wise random sparse connectivity.')\n\nflags.DEFINE_bool(\n    'is_warm_up',\n    default=True,\n    help=('Boolean for whether to scale weight of regularizer.'))\n\nflags.DEFINE_float(\n    'width', -1., 'Multiplier for the number of channels in each layer')\n# first and last layer are somewhat special.  First layer has almost no\n# parameters, but 3% of the total flops.  Last layer has only .05% of the total\n# flops but 10% of the total parameters.  Depending on whether the goal is max\n# compression or max acceleration, pruning goals will be different.\nflags.DEFINE_bool('use_adam', False,\n                  'Whether to use Adam or not')\nflags.DEFINE_bool('use_sgdr', False,\n                  'Whether to use SGDR for learning rate schedule.')\nflags.DEFINE_float('sgdr_decay_step', 5, 'Initial cycle length for SGDR.')\nflags.DEFINE_float('sgdr_t_mul', 1.5, 'Cycle length multiplier for SGDR')\nflags.DEFINE_float('sgdr_m_mul', .5,\n                   'Learning rate drop at each restart cycle.')\nflags.DEFINE_float('end_sparsity', 0.9,\n                   'Target sparsity desired by end of training.')\nflags.DEFINE_float('drop_fraction', 0.3,\n                   'When changing mask dynamically, this fraction decides how '\n                   'much of the ')\nflags.DEFINE_string('drop_fraction_anneal', 'constant',\n                    'If not empty the drop fraction is annealed during sparse'\n                    ' training. One of the following: `constant`, `cosine` or '\n                    '`exponential_(\\\\d*\\\\.?\\\\d*)$`. For example: '\n                    '`exponential_3`, `exponential_.3`, `exponential_0.3`. '\n                    'The number after `exponential` defines the exponent.')\nflags.DEFINE_string('grow_init', 'zeros',\n                    'Passed to the SparseInitializer, one of: zeros, '\n                    'initial_value, random_normal, random_uniform.')\nflags.DEFINE_float('s_momentum', 0.9,\n                   'Momentum values for exponential moving average of '\n                   'gradients. Used when training_method=\"momentum\".')\nflags.DEFINE_float('rigl_acc_scale', 0.,\n                   'Used to scale initial accumulated gradients for new '\n                   'connections.')\nflags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin pruning at.')\nflags.DEFINE_integer('maskupdate_end_step', 25000, 'Step to end pruning at.')\nflags.DEFINE_integer('maskupdate_frequency', 100,\n                     'Step interval between pruning.')\nflags.DEFINE_float(\n    'first_layer_sparsity', 0.,\n    'Sparsity to use for the first layer. Overrides default end_sparsity '\n    'if greater than 0. If -1, default sparsity is applied. If 0, layer is not'\n    'pruned or masked.')\nflags.DEFINE_float(\n    'last_layer_sparsity', -1,\n    'Sparsity to use for the last layer. Overrides default end_sparsity '\n    'if greater than 0. If -1, default sparsity is applied. If 0, layer is not'\n    'pruned or masked.')\nflags.DEFINE_string(\n    'load_mask_dir', '',\n    'Directory of a trained model from which to load only the mask')\nflags.DEFINE_string(\n    'initial_value_checkpoint', '',\n    'Directory of a model from which to load only the parameters')\nflags.DEFINE_string(\n    'model_architecture', 'resnet',\n    'Which architecture to use. Options: resnet, mobilenet_v1, mobilenet_v2.'\n    'vgg_16, vgg_a, vgg_19.')\nflags.DEFINE_float('expansion_factor', 6.,\n                   'how much to expand filters before depthwise conv')\nflags.DEFINE_float('training_steps_multiplier', 1.0,\n                   'Training schedule is shortened or extended with the '\n                   'multiplier, if it is not 1.')\nflags.DEFINE_integer('block_width', 1, 'width of block')\nflags.DEFINE_integer('block_height', 1, 'height of block')\nFLAGS = flags.FLAGS\nLR_SCHEDULE = []\nPARAM_SUFFIXES = ('gamma', 'beta', 'weights', 'biases')\nMASK_SUFFIX = 'mask'\n\n\n# Learning rate schedule (multiplier, epoch to start) tuples\ndef set_lr_schedule():\n  \"\"\"Sets the learning schedule: LR_SCHEDULE for the training.\"\"\"\n  global LR_SCHEDULE\n  if FLAGS.model_architecture == 'mobilenet_v2' or FLAGS.model_architecture == 'mobilenet_v1':\n    LR_SCHEDULE = [(1.0, 8), (0.1, 40), (0.01, 75), (0.001, 95), (.0003, 120)]\n  elif (FLAGS.model_architecture == 'resnet' or\n        FLAGS.model_architecture.startswith('vgg')):\n    LR_SCHEDULE = [(1.0, 0), (0.1, 30), (0.01, 70), (0.001, 90), (.0001, 120)]\n  else:\n    raise ValueError('Unknown architecture ' + FLAGS.model_architecture)\n  if FLAGS.training_steps_multiplier != 1.0:\n    multiplier = FLAGS.training_steps_multiplier\n    LR_SCHEDULE = [(x, y * multiplier) for x, y in LR_SCHEDULE]\n    FLAGS.train_steps = int(FLAGS.train_steps * multiplier)\n    FLAGS.maskupdate_begin_step = int(FLAGS.maskupdate_begin_step * multiplier)\n    FLAGS.maskupdate_end_step = int(FLAGS.maskupdate_end_step * multiplier)\n    tf.logging.info(\n        'Training schedule is updated with multiplier: %.2f' % multiplier)\n  tf.logging.info('LR schedule: %s' % LR_SCHEDULE)\n  tf.logging.info('Training Steps: %d' % FLAGS.train_steps)\n# The input tensor is in the range of [0, 255], we need to scale them to the\n# range of [0, 1]\nMEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]\nSTDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]\n\nCUSTOM_SPARSITY_MAP = {}\n\n\ndef set_custom_sparsity_map():\n  if FLAGS.first_layer_sparsity > 0.:\n    CUSTOM_SPARSITY_MAP[\n        'resnet_model/initial_conv'] = FLAGS.first_layer_sparsity\n  if FLAGS.last_layer_sparsity > 0.:\n    CUSTOM_SPARSITY_MAP[\n        'resnet_model/final_dense'] = FLAGS.last_layer_sparsity\n\n\ndef lr_schedule(current_epoch):\n  \"\"\"Computes learning rate schedule.\"\"\"\n  scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)\n  if FLAGS.use_sgdr:\n    decay_rate = tf.train.cosine_decay_restarts(\n        scaled_lr, current_epoch, FLAGS.sgdr_decay_step,\n        t_mul=FLAGS.sgdr_t_mul, m_mul=FLAGS.sgdr_m_mul)\n  else:\n    decay_rate = (\n        scaled_lr * LR_SCHEDULE[0][0] * current_epoch / LR_SCHEDULE[0][1])\n    for mult, start_epoch in LR_SCHEDULE:\n      decay_rate = tf.where(current_epoch < start_epoch, decay_rate,\n                            scaled_lr * mult)\n  return decay_rate\n\n\ndef train_function(training_method, loss, cross_loss, reg_loss, output_dir,\n                   use_tpu):\n  \"\"\"Training script for resnet model.\n\n  Args:\n   training_method: string indicating pruning method used to compress model.\n   loss: tensor float32 of the cross entropy + regularization losses.\n   cross_loss: tensor, only cross entropy loss, passed for logging.\n   reg_loss: tensor, only regularization loss, passed for logging.\n   output_dir: string tensor indicating the directory to save summaries.\n   use_tpu: boolean indicating whether to run script on a tpu.\n\n  Returns:\n    host_call: summary tensors to be computed at each training step.\n    train_op: the optimization term.\n  \"\"\"\n\n  global_step = tf.train.get_global_step()\n\n  steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size\n  current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)\n  learning_rate = lr_schedule(current_epoch)\n  if FLAGS.use_adam:\n    # We don't use step decrease for the learning rate.\n    learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)\n    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n  else:\n    optimizer = tf.train.MomentumOptimizer(\n        learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True)\n\n  if use_tpu:\n    # use CrossShardOptimizer when using TPU.\n    optimizer = contrib_tpu.CrossShardOptimizer(optimizer)\n\n  if training_method == 'set':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseSETOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal,\n        stateless_seed_offset=FLAGS.seed)\n  elif training_method == 'static':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseStaticOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal,\n        stateless_seed_offset=FLAGS.seed)\n  elif training_method == 'momentum':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseMomentumOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        grow_init=FLAGS.grow_init, stateless_seed_offset=FLAGS.seed,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=use_tpu)\n  elif training_method == 'rigl':\n    # We override the train op to also update the mask.\n    optimizer = sparse_optimizers.SparseRigLOptimizer(\n        optimizer, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency,\n        drop_fraction=FLAGS.drop_fraction, stateless_seed_offset=FLAGS.seed,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal,\n        initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=use_tpu)\n\n  elif training_method == 'snip':\n    optimizer = sparse_optimizers.SparseSnipOptimizer(\n        optimizer, mask_init_method=FLAGS.mask_init_method,\n        custom_sparsity_map=CUSTOM_SPARSITY_MAP,\n        default_sparsity=FLAGS.end_sparsity, use_tpu=use_tpu)\n  elif training_method == 'dnw':\n    optimizer = sparse_optimizers.SparseDNWOptimizer(\n        optimizer,\n        mask_init_method=FLAGS.mask_init_method,\n        custom_sparsity_map=CUSTOM_SPARSITY_MAP,\n        default_sparsity=FLAGS.end_sparsity,\n        use_tpu=use_tpu)\n  elif training_method in ('scratch', 'baseline'):\n    pass\n  else:\n    raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)\n  # UPDATE_OPS needs to be added as a dependency due to batch norm\n  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n  with tf.control_dependencies(update_ops), tf.name_scope('train'):\n    grads_and_vars = optimizer.compute_gradients(loss)\n\n    vars_with_grad = [v for g, v in grads_and_vars if g is not None]\n    if not vars_with_grad:\n      raise ValueError(\n          'No gradients provided for any variable, check your graph for ops'\n          ' that do not support gradients, between variables %s and loss %s.' %\n          ([str(v) for _, v in grads_and_vars], loss))\n\n    train_op = optimizer.apply_gradients(\n        grads_and_vars, global_step=global_step)\n  metrics = {\n      'global_step': tf.train.get_or_create_global_step(),\n      'loss': loss,\n      'cross_loss': cross_loss,\n      'reg_loss': reg_loss,\n      'learning_rate': learning_rate,\n      'current_epoch': current_epoch,\n  }\n\n  # Logging drop_fraction if dynamic sparse training.\n  is_dst_method = training_method in DST_METHODS\n  if is_dst_method:\n    metrics['drop_fraction'] = optimizer.drop_fraction\n\n  def flatten_list_of_vars(var_list):\n    flat_vars = [tf.reshape(v, [-1]) for v in var_list]\n    return tf.concat(flat_vars, axis=-1)\n\n  if use_tpu:\n    reduced_grads = [tf.tpu.cross_replica_sum(g) for g, _ in grads_and_vars]\n  else:\n    reduced_grads = [g for g, _ in grads_and_vars]\n  metrics['grad_norm'] = tf.norm(flatten_list_of_vars(reduced_grads))\n  metrics['var_norm'] = tf.norm(\n      flatten_list_of_vars([v for _, v in grads_and_vars]))\n  # Let's log some statistics from a single parameter-mask couple.\n  # This is useful for debugging.\n  test_var = pruning.get_weights()[0]\n  test_var_mask = pruning.get_masks()[0]\n  metrics.update({\n      'fw_nz_weight': tf.count_nonzero(test_var),\n      'fw_nz_mask': tf.count_nonzero(test_var_mask),\n      'fw_l1_weight': tf.reduce_sum(tf.abs(test_var))\n  })\n\n  masks = pruning.get_masks()\n  global_sparsity = sparse_utils.calculate_sparsity(masks)\n  metrics['global_sparsity'] = global_sparsity\n  metrics.update(\n      utils.mask_summaries(masks, with_img=FLAGS.log_mask_imgs_each_iteration))\n\n  host_call = (functools.partial(utils.host_call_fn, output_dir),\n               utils.format_tensors(metrics))\n\n  return host_call, train_op\n\n\ndef resnet_model_fn_w_pruning(features, labels, mode, params):\n  \"\"\"The model_fn for ResNet-50 with pruning.\n\n  Args:\n    features: A float32 batch of images.\n    labels: A int32 batch of labels.\n    mode: Specifies whether training or evaluation.\n    params: Dictionary of parameters passed to the model.\n\n  Returns:\n    A TPUEstimatorSpec for the model\n  \"\"\"\n\n  width = 1. if FLAGS.width <= 0 else FLAGS.width\n\n  if isinstance(features, dict):\n    features = features['feature']\n\n  if FLAGS.data_format == 'channels_first':\n    assert not FLAGS.transpose_input  # channels_first only for GPU\n    features = tf.transpose(features, [0, 3, 1, 2])\n\n  if FLAGS.transpose_input and mode != tf_estimator.ModeKeys.PREDICT:\n    features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC\n\n  # Normalize the image to zero mean and unit variance.\n  features -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=features.dtype)\n  features /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=features.dtype)\n\n  training_method = params['training_method']\n  use_tpu = params['use_tpu']\n\n  def build_network():\n    \"\"\"Construct the network in the graph.\"\"\"\n    if FLAGS.model_architecture == 'mobilenet_v2':\n      network_func = functools.partial(\n          mobilenetv2_model.mobilenet_v2,\n          expansion_factor=FLAGS.expansion_factor)\n    elif FLAGS.model_architecture == 'mobilenet_v1':\n      network_func = functools.partial(mobilenetv1_model.mobilenet_v1)\n    elif FLAGS.model_architecture == 'resnet':\n      prune_first_layer = FLAGS.first_layer_sparsity != 0.\n      network_func = functools.partial(\n          resnet_model.resnet_v1_,\n          resnet_depth=FLAGS.resnet_depth,\n          init_method=FLAGS.init_method,\n          end_sparsity=FLAGS.end_sparsity,\n          prune_first_layer=prune_first_layer)\n    elif FLAGS.model_architecture.startswith('vgg'):\n      network_func = functools.partial(\n          vgg.vgg,\n          vgg_type=FLAGS.model_architecture,\n          init_method=FLAGS.init_method,\n          end_sparsity=FLAGS.end_sparsity)\n    else:\n      raise ValueError('Unknown archiecture ' + FLAGS.archiecture)\n    prune_last_layer = FLAGS.last_layer_sparsity != 0.\n    network = network_func(\n        num_classes=FLAGS.num_label_classes,\n        # TODO remove the pruning_method option.\n        pruning_method='threshold',\n        width=width,\n        prune_last_layer=prune_last_layer,\n        data_format=FLAGS.data_format,\n        weight_decay=FLAGS.weight_decay)\n\n    is_training = (mode == tf_estimator.ModeKeys.TRAIN)\n    if FLAGS.use_batch_statistics:\n      is_training = True\n    return network(inputs=features, is_training=is_training)\n\n  if FLAGS.precision == 'bfloat16':\n    with contrib_tpu.bfloat16_scope():\n      logits = build_network()\n    logits = tf.cast(logits, tf.float32)\n  elif FLAGS.precision == 'float32':\n    logits = build_network()\n\n  if mode == tf_estimator.ModeKeys.PREDICT:\n    predictions = {\n        'classes': tf.argmax(logits, axis=1),\n        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')\n    }\n    return tf_estimator.EstimatorSpec(\n        mode=mode,\n        predictions=predictions,\n        export_outputs={\n            'classify': tf_estimator.export.PredictOutput(predictions)\n        })\n  output_dir = params['output_dir']\n  # Calculate loss, which includes softmax cross entropy and L2 regularization.\n  one_hot_labels = tf.one_hot(labels, FLAGS.num_label_classes)\n\n  # make sure we reuse the same label smoothing parameter is we're doing\n  # scratch / lottery ticket experiments.\n  label_smoothing = FLAGS.label_smoothing\n  if FLAGS.training_method == 'scratch' and FLAGS.load_mask_dir:\n    scratch_stripped = FLAGS.load_mask_dir.replace('/scratch', '')\n    label_smoothing = float(scratch_stripped.split('/')[15])\n    tf.logging.info('LABEL SMOOTHING USED: %.2f' % label_smoothing)\n  cross_loss = tf.losses.softmax_cross_entropy(\n      logits=logits,\n      onehot_labels=one_hot_labels,\n      label_smoothing=label_smoothing)\n  # Add regularization loss term\n  reg_loss = tf.losses.get_regularization_loss()\n  loss = cross_loss + reg_loss\n\n  host_call = None\n  if mode == tf_estimator.ModeKeys.TRAIN:\n    host_call, train_op = train_function(training_method, loss, cross_loss,\n                                         reg_loss, output_dir, use_tpu)\n  else:\n    train_op = None\n\n  eval_metrics = None\n  if mode == tf_estimator.ModeKeys.EVAL:\n\n    def metric_fn(labels, logits, cross_loss, reg_loss):\n      \"\"\"Calculate eval metrics.\"\"\"\n      logging.info('In metric function')\n      eval_metrics = {}\n      predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)\n      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)\n      eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)\n      eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss)\n      eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss)\n      eval_metrics['eval_accuracy'] = tf.metrics.accuracy(\n          labels=labels, predictions=predictions)\n\n      # If evaluating once lets also calculate sparsities.\n      if FLAGS.mode == 'eval_once':\n        sparsity_summaries = utils.mask_summaries(pruning.get_masks())\n        # We call mean on a scalar to create tensor, update_op pairs.\n        sparsity_summaries = {k: tf.metrics.mean(v) for k, v\n                              in sparsity_summaries.items()}\n        eval_metrics.update(sparsity_summaries)\n      return eval_metrics\n\n    tensors = [labels, logits,\n               tf.broadcast_to(cross_loss, tf.shape(labels)),\n               tf.broadcast_to(reg_loss, tf.shape(labels))]\n\n    eval_metrics = (metric_fn, tensors)\n\n  if (FLAGS.load_mask_dir and\n      FLAGS.training_method not in NO_MASK_INIT_METHODS):\n\n    def scaffold_fn():\n      \"\"\"For initialization, passed to the estimator.\"\"\"\n      utils.initialize_parameters_from_ckpt(FLAGS.load_mask_dir,\n                                            FLAGS.output_dir, MASK_SUFFIX)\n      if FLAGS.initial_value_checkpoint:\n        utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,\n                                              FLAGS.output_dir, PARAM_SUFFIXES)\n      return tf.train.Scaffold()\n  elif (FLAGS.mask_init_method and\n        FLAGS.training_method not in NO_MASK_INIT_METHODS):\n\n    def scaffold_fn():\n      \"\"\"For initialization, passed to the estimator.\"\"\"\n      if FLAGS.initial_value_checkpoint:\n        utils.initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint,\n                                              FLAGS.output_dir, PARAM_SUFFIXES)\n      all_masks = pruning.get_masks()\n      assigner = sparse_utils.get_mask_init_fn(\n          all_masks,\n          FLAGS.mask_init_method,\n          FLAGS.end_sparsity,\n          CUSTOM_SPARSITY_MAP,\n          erk_power_scale=FLAGS.erk_power_scale)\n      def init_fn(scaffold, session):\n        \"\"\"A callable for restoring variable from a checkpoint.\"\"\"\n        del scaffold  # Unused.\n        session.run(assigner)\n      return tf.train.Scaffold(init_fn=init_fn)\n  else:\n    assert FLAGS.training_method in NO_MASK_INIT_METHODS\n    scaffold_fn = None\n    tf.logging.info('No mask is set, starting dense.')\n\n  return contrib_tpu.TPUEstimatorSpec(\n      mode=mode,\n      loss=loss,\n      train_op=train_op,\n      host_call=host_call,\n      eval_metrics=eval_metrics,\n      scaffold_fn=scaffold_fn)\n\n\nclass ExportModelHook(tf.train.SessionRunHook):\n  \"\"\"Train hooks called after each session run for exporting the model.\"\"\"\n\n  def __init__(self, classifier, export_dir):\n    self.classifier = classifier\n    self.global_step = None\n    self.export_dir = export_dir\n    self.last_export = 0\n    self.supervised_input_receiver_fn = (\n        contrib_estimator.build_raw_supervised_input_receiver_fn(\n            {\n                'feature':\n                    tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3])\n            }, tf.placeholder(dtype=tf.int32, shape=[None])))\n\n  def begin(self):\n    self.global_step = tf.train.get_or_create_global_step()\n\n  def after_run(self, run_context, run_values):\n    # export saved model\n    global_step = run_context.session.run(self.global_step)\n    if global_step - self.last_export >= FLAGS.export_model_freq:\n      tf.logging.info(\n          'Export model for prediction (step={}) ...'.format(global_step))\n\n      self.last_export = global_step\n      contrib_estimator.export_all_saved_models(\n          self.classifier, os.path.join(self.export_dir, str(global_step)), {\n              tf_estimator.ModeKeys.EVAL:\n                  self.supervised_input_receiver_fn,\n              tf_estimator.ModeKeys.PREDICT:\n                  imagenet_input.image_serving_input_fn\n          })\n\n\ndef main(argv):\n  del argv  # Unused.\n\n  tf.enable_resource_variables()\n  tf.set_random_seed(FLAGS.seed)\n  set_lr_schedule()\n  set_custom_sparsity_map()\n  folder_stub = os.path.join(FLAGS.training_method, str(FLAGS.end_sparsity),\n                             str(FLAGS.maskupdate_begin_step),\n                             str(FLAGS.maskupdate_end_step),\n                             str(FLAGS.maskupdate_frequency),\n                             str(FLAGS.drop_fraction),\n                             str(FLAGS.label_smoothing),\n                             str(FLAGS.weight_decay))\n\n  output_dir = FLAGS.output_dir\n  if FLAGS.use_folder_stub:\n    output_dir = os.path.join(output_dir, folder_stub)\n\n  export_dir = os.path.join(output_dir, 'export_dir')\n\n  # we pass the updated eval and train string to the params dictionary.\n  params = {}\n  params['output_dir'] = output_dir\n  params['training_method'] = FLAGS.training_method\n  params['use_tpu'] = FLAGS.use_tpu\n\n  dataset_func = functools.partial(\n      imagenet_input.ImageNetInput, data_dir=FLAGS.data_directory,\n      transpose_input=False, num_parallel_calls=FLAGS.num_parallel_calls,\n      use_bfloat16=False)\n  imagenet_train, imagenet_eval = [dataset_func(is_training=is_training)\n                                   for is_training in [True, False]]\n\n  run_config = tpu_config.RunConfig(\n      master=FLAGS.master,\n      model_dir=output_dir,\n      save_checkpoints_steps=FLAGS.steps_per_checkpoint,\n      keep_checkpoint_max=FLAGS.keep_checkpoint_max,\n      session_config=tf.ConfigProto(\n          allow_soft_placement=True, log_device_placement=False),\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=FLAGS.iterations_per_loop,\n          num_shards=FLAGS.num_cores,\n          tpu_job_name=FLAGS.tpu_job_name))\n\n  classifier = tpu_estimator.TPUEstimator(\n      use_tpu=FLAGS.use_tpu,\n      model_fn=resnet_model_fn_w_pruning,\n      params=params,\n      config=run_config,\n      train_batch_size=FLAGS.train_batch_size,\n      eval_batch_size=FLAGS.eval_batch_size)\n\n  cpu_classifier = tpu_estimator.TPUEstimator(\n      use_tpu=FLAGS.use_tpu,\n      model_fn=resnet_model_fn_w_pruning,\n      params=params,\n      config=run_config,\n      train_batch_size=FLAGS.train_batch_size,\n      export_to_tpu=False,\n      eval_batch_size=FLAGS.eval_batch_size)\n\n  if FLAGS.num_eval_images % FLAGS.eval_batch_size != 0:\n    raise ValueError(\n        'eval_batch_size (%d) must evenly divide num_eval_images(%d)!' %\n        (FLAGS.eval_batch_size, FLAGS.num_eval_images))\n\n  eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size\n  if FLAGS.mode == 'eval_once':\n    ckpt_path = os.path.join(output_dir, FLAGS.eval_once_ckpt_prefix)\n    dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval\n    classifier.evaluate(\n        input_fn=dataset.input_fn,\n        steps=eval_steps,\n        checkpoint_path=ckpt_path,\n        name='{0}'.format(FLAGS.eval_once_ckpt_prefix))\n  elif FLAGS.mode == 'eval':\n    # Run evaluation when there's a new checkpoint\n    for ckpt in evaluation.checkpoints_iterator(output_dir):\n      tf.logging.info('Starting to evaluate.')\n      try:\n        dataset = imagenet_train if FLAGS.eval_on_train else imagenet_eval\n        classifier.evaluate(\n            input_fn=dataset.input_fn,\n            steps=eval_steps,\n            checkpoint_path=ckpt,\n            name='eval')\n        # Terminate eval job when final checkpoint is reached\n        global_step = int(os.path.basename(ckpt).split('-')[1])\n        if global_step >= FLAGS.train_steps:\n          tf.logging.info(\n              'Evaluation finished after training step %d' % global_step)\n          break\n\n      except tf.errors.NotFoundError:\n        logging('Checkpoint no longer exists,skipping checkpoint.')\n\n  else:\n    global_step = estimator._load_global_step_from_checkpoint_dir(output_dir)\n    # Session run hooks to export model for prediction\n    export_hook = ExportModelHook(cpu_classifier, export_dir)\n    hooks = [export_hook]\n\n    if FLAGS.mode == 'train':\n      tf.logging.info('start training...')\n      classifier.train(\n          input_fn=imagenet_train.input_fn,\n          hooks=hooks,\n          max_steps=FLAGS.train_steps)\n    else:\n      assert FLAGS.mode == 'train_and_eval'\n      tf.logging.info('start training and eval...')\n      while global_step < FLAGS.train_steps:\n        next_checkpoint = min(global_step + FLAGS.steps_per_eval,\n                              FLAGS.train_steps)\n        classifier.train(\n            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)\n        global_step = next_checkpoint\n        logging('Completed training up to step :', global_step)\n        classifier.evaluate(input_fn=imagenet_eval.input_fn, steps=eval_steps)\n\n\nif __name__ == '__main__':\n  app.run(main)\n"
  },
  {
    "path": "rigl/imagenet_resnet/mobilenetv1_model.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Straightforward MobileNet v1 for inputs of size 224x224.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nfrom absl import flags\nfrom rigl.imagenet_resnet import resnet_model\nfrom rigl.imagenet_resnet.pruning_layers import sparse_conv2d\nfrom rigl.imagenet_resnet.pruning_layers import sparse_fully_connected\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib import layers as contrib_layers\n\nFLAGS = flags.FLAGS\n\n\ndef _make_divisible(v, divisor=8, min_value=None):\n  if min_value is None:\n    min_value = divisor\n  new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n  # Make sure that round down does not go down by more than 10%.\n  if new_v < 0.9 * v:\n    new_v += divisor\n  return new_v\n\n\ndef depthwise_conv2d_fixed_padding(inputs,\n                                   kernel_size,\n                                   stride,\n                                   data_format='channels_first',\n                                   name=None):\n  \"\"\"Depthwise Strided 2-D convolution with explicit padding.\n\n  The padding is consistent and is based only on `kernel_size`, not on the\n  dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    kernel_size: Int designating size of kernel to be used in the convolution.\n    stride: Int specifying the stride. If stride >1, the input is downsampled.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    name: String that specifies name for model layer.\n\n  Returns:\n    The output activation tensor of size [batch, filters, height_out, width_out]\n\n  Raises:\n    ValueError: If the data_format provided is not a valid string.\n  \"\"\"\n  if stride > 1:\n    inputs = resnet_model.fixed_padding(\n        inputs, kernel_size, data_format=data_format)\n  padding = 'SAME' if stride == 1 else 'VALID'\n\n  if data_format == 'channels_last':\n    data_format_channels = 'NHWC'\n  elif data_format == 'channels_first':\n    data_format_channels = 'NCHW'\n  else:\n    raise ValueError('Not a valid channel string:', data_format)\n\n  return contrib_layers.separable_conv2d(\n      inputs=inputs,\n      num_outputs=None,\n      kernel_size=kernel_size,\n      stride=stride,\n      padding=padding,\n      data_format=data_format_channels,\n      activation_fn=None,\n      weights_regularizer=None,\n      biases_initializer=None,\n      biases_regularizer=None,\n      scope=name)\n\n\ndef conv2d_fixed_padding(inputs,\n                         filters,\n                         kernel_size,\n                         strides,\n                         pruning_method='baseline',\n                         data_format='channels_first',\n                         weight_decay=0.,\n                         name=None):\n  \"\"\"Strided 2-D convolution with explicit padding.\n\n  The padding is consistent and is based only on `kernel_size`, not on the\n  dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    filters: Int specifying number of filters for the first two convolutions.\n    kernel_size: Int designating size of kernel to be used in the convolution.\n    strides: Int specifying the stride. If stride >1, the input is downsampled.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies name for model layer.\n\n  Returns:\n    The output activation tensor of size [batch, filters, height_out, width_out]\n\n  Raises:\n    ValueError: If the data_format provided is not a valid string.\n  \"\"\"\n  if strides > 1:\n    inputs = resnet_model.fixed_padding(\n        inputs, kernel_size, data_format=data_format)\n    padding = 'VALID'\n  else:\n    padding = 'SAME'\n\n  kernel_initializer = tf.variance_scaling_initializer()\n\n  kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n  return sparse_conv2d(\n      x=inputs,\n      units=filters,\n      activation=None,\n      kernel_size=[kernel_size, kernel_size],\n      use_bias=False,\n      kernel_initializer=kernel_initializer,\n      kernel_regularizer=kernel_regularizer,\n      bias_initializer=None,\n      biases_regularizer=None,\n      sparsity_technique=pruning_method,\n      normalizer_fn=None,\n      strides=[strides, strides],\n      padding=padding,\n      data_format=data_format,\n      name=name)\n\n\ndef mbv1_block_(inputs,\n                filters,\n                is_training,\n                stride,\n                width=1.,\n                block_id=0,\n                pruning_method='baseline',\n                data_format='channels_first',\n                weight_decay=0.):\n  \"\"\"Standard building block for mobilenetv1 networks.\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    filters: Int specifying number of filters for the first two convolutions.\n    is_training: Boolean specifying whether the model is training.\n    stride: Int specifying the stride. If stride >1, the input is downsampled.\n    width: multiplier for channel dimensions\n    block_id: which block this is\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    weight_decay: Weight for the l2 regularization loss.\n\n  Returns:\n    The output activation tensor.\n  \"\"\"\n\n  # separable_conv_2d followed by contracting 1x1 conv.\n\n  end_point = 'depthwise_nxn_%s' % block_id\n  # Depthwise\n  depthwise_out = depthwise_conv2d_fixed_padding(\n      inputs=inputs,\n      kernel_size=3,\n      stride=stride,\n      data_format=data_format,\n      name=end_point)\n\n  depthwise_out = resnet_model.batch_norm_relu(\n      depthwise_out, is_training, relu=True, data_format=data_format)\n\n  # Contraction\n  end_point = 'contraction_1x1_%s' % block_id\n  divisible_by = 8\n  if block_id == 0:\n    divisible_by = 1\n  out_filters = _make_divisible(int(width * filters), divisor=divisible_by)\n\n  contraction_out = conv2d_fixed_padding(\n      inputs=depthwise_out,\n      filters=out_filters,\n      kernel_size=1,\n      strides=1,\n      pruning_method=pruning_method,\n      data_format=data_format,\n      weight_decay=weight_decay,\n      name=end_point)\n  contraction_out = resnet_model.batch_norm_relu(\n      contraction_out, is_training, relu=True, data_format=data_format)\n\n  output = contraction_out\n  return output\n\n\ndef mobilenet_v1_generator(num_classes=1000,\n                           pruning_method='baseline',\n                           width=1.,\n                           prune_last_layer=False,\n                           data_format='channels_first',\n                           weight_decay=0.,\n                           name=None):\n  \"\"\"Generator for mobilenet v2 models.\n\n  Args:\n    num_classes: Int number of possible classes for image classification.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    width: Float that scales the number of filters in each layer.\n    prune_last_layer: Whether or not to prune the last layer.\n    data_format: String either \"channels_first\" for `[batch, channels, height,\n      width]` or \"channels_last for `[batch, height, width, channels]`.\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies name for model layer.\n\n  Returns:\n    Model `function` that takes in `inputs` and `is_training` and returns the\n    output `Tensor` of the ResNet model.\n  \"\"\"\n\n  def model(inputs, is_training):\n    \"\"\"Creation of the model graph.\"\"\"\n    with tf.variable_scope(name, 'resnet_model'):\n      inputs = resnet_model.fixed_padding(\n          inputs, kernel_size=3, data_format=data_format)\n      padding = 'VALID'\n\n      kernel_initializer = tf.variance_scaling_initializer()\n      kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n\n      inputs = tf.layers.conv2d(\n          inputs=inputs,\n          filters=_make_divisible(32 * width),\n          kernel_size=3,\n          strides=2,\n          padding=padding,\n          use_bias=False,\n          kernel_initializer=kernel_initializer,\n          kernel_regularizer=kernel_regularizer,\n          data_format=data_format,\n          name='initial_conv')\n\n      inputs = tf.identity(inputs, 'initial_conv')\n      inputs = resnet_model.batch_norm_relu(\n          inputs, is_training, data_format=data_format)\n\n      mb_block = functools.partial(\n          mbv1_block_,\n          is_training=is_training,\n          width=width,\n          pruning_method=pruning_method,\n          data_format=data_format,\n          weight_decay=weight_decay)\n\n      inputs = mb_block(inputs, filters=64, stride=1, block_id=0)\n\n      inputs = mb_block(inputs, filters=128, stride=2, block_id=1)\n      inputs = mb_block(inputs, filters=128, stride=1, block_id=2)\n\n      inputs = mb_block(inputs, filters=256, stride=2, block_id=3)\n      inputs = mb_block(inputs, filters=256, stride=1, block_id=4)\n\n      inputs = mb_block(inputs, filters=512, stride=2, block_id=5)\n      inputs = mb_block(inputs, filters=512, stride=1, block_id=6)\n      inputs = mb_block(inputs, filters=512, stride=1, block_id=7)\n      inputs = mb_block(inputs, filters=512, stride=1, block_id=8)\n      inputs = mb_block(inputs, filters=512, stride=1, block_id=9)\n      inputs = mb_block(inputs, filters=512, stride=1, block_id=10)\n\n      inputs = mb_block(inputs, filters=1024, stride=2, block_id=11)\n      inputs = mb_block(inputs, filters=1024, stride=1, block_id=12)\n\n      last_block_filters = _make_divisible(int(1024 * width), 8)\n\n      if data_format == 'channels_last':\n        pool_size = (inputs.shape[1], inputs.shape[2])\n      elif data_format == 'channels_first':\n        pool_size = (inputs.shape[2], inputs.shape[3])\n\n      inputs = tf.layers.average_pooling2d(\n          inputs=inputs,\n          pool_size=pool_size,\n          strides=1,\n          padding='VALID',\n          data_format=data_format,\n          name='final_avg_pool')\n      inputs = tf.identity(inputs, 'final_avg_pool')\n      inputs = tf.reshape(inputs, [-1, last_block_filters])\n\n      kernel_initializer = tf.variance_scaling_initializer()\n      kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n      if prune_last_layer:\n        inputs = sparse_fully_connected(\n            x=inputs,\n            units=num_classes,\n            sparsity_technique=pruning_method\n            if prune_last_layer else 'baseline',\n            kernel_initializer=kernel_initializer,\n            kernel_regularizer=kernel_regularizer,\n            name='final_dense')\n      else:\n        inputs = tf.layers.dense(\n            inputs=inputs,\n            units=num_classes,\n            activation=None,\n            use_bias=True,\n            kernel_initializer=kernel_initializer,\n            kernel_regularizer=kernel_regularizer,\n            name='final_dense')\n\n      inputs = tf.identity(inputs, 'final_dense')\n    return inputs\n\n  model.default_image_size = 224\n  return model\n\n\ndef mobilenet_v1(num_classes,\n                 pruning_method='baseline',\n                 width=1.,\n                 prune_last_layer=True,\n                 data_format='channels_first',\n                 weight_decay=0.):\n  \"\"\"Returns the mobilenet_V1 model for a given size and number of output classes.\n\n  Args:\n    num_classes: Int number of possible classes for image classification.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    width: Float multiplier of the number of filters in each layer.\n    prune_last_layer: Whether or not to prune the last layer.\n    data_format: String specifying either \"channels_first\" for `[batch,\n      channels, height, width]` or \"channels_last for `[batch, height, width,\n      channels]`.\n    weight_decay: Weight for the l2 regularization loss.\n\n  Raises:\n    ValueError: If the resnet_depth int is not in the model_params dictionary.\n  \"\"\"\n  return mobilenet_v1_generator(num_classes, pruning_method, width,\n                                prune_last_layer, data_format, weight_decay)\n"
  },
  {
    "path": "rigl/imagenet_resnet/mobilenetv2_model.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Straightforward MobileNet v2 for inputs of size 224x224.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nfrom absl import flags\nfrom rigl.imagenet_resnet import resnet_model\nfrom rigl.imagenet_resnet.pruning_layers import sparse_conv2d\nfrom rigl.imagenet_resnet.pruning_layers import sparse_fully_connected\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib import layers as contrib_layers\n\nFLAGS = flags.FLAGS\n\n\ndef _make_divisible(v, divisor=8, min_value=None):\n  if min_value is None:\n    min_value = divisor\n  new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n  # Make sure that round down does not go down by more than 10%.\n  if new_v < 0.9 * v:\n    new_v += divisor\n  return new_v\n\n\ndef depthwise_conv2d_fixed_padding(inputs,\n                                   kernel_size,\n                                   stride,\n                                   data_format='channels_first',\n                                   name=None):\n  \"\"\"Depthwise Strided 2-D convolution with explicit padding.\n\n  The padding is consistent and is based only on `kernel_size`, not on the\n  dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    kernel_size: Int designating size of kernel to be used in the convolution.\n    stride: Int specifying the stride. If stride >1, the input is downsampled.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    name: String that specifies name for model layer.\n\n  Returns:\n    The output activation tensor of size [batch, filters, height_out, width_out]\n\n  Raises:\n    ValueError: If the data_format provided is not a valid string.\n  \"\"\"\n  if stride > 1:\n    inputs = resnet_model.fixed_padding(\n        inputs, kernel_size, data_format=data_format)\n  padding = 'SAME' if stride == 1 else 'VALID'\n\n  if data_format == 'channels_last':\n    data_format_channels = 'NHWC'\n  elif data_format == 'channels_first':\n    data_format_channels = 'NCHW'\n  else:\n    raise ValueError('Not a valid channel string:', data_format)\n\n  return contrib_layers.separable_conv2d(\n      inputs=inputs,\n      num_outputs=None,\n      kernel_size=kernel_size,\n      stride=stride,\n      padding=padding,\n      data_format=data_format_channels,\n      activation_fn=None,\n      weights_regularizer=None,\n      biases_initializer=None,\n      biases_regularizer=None,\n      scope=name)\n\n\ndef conv2d_fixed_padding(inputs,\n                         filters,\n                         kernel_size,\n                         strides,\n                         pruning_method='baseline',\n                         data_format='channels_first',\n                         weight_decay=0.,\n                         name=None):\n  \"\"\"Strided 2-D convolution with explicit padding.\n\n  The padding is consistent and is based only on `kernel_size`, not on the\n  dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    filters: Int specifying number of filters for the first two convolutions.\n    kernel_size: Int designating size of kernel to be used in the convolution.\n    strides: Int specifying the stride. If stride >1, the input is downsampled.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies name for model layer.\n\n  Returns:\n    The output activation tensor of size [batch, filters, height_out, width_out]\n\n  Raises:\n    ValueError: If the data_format provided is not a valid string.\n  \"\"\"\n  if strides > 1:\n    inputs = resnet_model.fixed_padding(\n        inputs, kernel_size, data_format=data_format)\n    padding = 'VALID'\n  else:\n    padding = 'SAME'\n\n  kernel_initializer = tf.variance_scaling_initializer()\n\n  kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n  return sparse_conv2d(\n      x=inputs,\n      units=filters,\n      activation=None,\n      kernel_size=[kernel_size, kernel_size],\n      use_bias=False,\n      kernel_initializer=kernel_initializer,\n      kernel_regularizer=kernel_regularizer,\n      bias_initializer=None,\n      biases_regularizer=None,\n      sparsity_technique=pruning_method,\n      normalizer_fn=None,\n      strides=[strides, strides],\n      padding=padding,\n      data_format=data_format,\n      name=name)\n\n\ndef inverted_res_block_(inputs,\n                        filters,\n                        is_training,\n                        stride,\n                        width=1.,\n                        expansion_factor=6.,\n                        block_id=0,\n                        pruning_method='baseline',\n                        data_format='channels_first',\n                        weight_decay=0.,):\n  \"\"\"Standard building block for mobilenetv2 networks.\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    filters: Int specifying number of filters for the first two convolutions.\n    is_training: Boolean specifying whether the model is training.\n    stride: Int specifying the stride. If stride >1, the input is downsampled.\n    width: multiplier for channel dimensions\n    expansion_factor: How much to increase the filters before the depthwise\n      conv.\n    block_id: which block this is\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    weight_decay: Weight for the l2 regularization loss.\n\n  Returns:\n    The output activation tensor.\n  \"\"\"\n\n  # 1x1 expanded conv, followed by separable_conv_2d followed by\n  # contracting 1x1 conv.\n\n  shortcut = inputs\n\n  if data_format == 'channels_first':\n    prev_depth = inputs.get_shape().as_list()[1]\n  elif data_format == 'channels_last':\n    prev_depth = inputs.get_shape().as_list()[3]\n  else:\n    raise ValueError('Unknown data_format ' + data_format)\n\n  # Expand\n  multiplier = expansion_factor if block_id > 0 else 1\n  # skip the expansion if this is the first block\n  if block_id:\n    end_point = 'expand_1x1_%s' % block_id\n    inputs = conv2d_fixed_padding(\n        inputs=inputs,\n        filters=int(multiplier * prev_depth),\n        kernel_size=1,\n        strides=1,\n        pruning_method=pruning_method,\n        data_format=data_format,\n        weight_decay=weight_decay,\n        name=end_point)\n    inputs = resnet_model.batch_norm_relu(\n        inputs, is_training, relu=True, data_format=data_format)\n\n  end_point = 'depthwise_nxn_%s' % block_id\n  # Depthwise\n  depthwise_out = depthwise_conv2d_fixed_padding(\n      inputs=inputs,\n      kernel_size=3,\n      stride=stride,\n      data_format=data_format,\n      name=end_point)\n\n  depthwise_out = resnet_model.batch_norm_relu(\n      depthwise_out, is_training, relu=True, data_format=data_format)\n\n  # Contraction\n  end_point = 'contraction_1x1_%s' % block_id\n  divisible_by = 8\n  if block_id == 0:\n    divisible_by = 1\n  out_filters = _make_divisible(int(width * filters), divisor=divisible_by)\n\n  contraction_out = conv2d_fixed_padding(\n      inputs=depthwise_out,\n      filters=out_filters,\n      kernel_size=1,\n      strides=1,\n      pruning_method=pruning_method,\n      data_format=data_format,\n      weight_decay=weight_decay,\n      name=end_point)\n  contraction_out = resnet_model.batch_norm_relu(\n      contraction_out, is_training, relu=False, data_format=data_format)\n\n  output = contraction_out\n  if prev_depth == out_filters and stride == 1:\n    output += shortcut\n  return output\n\n\ndef mobilenet_v2_generator(num_classes=1000,\n                           pruning_method='baseline',\n                           width=1.,\n                           expansion_factor=6.,\n                           prune_last_layer=False,\n                           data_format='channels_first',\n                           weight_decay=0.,\n                           name=None):\n  \"\"\"Generator for mobilenet v2 models.\n\n  Args:\n    num_classes: Int number of possible classes for image classification.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    width: Float that scales the number of filters in each layer.\n    expansion_factor: How much to expand the input filters for the depthwise\n      conv.\n    prune_last_layer: Whether or not to prune the last layer.\n    data_format: String either \"channels_first\" for `[batch, channels, height,\n      width]` or \"channels_last for `[batch, height, width, channels]`.\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies name for model layer.\n\n  Returns:\n    Model `function` that takes in `inputs` and `is_training` and returns the\n    output `Tensor` of the ResNet model.\n  \"\"\"\n\n  def model(inputs, is_training):\n    \"\"\"Creation of the model graph.\"\"\"\n    with tf.variable_scope(name, 'resnet_model'):\n      inputs = resnet_model.fixed_padding(\n          inputs, kernel_size=3, data_format=data_format)\n      padding = 'VALID'\n\n      kernel_initializer = tf.variance_scaling_initializer()\n      kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n\n      inputs = tf.layers.conv2d(\n          inputs=inputs,\n          filters=_make_divisible(32 * width),\n          kernel_size=3,\n          strides=2,\n          padding=padding,\n          use_bias=False,\n          kernel_initializer=kernel_initializer,\n          kernel_regularizer=kernel_regularizer,\n          data_format=data_format,\n          name='initial_conv')\n\n      inputs = tf.identity(inputs, 'initial_conv')\n      inputs = resnet_model.batch_norm_relu(\n          inputs, is_training, data_format=data_format)\n\n      inverted_res_block = functools.partial(\n          inverted_res_block_,\n          is_training=is_training,\n          width=width,\n          expansion_factor=expansion_factor,\n          pruning_method=pruning_method,\n          data_format=data_format,\n          weight_decay=weight_decay)\n\n      inputs = inverted_res_block(inputs, filters=16, stride=1, block_id=0)\n\n      inputs = inverted_res_block(inputs, filters=24, stride=2, block_id=1)\n      inputs = inverted_res_block(inputs, filters=24, stride=1, block_id=2)\n\n      inputs = inverted_res_block(inputs, filters=32, stride=2, block_id=3)\n      inputs = inverted_res_block(inputs, filters=32, stride=1, block_id=4)\n      inputs = inverted_res_block(inputs, filters=32, stride=1, block_id=5)\n\n      inputs = inverted_res_block(inputs, filters=64, stride=2, block_id=6)\n      inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=7)\n      inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=8)\n      inputs = inverted_res_block(inputs, filters=64, stride=1, block_id=9)\n\n      inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=10)\n      inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=11)\n      inputs = inverted_res_block(inputs, filters=96, stride=1, block_id=12)\n\n      inputs = inverted_res_block(inputs, filters=160, stride=2, block_id=13)\n      inputs = inverted_res_block(inputs, filters=160, stride=1, block_id=14)\n      inputs = inverted_res_block(inputs, filters=160, stride=1, block_id=15)\n\n      inputs = inverted_res_block(inputs, filters=320, stride=1, block_id=16)\n\n      last_block_filters = max(1280, _make_divisible(1280 * width, 8))\n\n      inputs = conv2d_fixed_padding(\n          inputs=inputs,\n          filters=last_block_filters,\n          kernel_size=1,\n          strides=1,\n          pruning_method=pruning_method,\n          data_format=data_format,\n          weight_decay=weight_decay,\n          name='final_1x1_conv')\n\n      inputs = resnet_model.batch_norm_relu(\n          inputs, is_training, data_format=data_format)\n\n      if data_format == 'channels_last':\n        pool_size = (inputs.shape[1], inputs.shape[2])\n      elif data_format == 'channels_first':\n        pool_size = (inputs.shape[2], inputs.shape[3])\n\n      inputs = tf.layers.average_pooling2d(\n          inputs=inputs,\n          pool_size=pool_size,\n          strides=1,\n          padding='VALID',\n          data_format=data_format,\n          name='final_avg_pool')\n      inputs = tf.identity(inputs, 'final_avg_pool')\n      inputs = tf.reshape(inputs, [-1, last_block_filters])\n\n      kernel_initializer = tf.variance_scaling_initializer()\n\n      kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n      if prune_last_layer:\n        inputs = sparse_fully_connected(\n            x=inputs,\n            units=num_classes,\n            sparsity_technique=pruning_method\n            if prune_last_layer else 'baseline',\n            kernel_initializer=kernel_initializer,\n            kernel_regularizer=kernel_regularizer,\n            name='final_dense')\n      else:\n        inputs = tf.layers.dense(\n            inputs=inputs,\n            units=num_classes,\n            activation=None,\n            use_bias=True,\n            kernel_initializer=kernel_initializer,\n            kernel_regularizer=kernel_regularizer,\n            name='final_dense')\n\n      inputs = tf.identity(inputs, 'final_dense')\n    return inputs\n\n  model.default_image_size = 224\n  return model\n\n\ndef mobilenet_v2(num_classes,\n                 pruning_method='baseline',\n                 width=1.,\n                 expansion_factor=6.,\n                 prune_last_layer=True,\n                 data_format='channels_first',\n                 weight_decay=0.,):\n  \"\"\"Returns the mobilenet_V2 model for a given size and number of output classes.\n\n  Args:\n    num_classes: Int number of possible classes for image classification.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    width: Float multiplier of the number of filters in each layer.\n    expansion_factor: How much to increase the number of filters before the\n      depthwise conv.\n    prune_last_layer: Whether or not to prune the last layer.\n    data_format: String specifying either \"channels_first\" for `[batch,\n      channels, height, width]` or \"channels_last for `[batch, height, width,\n      channels]`.\n    weight_decay: Weight for the l2 regularization loss.\n\n  Raises:\n    ValueError: If the resnet_depth int is not in the model_params dictionary.\n  \"\"\"\n  return mobilenet_v2_generator(\n      num_classes, pruning_method, width, expansion_factor, prune_last_layer,\n      data_format, weight_decay)\n"
  },
  {
    "path": "rigl/imagenet_resnet/pruning_layers.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tensorflow layers with parameters for implementing pruning.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow.compat.v1 as tf\n\n\nfrom tensorflow.contrib.framework.python.ops import variables\nfrom tensorflow.contrib.model_pruning.python.layers import layers\nfrom tensorflow.python.ops import init_ops\n\n\ndef get_model_variables(getter,\n                        name,\n                        shape=None,\n                        dtype=None,\n                        initializer=None,\n                        regularizer=None,\n                        trainable=True,\n                        collections=None,\n                        caching_device=None,\n                        partitioner=None,\n                        rename=None,\n                        use_resource=None,\n                        **_):\n  \"\"\"This ensure variables are retrieved in a consistent way for core layers.\"\"\"\n  short_name = name.split('/')[-1]\n  if rename and short_name in rename:\n    name_components = name.split('/')\n    name_components[-1] = rename[short_name]\n    name = '/'.join(name_components)\n  return variables.model_variable(\n      name,\n      shape=shape,\n      dtype=dtype,\n      initializer=initializer,\n      regularizer=regularizer,\n      collections=collections,\n      trainable=trainable,\n      caching_device=caching_device,\n      partitioner=partitioner,\n      custom_getter=getter,\n      use_resource=use_resource)\n\n\ndef variable_getter(rename=None):\n  \"\"\"Ensures scope is respected and consistently used.\"\"\"\n\n  def layer_variable_getter(getter, *args, **kwargs):\n    kwargs['rename'] = rename\n    return get_model_variables(getter, *args, **kwargs)\n\n  return layer_variable_getter\n\n\ndef sparse_conv2d(x,\n                  units,\n                  kernel_size,\n                  activation=None,\n                  use_bias=False,\n                  kernel_initializer=None,\n                  kernel_regularizer=None,\n                  bias_initializer=None,\n                  biases_regularizer=None,\n                  sparsity_technique='baseline',\n                  normalizer_fn=None,\n                  strides=(1, 1),\n                  padding='SAME',\n                  data_format='channels_last',\n                  name=None):\n  \"\"\"Function that constructs conv2d with any desired pruning method.\n\n  Args:\n    x: Input, float32 tensor.\n    units: Int representing size of output tensor.\n    kernel_size: The size of the convolutional window, int of list of ints.\n    activation: If None, a linear activation is used.\n    use_bias: Boolean specifying whether bias vector should be used.\n    kernel_initializer: Initializer for the convolution weights.\n    kernel_regularizer: Regularization method for the convolution weights.\n    bias_initializer: Initalizer of the bias vector.\n    biases_regularizer: Optional regularizer for the bias vector.\n    sparsity_technique: Method used to introduce sparsity.\n      ['threshold', 'baseline']\n    normalizer_fn: function used to transform the output activations.\n    strides: stride length of convolution, a single int is expected.\n    padding: May be populated as 'VALID' or 'SAME'.\n    data_format: Either 'channels_last', 'channels_first'.\n    name: String speciying name scope of layer in network.\n\n  Returns:\n    Output: activations.\n\n  Raises:\n    ValueError: If the rank of the input is not greater than 2.\n  \"\"\"\n\n  if data_format == 'channels_last':\n    data_format_channels = 'NHWC'\n  elif data_format == 'channels_first':\n    data_format_channels = 'NCHW'\n  else:\n    raise ValueError('Not a valid channel string:', data_format)\n\n  layer_variable_getter = variable_getter({\n      'bias': 'biases',\n      'kernel': 'weights',\n  })\n  input_rank = x.get_shape().ndims\n  if input_rank != 4:\n    raise ValueError('Rank not supported {}'.format(input_rank))\n\n  with tf.variable_scope(\n      name, 'Conv', [x], custom_getter=layer_variable_getter) as sc:\n\n    input_shape = x.get_shape().as_list()\n    if input_shape[-1] is None:\n      raise ValueError('The last dimension of the inputs to `Convolution` '\n                       'should be defined. Found `None`.')\n\n    pruning_methods = ['threshold']\n\n    if sparsity_technique in pruning_methods:\n      return layers.masked_conv2d(\n          inputs=x,\n          num_outputs=units,\n          kernel_size=kernel_size[0],\n          stride=strides[0],\n          padding=padding,\n          data_format=data_format_channels,\n          rate=1,\n          activation_fn=activation,\n          weights_initializer=kernel_initializer,\n          weights_regularizer=kernel_regularizer,\n          normalizer_fn=normalizer_fn,\n          normalizer_params=None,\n          biases_initializer=bias_initializer,\n          biases_regularizer=biases_regularizer,\n          outputs_collections=None,\n          trainable=True,\n          scope=sc)\n    elif sparsity_technique == 'baseline':\n      return tf.layers.conv2d(\n          inputs=x,\n          filters=units,\n          kernel_size=kernel_size,\n          strides=strides,\n          padding=padding,\n          use_bias=use_bias,\n          kernel_initializer=kernel_initializer,\n          kernel_regularizer=kernel_regularizer,\n          data_format=data_format,\n          name=name)\n    else:\n      raise ValueError(\n          'Unsupported sparsity technique {}'.format(sparsity_technique))\n\n\ndef sparse_fully_connected(x,\n                           units,\n                           activation=None,\n                           use_bias=True,\n                           kernel_initializer=None,\n                           kernel_regularizer=None,\n                           bias_initializer=init_ops.zeros_initializer(),\n                           biases_regularizer=None,\n                           sparsity_technique='baseline',\n                           name=None):\n  \"\"\"Constructs sparse_fully_connected with any desired pruning method.\n\n  Args:\n    x: Input, float32 tensor.\n    units: Int representing size of output tensor.\n    activation: If None, a linear activation is used.\n    use_bias: Boolean specifying whether bias vector should be used.\n    kernel_initializer: Initializer for the convolution weights.\n    kernel_regularizer: Regularization method for the convolution weights.\n    bias_initializer: Initalizer of the bias vector.\n    biases_regularizer: Optional regularizer for the bias vector.\n    sparsity_technique: Method used to introduce sparsity. ['baseline',\n      'threshold']\n    name: String speciying name scope of layer in network.\n\n  Returns:\n    Output: activations.\n\n  Raises:\n    ValueError: If the rank of the input is not greater than 2.\n  \"\"\"\n\n  layer_variable_getter = variable_getter({\n      'bias': 'biases',\n      'kernel': 'weights',\n  })\n\n  with tf.variable_scope(\n      name, 'Dense', [x], custom_getter=layer_variable_getter) as sc:\n\n    input_shape = x.get_shape().as_list()\n    if input_shape[-1] is None:\n      raise ValueError('The last dimension of the inputs to `Dense` '\n                       'should be defined. Found `None`.')\n\n    pruning_methods = ['threshold']\n\n    if sparsity_technique in pruning_methods:\n      return layers.masked_fully_connected(\n          inputs=x,\n          num_outputs=units,\n          activation_fn=activation,\n          weights_initializer=kernel_initializer,\n          weights_regularizer=kernel_regularizer,\n          biases_initializer=bias_initializer,\n          biases_regularizer=biases_regularizer,\n          outputs_collections=None,\n          trainable=True,\n          scope=sc)\n\n    elif sparsity_technique == 'baseline':\n      return tf.layers.dense(\n          inputs=x,\n          units=units,\n          activation=activation,\n          use_bias=use_bias,\n          kernel_initializer=kernel_initializer,\n          kernel_regularizer=kernel_regularizer,\n          bias_initializer=bias_initializer,\n          bias_regularizer=biases_regularizer,\n          name=name)\n    else:\n      raise ValueError(\n          'Unsupported sparsity technique {}'.format(sparsity_technique))\n"
  },
  {
    "path": "rigl/imagenet_resnet/resnet_model.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"ResNet modified to including pruning layers if specified.\n\nResidual networks (ResNets) were proposed in:\n[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun\n    Deep Residual Learning for Image Recognition. arXiv:1512.03385\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nfrom absl import flags\nfrom rigl.imagenet_resnet.pruning_layers import sparse_conv2d\nfrom rigl.imagenet_resnet.pruning_layers import sparse_fully_connected\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib import layers as contrib_layers\n\nfrom tensorflow.python.ops import init_ops\n\nFLAGS = flags.FLAGS\nBATCH_NORM_DECAY = 0.9\nBATCH_NORM_EPSILON = 1e-5\n\n\ndef batch_norm_relu(inputs, is_training, relu=True, init_zero=False,\n                    data_format='channels_first'):\n  \"\"\"Performs a batch normalization followed by a ReLU.\n\n  Args:\n    inputs: `Tensor` of shape `[batch, channels, ...]`.\n    is_training: `bool` for whether the model is training.\n    relu: `bool` if False, omits the ReLU operation.\n    init_zero: `bool` if True, initializes scale parameter of batch\n        normalization with 0 instead of 1 (default).\n    data_format: `str` either \"channels_first\" for `[batch, channels, height,\n        width]` or \"channels_last for `[batch, height, width, channels]`.\n\n  Returns:\n    A normalized `Tensor` with the same `data_format`.\n  \"\"\"\n  if init_zero:\n    gamma_initializer = tf.zeros_initializer()\n  else:\n    gamma_initializer = tf.ones_initializer()\n\n  if data_format == 'channels_first':\n    axis = 1\n  else:\n    axis = 3\n\n  inputs = tf.layers.batch_normalization(\n      inputs=inputs,\n      axis=axis,\n      momentum=BATCH_NORM_DECAY,\n      epsilon=BATCH_NORM_EPSILON,\n      center=True,\n      scale=True,\n      training=is_training,\n      fused=True,\n      gamma_initializer=gamma_initializer)\n\n  if relu:\n    inputs = tf.nn.relu(inputs)\n  return inputs\n\n\ndef fixed_padding(inputs, kernel_size, data_format='channels_first'):\n  \"\"\"Pads the input along the spatial dimensions independently of input size.\n\n  Args:\n    inputs: `Tensor` of size `[batch, channels, height, width]` or\n        `[batch, height, width, channels]` depending on `data_format`.\n    kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`\n        operations. Should be a positive integer.\n    data_format: `str` either \"channels_first\" for `[batch, channels, height,\n        width]` or \"channels_last for `[batch, height, width, channels]`.\n\n  Returns:\n    A padded `Tensor` of the same `data_format` with size either intact\n    (if `kernel_size == 1`) or padded (if `kernel_size > 1`).\n  \"\"\"\n  pad_total = kernel_size - 1\n  pad_beg = pad_total // 2\n  pad_end = pad_total - pad_beg\n  if data_format == 'channels_first':\n    padded_inputs = tf.pad(inputs, [[0, 0], [0, 0],\n                                    [pad_beg, pad_end], [pad_beg, pad_end]])\n  else:\n    padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],\n                                    [pad_beg, pad_end], [0, 0]])\n\n  return padded_inputs\n\n\nclass RandomSparseInitializer(init_ops.Initializer):\n  \"\"\"An initializer that sets a fraction of values to zero.\"\"\"\n\n  def __init__(self, sparsity, seed=None, dtype=tf.float32):\n    if sparsity < 0. or sparsity > 1.:\n      raise ValueError('sparsity must be in the range [0., 1.].')\n    self.kernel_initializer = tf.variance_scaling_initializer(seed=seed,\n                                                              dtype=dtype)\n    self.seed = seed\n    self.dtype = dtype\n    self.sparsity = float(sparsity)\n\n  def __call__(self, *args, **kwargs):\n    init_tensor = self.kernel_initializer(*args, **kwargs)\n    rand_vals = tf.random_uniform(tf.shape(init_tensor))\n    threshold = tf.constant(self.sparsity)\n    masked_tensor = tf.where(rand_vals < threshold,\n                             tf.zeros_like(rand_vals), init_tensor)\n    return masked_tensor\n\n  def get_config(self):\n    return {\n        'seed': self.seed,\n        'dtype': self.dtype.name,\n        'sparsity': self.sparsity\n    }\n\n\nclass SparseConvVarianceScalingInitializer(init_ops.Initializer):\n  \"\"\"Define an initializer for an already sparse layer.\"\"\"\n\n  def __init__(self, sparsity, seed=None, dtype=tf.float32):\n    if sparsity < 0. or sparsity >= 1.:\n      raise ValueError('sparsity must be in the range [0., 1.).')\n\n    self.sparsity = sparsity\n    self.seed = seed\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    if partition_info is not None:\n      raise ValueError('partition_info not supported.')\n    if dtype is None:\n      dtype = self.dtype\n\n    # Calculate number of non-zero weights\n    nnz = 1.\n    for d in shape:\n      nnz *= d\n    nnz *= (1. - self.sparsity)\n\n    input_channels = shape[-2]\n    n = nnz / input_channels\n\n    variance = (2. / n)**.5\n\n    return tf.random_normal(shape, 0, variance, dtype, seed=self.seed)\n\n  def get_config(self):\n    return {\n        'seed': self.seed,\n        'dtype': self.dtype.name,\n    }\n\n\nclass SparseFCVarianceScalingInitializer(init_ops.Initializer):\n  \"\"\"Define an initializer for an already sparse layer.\"\"\"\n\n  def __init__(self, sparsity, seed=None, dtype=tf.float32):\n    if sparsity < 0. or sparsity >= 1.:\n      raise ValueError('sparsity must be in the range [0., 1.).')\n\n    self.sparsity = sparsity\n    self.seed = seed\n\n  def __call__(self, shape, dtype=None, partition_info=None):\n    if partition_info is not None:\n      raise ValueError('partition_info not supported.')\n    if dtype is None:\n      dtype = self.dtype\n\n    if len(shape) != 2:\n      raise ValueError('Weights must be 2-dimensional.')\n\n    fan_in = shape[0]\n    fan_out = shape[1]\n\n    # Calculate number of non-zero weights\n    nnz = 1.\n    for d in shape:\n      nnz *= d\n    nnz *= (1. - self.sparsity)\n\n    limit = math.sqrt(6. / (nnz / fan_out + nnz / fan_in))\n\n    return tf.random_uniform(shape, -limit, limit, dtype, seed=self.seed)\n\n  def get_config(self):\n    return {\n        'seed': self.seed,\n        'dtype': self.dtype.name,\n    }\n\n\ndef _pick_initializer(kernel_initializer, init_method, pruning_method,\n                      end_sparsity):\n  \"\"\"Updates the initializer selected, if necessary.\"\"\"\n  if init_method == 'sparse':\n    if pruning_method != 'threshold':\n      raise ValueError(\n          'Unsupported combination of flags, pruning_method must be threshold'\n          ' if init_method is `sparse`.')\n    else:\n      kernel_initializer = SparseFCVarianceScalingInitializer(end_sparsity)\n  elif init_method == 'random_zeros':\n    if pruning_method != 'baseline':\n      raise ValueError(\n          'Unsupported combination of flags, pruning_method must be '\n          'baseline if init_method is `random_zeros`.')\n    else:\n      kernel_initializer = RandomSparseInitializer(end_sparsity)\n  return kernel_initializer\n\n\ndef conv2d_fixed_padding(inputs,\n                         filters,\n                         kernel_size,\n                         strides,\n                         pruning_method='baseline',\n                         init_method='baseline',\n                         data_format='channels_first',\n                         end_sparsity=0.,\n                         weight_decay=0.,\n                         init_scale=1.0,\n                         name=None):\n  \"\"\"Strided 2-D convolution with explicit padding.\n\n  The padding is consistent and is based only on `kernel_size`, not on the\n  dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    filters: Int specifying number of filters for the first two convolutions.\n    kernel_size: Int designating size of kernel to be used in the convolution.\n    strides: Int specifying the stride. If stride >1, the input is downsampled.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with\n      pruning_method == 'scratch'. 'random_zeros' set random weights to zero\n      using end_sparsoty parameter and used with 'baseline' method.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n    init_scale: float, passed to the VarianceScalingInitializer.\n    name: String that specifies name for model layer.\n\n  Returns:\n    The output activation tensor of size [batch, filters, height_out, width_out]\n\n  Raises:\n    ValueError: If the data_format provided is not a valid string.\n  \"\"\"\n  if strides > 1:\n    inputs = fixed_padding(\n        inputs, kernel_size, data_format=data_format)\n  padding = 'SAME' if strides == 1 else 'VALID'\n\n  kernel_initializer = tf.variance_scaling_initializer(scale=init_scale)\n  kernel_initializer = _pick_initializer(kernel_initializer, init_method,\n                                         pruning_method, end_sparsity)\n\n  kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n  return sparse_conv2d(\n      x=inputs,\n      units=filters,\n      activation=None,\n      kernel_size=[kernel_size, kernel_size],\n      use_bias=False,\n      kernel_initializer=kernel_initializer,\n      kernel_regularizer=kernel_regularizer,\n      bias_initializer=None,\n      biases_regularizer=None,\n      sparsity_technique=pruning_method,\n      normalizer_fn=None,\n      strides=[strides, strides],\n      padding=padding,\n      data_format=data_format,\n      name=name)\n\n\ndef residual_block_(inputs,\n                    filters,\n                    is_training,\n                    strides,\n                    use_projection=False,\n                    pruning_method='baseline',\n                    init_method='baseline',\n                    data_format='channels_first',\n                    end_sparsity=0.,\n                    weight_decay=0.,\n                    name=''):\n  \"\"\"Standard building block for residual networks with BN after convolutions.\n\n  Args:\n    inputs:  Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    filters: Int specifying number of filters for the first two convolutions.\n    is_training: Boolean specifying whether the model is training.\n    strides: Int specifying the stride. If stride >1, the input is downsampled.\n    use_projection: Boolean for whether the layer should use a projection\n      shortcut Often, use_projection=True for the first block of a block group.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with\n      pruning_method == 'scratch'. 'random_zeros' sets random weights to zero\n      using end_sparsoty parameter and used with 'baseline' method.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies name for model layer.\n\n  Returns:\n    The output activation tensor.\n  \"\"\"\n  shortcut = inputs\n  if use_projection:\n    # Projection shortcut in first layer to match filters and strides\n    end_point = 'residual_projection_%s' % name\n    shortcut = conv2d_fixed_padding(\n        inputs=inputs,\n        filters=filters,\n        kernel_size=1,\n        strides=strides,\n        pruning_method=pruning_method,\n        init_method=init_method,\n        data_format=data_format,\n        end_sparsity=end_sparsity,\n        weight_decay=weight_decay,\n        name=end_point)\n    shortcut = batch_norm_relu(\n        shortcut, is_training, relu=False, data_format=data_format)\n\n  end_point = 'residual_1_%s' % name\n  inputs = conv2d_fixed_padding(\n      inputs=inputs,\n      filters=filters,\n      kernel_size=3,\n      strides=strides,\n      pruning_method=pruning_method,\n      init_method=init_method,\n      data_format=data_format,\n      end_sparsity=end_sparsity,\n      weight_decay=weight_decay,\n      name=end_point)\n  inputs = batch_norm_relu(\n      inputs, is_training, data_format=data_format)\n\n  end_point = 'residual_2_%s' % name\n  inputs = conv2d_fixed_padding(\n      inputs=inputs,\n      filters=filters,\n      kernel_size=3,\n      strides=1,\n      pruning_method=pruning_method,\n      init_method=init_method,\n      data_format=data_format,\n      end_sparsity=end_sparsity,\n      weight_decay=weight_decay,\n      name=end_point)\n  inputs = batch_norm_relu(\n      inputs, is_training, relu=False, init_zero=True, data_format=data_format)\n\n  return tf.nn.relu(inputs + shortcut)\n\n\ndef bottleneck_block_(inputs,\n                      filters,\n                      is_training,\n                      strides,\n                      use_projection=False,\n                      pruning_method='baseline',\n                      init_method='baseline',\n                      data_format='channels_first',\n                      end_sparsity=0.,\n                      weight_decay=0.,\n                      name=None):\n  \"\"\"Bottleneck block variant for residual networks with BN after convolutions.\n\n  Args:\n    inputs: Input tensor, float32 or bfloat16 of size [batch, channels, height,\n      width].\n    filters: Int specifying number of filters for the first two convolutions.\n    is_training: Boolean specifying whether the model is training.\n    strides: Int specifying the stride. If stride >1, the input is downsampled.\n    use_projection: Boolean for whether the layer should use a projection\n      shortcut Often, use_projection=True for the first block of a block group.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with\n      pruning_method == 'scratch'. 'random_zeros' set random weights to zero\n      using end_sparsoty parameter and used with 'baseline' method.\n    data_format: String that specifies either \"channels_first\" for [batch,\n      channels, height,width] or \"channels_last\" for [batch, height, width,\n      channels].\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies name for model layer.\n\n  Returns:\n    The output activation tensor.\n  \"\"\"\n  shortcut = inputs\n\n  if use_projection:\n    # Projection shortcut only in first block within a group. Bottleneck blocks\n    # end with 4 times the number of filters.\n    filters_out = 4 * filters\n    end_point = 'bottleneck_projection_%s' % name\n    shortcut = conv2d_fixed_padding(\n        inputs=inputs,\n        filters=filters_out,\n        kernel_size=1,\n        strides=strides,\n        pruning_method=pruning_method,\n        init_method=init_method,\n        data_format=data_format,\n        end_sparsity=end_sparsity,\n        weight_decay=weight_decay,\n        name=end_point)\n    shortcut = batch_norm_relu(\n        shortcut, is_training, relu=False, data_format=data_format)\n\n  end_point = 'bottleneck_1_%s' % name\n  inputs = conv2d_fixed_padding(\n      inputs=inputs,\n      filters=filters,\n      kernel_size=1,\n      strides=1,\n      pruning_method=pruning_method,\n      init_method=init_method,\n      data_format=data_format,\n      end_sparsity=end_sparsity,\n      weight_decay=weight_decay,\n      name=end_point)\n  inputs = batch_norm_relu(\n      inputs, is_training, data_format=data_format)\n\n  end_point = 'bottleneck_2_%s' % name\n  inputs = conv2d_fixed_padding(\n      inputs=inputs,\n      filters=filters,\n      kernel_size=3,\n      strides=strides,\n      pruning_method=pruning_method,\n      init_method=init_method,\n      data_format=data_format,\n      end_sparsity=end_sparsity,\n      weight_decay=weight_decay,\n      name=end_point)\n  inputs = batch_norm_relu(\n      inputs, is_training, data_format=data_format)\n\n  end_point = 'bottleneck_3_%s' % name\n  inputs = conv2d_fixed_padding(\n      inputs=inputs,\n      filters=4 * filters,\n      kernel_size=1,\n      strides=1,\n      pruning_method=pruning_method,\n      init_method=init_method,\n      data_format=data_format,\n      end_sparsity=end_sparsity,\n      weight_decay=weight_decay,\n      name=end_point)\n  inputs = batch_norm_relu(\n      inputs, is_training, relu=False, init_zero=True, data_format=data_format)\n\n  return tf.nn.relu(inputs + shortcut)\n\n\ndef block_group(inputs,\n                filters,\n                block_fn,\n                blocks,\n                strides,\n                is_training,\n                name,\n                pruning_method='baseline',\n                init_method='baseline',\n                data_format='channels_first',\n                end_sparsity=0.,\n                weight_decay=0.):\n  \"\"\"Creates one group of blocks for the ResNet model.\n\n  Args:\n    inputs: `Tensor` of size `[batch, channels, height, width]`.\n    filters: `int` number of filters for the first convolution of the layer.\n    block_fn: `function` for the block to use within the model\n    blocks: `int` number of blocks contained in the layer.\n    strides: `int` stride to use for the first convolution of the layer. If\n      greater than 1, this layer will downsample the input.\n    is_training: `bool` for whether the model is training.\n    name: String specifying the Tensor output of the block layer.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with\n      pruning_method == 'scratch'. 'random_zeros' set random weights to zero\n      using end_sparsoty parameter and used with 'baseline' method.\n    data_format: `str` either \"channels_first\" for `[batch, channels, height,\n      width]` or \"channels_last for `[batch, height, width, channels]`.\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n\n  Returns:\n    The output `Tensor` of the block layer.\n  \"\"\"\n  with tf.name_scope(name):\n    end_point = 'block_group_projection_%s' % name\n    # Only the first block per block_group uses projection shortcut and strides.\n    inputs = block_fn(\n        inputs,\n        filters,\n        is_training,\n        strides,\n        use_projection=True,\n        pruning_method=pruning_method,\n        init_method=init_method,\n        data_format=data_format,\n        end_sparsity=end_sparsity,\n        weight_decay=weight_decay,\n        name=end_point)\n\n    for n in range(1, blocks):\n      with tf.name_scope('block_group_%d' % n):\n        end_point = '%s_%d_1' % (name, n)\n        inputs = block_fn(\n            inputs,\n            filters,\n            is_training,\n            1,\n            pruning_method=pruning_method,\n            init_method=init_method,\n            data_format=data_format,\n            end_sparsity=end_sparsity,\n            weight_decay=weight_decay,\n            name=end_point)\n\n  return tf.identity(inputs, name)\n\n\ndef resnet_v1_generator(block_fn,\n                        num_blocks,\n                        num_classes,\n                        pruning_method='baseline',\n                        init_method='baseline',\n                        width=1.,\n                        prune_first_layer=True,\n                        prune_last_layer=True,\n                        data_format='channels_first',\n                        end_sparsity=0.,\n                        weight_decay=0.,\n                        name=None):\n  \"\"\"Generator for ResNet v1 models.\n\n  Args:\n    block_fn: String that defines whether to use a `residual_block` or\n      `bottleneck_block`.\n    num_blocks: list of Ints that denotes number of blocks to include in each\n      block group. Each group consists of blocks that take inputs of the same\n      resolution.\n    num_classes: Int number of possible classes for image classification.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with\n      pruning_method == 'scratch'. 'random_zeros' set random weights to zero\n      using end_sparsoty parameter and used with 'baseline' method.\n    width: Float that scales the number of filters in each layer.\n    prune_first_layer: Whether or not to prune the first layer.\n    prune_last_layer: Whether or not to prune the last layer.\n    data_format: String either \"channels_first\" for `[batch, channels, height,\n      width]` or \"channels_last for `[batch, height, width, channels]`.\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies name for model layer.\n\n  Returns:\n    Model `function` that takes in `inputs` and `is_training` and returns the\n    output `Tensor` of the ResNet model.\n  \"\"\"\n\n  def model(inputs, is_training):\n    \"\"\"Creation of the model graph.\"\"\"\n    with tf.variable_scope(name, 'resnet_model'):\n      inputs = conv2d_fixed_padding(\n          inputs=inputs,\n          filters=int(64 * width),\n          kernel_size=7,\n          strides=2,\n          pruning_method=pruning_method if prune_first_layer else 'baseline',\n          init_method=init_method if prune_first_layer else 'baseline',\n          data_format=data_format,\n          end_sparsity=end_sparsity,\n          weight_decay=weight_decay,\n          name='initial_conv')\n\n      inputs = tf.identity(inputs, 'initial_conv')\n      inputs = batch_norm_relu(\n          inputs, is_training, data_format=data_format)\n\n      inputs = tf.layers.max_pooling2d(\n          inputs=inputs,\n          pool_size=3,\n          strides=2,\n          padding='SAME',\n          data_format=data_format,\n          name='initial_max_pool')\n      inputs = tf.identity(inputs, 'initial_max_pool')\n\n      inputs = block_group(\n          inputs=inputs,\n          filters=int(64 * width),\n          block_fn=block_fn,\n          blocks=num_blocks[0],\n          strides=1,\n          is_training=is_training,\n          name='block_group1',\n          pruning_method=pruning_method,\n          init_method=init_method,\n          data_format=data_format,\n          end_sparsity=end_sparsity,\n          weight_decay=weight_decay)\n      inputs = block_group(\n          inputs=inputs,\n          filters=int(128 * width),\n          block_fn=block_fn,\n          blocks=num_blocks[1],\n          strides=2,\n          is_training=is_training,\n          name='block_group2',\n          pruning_method=pruning_method,\n          init_method=init_method,\n          data_format=data_format,\n          end_sparsity=end_sparsity,\n          weight_decay=weight_decay)\n      inputs = block_group(\n          inputs=inputs,\n          filters=int(256 * width),\n          block_fn=block_fn,\n          blocks=num_blocks[2],\n          strides=2,\n          is_training=is_training,\n          name='block_group3',\n          pruning_method=pruning_method,\n          init_method=init_method,\n          data_format=data_format,\n          end_sparsity=end_sparsity,\n          weight_decay=weight_decay)\n      inputs = block_group(\n          inputs=inputs,\n          filters=int(512 * width),\n          block_fn=block_fn,\n          blocks=num_blocks[3],\n          strides=2,\n          is_training=is_training,\n          name='block_group4',\n          pruning_method=pruning_method,\n          init_method=init_method,\n          data_format=data_format,\n          end_sparsity=end_sparsity,\n          weight_decay=weight_decay)\n\n      pool_size = (inputs.shape[1], inputs.shape[2])\n      inputs = tf.layers.average_pooling2d(\n          inputs=inputs,\n          pool_size=pool_size,\n          strides=1,\n          padding='VALID',\n          data_format=data_format,\n          name='final_avg_pool')\n      inputs = tf.identity(inputs, 'final_avg_pool')\n      multiplier = 4 if block_fn is bottleneck_block_ else 1\n      fc_units = multiplier * int(512 * width)\n      inputs = tf.reshape(inputs, [-1, fc_units])\n      kernel_initializer = tf.random_normal_initializer(stddev=.01)\n      # If init_method==sparse and not pruning, skip.\n      if init_method != 'sparse' or prune_last_layer:\n        kernel_initializer = _pick_initializer(kernel_initializer, init_method,\n                                               pruning_method, end_sparsity)\n      kernel_regularizer = contrib_layers.l2_regularizer(weight_decay)\n      inputs = sparse_fully_connected(\n          x=inputs,\n          units=num_classes,\n          sparsity_technique=pruning_method if prune_last_layer else 'baseline',\n          kernel_initializer=kernel_initializer,\n          kernel_regularizer=kernel_regularizer,\n          name='final_dense')\n\n      inputs = tf.identity(inputs, 'final_dense')\n    return inputs\n\n  model.default_image_size = 224\n  return model\n\n\ndef resnet_v1_(resnet_depth,\n               num_classes,\n               pruning_method='baseline',\n               init_method='baseline',\n               width=1.,\n               prune_first_layer=True,\n               prune_last_layer=True,\n               data_format='channels_first',\n               end_sparsity=0.,\n               weight_decay=0.,\n               name=None):\n  \"\"\"Returns the ResNet model for a given size and number of output classes.\n\n  Args:\n    resnet_depth: Int number of blocks in the architecture.\n    num_classes: Int number of possible classes for image classification.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with\n      pruning_method == 'scratch'. 'random_zeros' set random weights to zero\n      using end_sparsoty parameter and used with 'baseline' method.\n    width: Float multiplier of the number of filters in each layer.\n    prune_first_layer: Whether or not to prune the first layer.\n    prune_last_layer: Whether or not to prune the last layer.\n    data_format: String specifying either \"channels_first\" for `[batch,\n      channels, height, width]` or \"channels_last for `[batch, height, width,\n      channels]`.\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n    name: String that specifies the prefix for the scope.\n\n  Raises:\n    ValueError: If the resnet_depth int is not in the model_params dictionary.\n  \"\"\"\n  model_params = {\n      18: {\n          'block': residual_block_,\n          'layers': [2, 2, 2, 2]\n      },\n      34: {\n          'block': residual_block_,\n          'layers': [3, 4, 6, 3]\n      },\n      50: {\n          'block': bottleneck_block_,\n          'layers': [3, 4, 6, 3]\n      },\n      101: {\n          'block': bottleneck_block_,\n          'layers': [3, 4, 23, 3]\n      },\n      152: {\n          'block': bottleneck_block_,\n          'layers': [3, 8, 36, 3]\n      },\n      200: {\n          'block': bottleneck_block_,\n          'layers': [3, 24, 36, 3]\n      }\n  }\n\n  if resnet_depth not in model_params:\n    raise ValueError('Not a valid resnet_depth:', resnet_depth)\n\n  params = model_params[resnet_depth]\n  return resnet_v1_generator(\n      params['block'], params['layers'], num_classes, pruning_method,\n      init_method, width, prune_first_layer, prune_last_layer, data_format,\n      end_sparsity, weight_decay, name)\n"
  },
  {
    "path": "rigl/imagenet_resnet/train_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Tests for the data_helper input pipeline and the training process.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import flags\nimport absl.testing.parameterized as parameterized\n\nfrom rigl.imagenet_resnet.imagenet_train_eval import resnet_model_fn_w_pruning\nfrom rigl.imagenet_resnet.imagenet_train_eval import set_lr_schedule\nimport tensorflow.compat.v1 as tf  # tf\nfrom official.resnet import imagenet_input\nfrom tensorflow.contrib.tpu.python.tpu import tpu_config\nfrom tensorflow.contrib.tpu.python.tpu import tpu_estimator\n\nFLAGS = flags.FLAGS\n\n\nclass DataInputTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _retrieve_data(self, is_training, data_dir):\n\n    dataset = imagenet_input.ImageNetInput(\n        is_training=is_training,\n        data_dir=data_dir,\n        transpose_input=False,\n        num_parallel_calls=8,\n        use_bfloat16=False)\n\n    return dataset\n\n  @parameterized.parameters('snip', 'set', 'rigl', 'scratch')\n  def testTrainingPipeline(self, training_method):\n    output_directory = '/tmp/'\n\n    g = tf.Graph()\n    with g.as_default():\n\n      dataset = self._retrieve_data(is_training=False, data_dir=False)\n\n      FLAGS.transpose_input = False\n      FLAGS.use_tpu = False\n      FLAGS.mode = 'train'\n      FLAGS.mask_init_method = 'random'\n      FLAGS.precision = 'float32'\n      FLAGS.train_steps = 1\n      FLAGS.train_batch_size = 1\n      FLAGS.eval_batch_size = 1\n      FLAGS.steps_per_eval = 1\n      FLAGS.model_architecture = 'resnet'\n\n      params = {}\n      params['output_dir'] = output_directory\n      params['training_method'] = training_method\n      params['use_tpu'] = False\n      set_lr_schedule()\n\n      run_config = tpu_config.RunConfig(\n          master=None,\n          model_dir=None,\n          save_checkpoints_steps=1,\n          tpu_config=tpu_config.TPUConfig(iterations_per_loop=1, num_shards=1))\n\n      classifier = tpu_estimator.TPUEstimator(\n          use_tpu=False,\n          model_fn=resnet_model_fn_w_pruning,\n          params=params,\n          config=run_config,\n          train_batch_size=1,\n          eval_batch_size=1)\n\n      classifier.train(input_fn=dataset.input_fn, max_steps=1)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "rigl/imagenet_resnet/utils.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Helped functions to concatenate subset of noisy images to batch.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.compat.v2 import summary\n\nIMG_SUMMARY_PREFIX = '_img_'\n\n\ndef format_tensors(*dicts):\n  \"\"\"Format metrics to be callable as tf.summary scalars on tpu's.\n\n  Args:\n    *dicts: A set of metric dictionaries, containing metric name + value tensor.\n\n  Returns:\n    A single formatted dictionary that holds all tensors.\n\n  Raises:\n   ValueError: if any tensor is not a scalar.\n  \"\"\"\n  merged_summaries = {}\n  for d in dicts:\n    for metric_name, value in d.items():\n      shape = value.shape.as_list()\n      if metric_name.startswith(IMG_SUMMARY_PREFIX):\n        # If image, shape it into 2d.\n        merged_summaries[metric_name] = tf.reshape(value,\n                                                   (1, -1, value.shape[-1], 1))\n      elif not shape:\n        merged_summaries[metric_name] = tf.expand_dims(value, axis=0)\n      elif shape == [1]:\n        merged_summaries[metric_name] = value\n      else:\n        raise ValueError(\n            'Metric {} has value {} that is not reconciliable'.format(\n                metric_name, value))\n  return merged_summaries\n\n\ndef host_call_fn(model_dir, **kwargs):\n  \"\"\"host_call function used for creating training summaries when using TPU.\n\n  Args:\n    model_dir: String indicating the output_dir to save summaries in.\n    **kwargs: Set of metric names and tensor values for all desired summaries.\n\n  Returns:\n    Summary op to be passed to the host_call arg of the estimator function.\n  \"\"\"\n  gs = kwargs.pop('global_step')[0]\n  with summary.create_file_writer(model_dir).as_default():\n    # Always record summaries.\n    with summary.record_if(True):\n      for name, tensor in kwargs.items():\n        if name.startswith(IMG_SUMMARY_PREFIX):\n          summary.image(name.replace(IMG_SUMMARY_PREFIX, ''), tensor,\n                        max_images=1)\n        else:\n          summary.scalar(name, tensor[0], step=gs)\n      # Following function is under tf:1x, so we use it.\n      return tf.summary.all_v2_summary_ops()\n\n\ndef mask_summaries(masks, with_img=False):\n  metrics = {}\n  for mask in masks:\n    metrics['pruning/{}/sparsity'.format(\n        mask.op.name)] = tf.nn.zero_fraction(mask)\n    if with_img:\n      metrics[IMG_SUMMARY_PREFIX + 'mask/' + mask.op.name] = mask\n  return metrics\n\n\ndef initialize_parameters_from_ckpt(ckpt_path, model_dir, param_suffixes):\n  \"\"\"Load parameters from an existing checkpoint.\n\n  Args:\n    ckpt_path: str, loads the mask variables from this checkpoint.\n    model_dir: str, if checkpoint exists in this folder no-op.\n    param_suffixes: list or str, suffix of parameters to be load from\n      checkpoint.\n  \"\"\"\n  already_has_ckpt = model_dir and tf.train.latest_checkpoint(\n      model_dir) is not None\n  if already_has_ckpt:\n    tf.logging.info(\n        'Training already started on this model, not loading masks from'\n        'previously trained model')\n    return\n\n  reader = tf.train.NewCheckpointReader(ckpt_path)\n  param_names = reader.get_variable_to_shape_map().keys()\n  param_names = [x for x in param_names if x.endswith(param_suffixes)]\n\n  variable_map = {}\n  for var in tf.global_variables():\n    var_name = var.name.split(':')[0]\n    if var_name in param_names:\n      tf.logging.info('Loading parameter variable from checkpoint: %s',\n                      var_name)\n      variable_map[var_name] = var\n    elif var_name.endswith(param_suffixes):\n      tf.logging.info(\n          'Cannot find parameter variable in checkpoint, skipping: %s',\n          var_name)\n  tf.train.init_from_checkpoint(ckpt_path, variable_map)\n"
  },
  {
    "path": "rigl/imagenet_resnet/vgg.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n# http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Contains model definitions for versions of the Oxford VGG network.\n\nThese model definitions were introduced in the following technical report:\n\n  Very Deep Convolutional Networks For Large-Scale Image Recognition\n  Karen Simonyan and Andrew Zisserman\n  arXiv technical report, 2015\n  PDF: http://arxiv.org/pdf/1409.1556.pdf\n  ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf\n  CC-BY-4.0\n\nMore information can be obtained from the VGG website:\nwww.robots.ox.ac.uk/~vgg/research/very_deep/\n\nUsage:\n  with arg_scope(vgg.vgg_arg_scope()):\n    outputs, end_points = vgg.vgg_net(inputs,scope='vgg_19')\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nimport functools\n\nfrom rigl.imagenet_resnet import resnet_model\nimport tensorflow.compat.v1 as tf\nfrom tensorflow.contrib import layers\n\nnetwork_cfg = {\n    'vgg_a': [1, 1, 2, 2, 2],\n    'vgg_16': [2, 2, 3, 3, 3],\n    'vgg_19': [2, 2, 4, 4, 4],\n}\n\n\ndef vgg_net(inputs,\n            num_classes=1000,\n            spatial_squeeze=True,\n            name='vgg_a',\n            global_pool=True,\n            pruning_method='baseline',\n            init_method='baseline',\n            data_format='channels_last',\n            width=1.,\n            prune_last_layer=True,\n            end_sparsity=0.,\n            weight_decay=0.):\n  \"\"\"Oxford Net VGG.\n\n  Note: All the fully_connected layers have been transformed to conv2d layers.\n        To use in classification mode, resize input to 224x224.\n\n  Args:\n    inputs: a tensor of size [batch_size, height, width, channels].\n    num_classes: number of predicted classes. If 0 or None, the logits layer is\n      omitted and the input features to the logits layer are returned instead.\n    spatial_squeeze: whether or not should squeeze the spatial dimensions of the\n      outputs. Useful to remove unnecessary dimensions for classification.\n    name: Optional scope for the variables.\n    global_pool: Optional boolean flag. If True, the input to the classification\n      layer is avgpooled to size 1x1, for any input size. (This is not part\n      of the original VGG architecture.)\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with\n      pruning_method == 'scratch'. 'random_zeros' set random weights to zero\n      using end_sparsoty parameter and used with 'baseline' method.\n    data_format: String specifying either \"channels_first\" for `[batch,\n      channels, height, width]` or \"channels_last for `[batch, height, width,\n      channels]`.\n    width: Float multiplier of the number of filters in each layer.\n    prune_last_layer: Whether or not to prune the last layer.\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n\n  Returns:\n    net: the output of the logits layer (if num_classes is a non-zero integer),\n      or the non-dropped-out input to the logits layer (if num_classes is 0 or\n      None).\n    end_points: a dict of tensors with intermediate activations. For\n      backwards compatibility, some Tensors appear multiple times in the dict.\n  \"\"\"\n  net_cfg = network_cfg[name]\n  sparse_conv2d = functools.partial(\n      resnet_model.conv2d_fixed_padding,\n      pruning_method=pruning_method,\n      init_method=init_method,\n      data_format=data_format,\n      init_scale=2.0,  # Heinit\n      end_sparsity=end_sparsity,\n      weight_decay=weight_decay)\n\n  def new_sparse_conv2d(*args, **kwargs):\n    kwargs['name'] = kwargs['scope']\n    del kwargs['scope']\n    activation_fn = 'relu'\n    if 'activation_fn' in kwargs:\n      activation_fn = kwargs['activation_fn']\n      del kwargs['activation_fn']\n    out = sparse_conv2d(*args, **kwargs)\n    if activation_fn == 'relu':\n      out = tf.nn.relu(out)\n    return out\n\n  with tf.variable_scope(name, name, values=[inputs]):\n    net = layers.repeat(\n        inputs,\n        net_cfg[0],\n        new_sparse_conv2d,\n        int(64 * width),\n        3,\n        strides=1,\n        scope='conv1')\n    net = layers.max_pool2d(net, [2, 2], scope='pool1')\n    net = layers.repeat(\n        net,\n        net_cfg[1],\n        new_sparse_conv2d,\n        int(128 * width),\n        3,\n        strides=1,\n        scope='conv2')\n    net = layers.max_pool2d(net, [2, 2], scope='pool2')\n    net = layers.repeat(\n        net,\n        net_cfg[2],\n        new_sparse_conv2d,\n        int(256 * width),\n        3,\n        strides=1,\n        scope='conv3')\n    net = layers.max_pool2d(net, [2, 2], scope='pool3')\n    net = layers.repeat(\n        net,\n        net_cfg[3],\n        new_sparse_conv2d,\n        int(512 * width),\n        3,\n        strides=1,\n        scope='conv4')\n    net = layers.max_pool2d(net, [2, 2], scope='pool4')\n    net = layers.repeat(\n        net,\n        net_cfg[4],\n        new_sparse_conv2d,\n        int(512 * width),\n        3,\n        strides=1,\n        scope='conv5')\n\n    # # Use conv2d instead of fully_connected layers.\n    # net = new_sparse_conv2d(net, 512, [7, 7], strides=1, scope='fc6')\n    # # net = layers.dropout(net, dropout_keep_prob, is_training=is_training,\n    # #                      scope='dropout6')\n    # net = new_sparse_conv2d(net, 512, [1, 1], strides=1, scope='fc7')\n    if global_pool:\n      net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')\n    if num_classes:\n      # net = layers.dropout(net, dropout_keep_prob, is_training=is_training,\n      #                      scope='dropout7')\n      if prune_last_layer:\n        net = new_sparse_conv2d(\n            net, num_classes, 1, activation_fn=None, strides=1, scope='fc8')\n      else:\n        net = layers.conv2d(\n            net, num_classes, [1, 1], activation_fn=None, scope='fc8')\n    if spatial_squeeze:\n      net = tf.squeeze(net, [1, 2], name='fc8/squeezed')\n    return net\n\n\ndef vgg(vgg_type,\n        num_classes,\n        pruning_method='baseline',\n        init_method='baseline',\n        width=1.,\n        prune_last_layer=True,\n        data_format='channels_last',\n        end_sparsity=0.,\n        weight_decay=0.):\n  \"\"\"Returns the ResNet model for a given size and number of output classes.\n\n  Args:\n    vgg_type: Int number of blocks in the architecture.\n    num_classes: Int number of possible classes for image classification.\n    pruning_method: String that specifies the pruning method used to identify\n      which weights to remove.\n    init_method: ('baseline', 'sparse', 'random_zeros') Whether to use standard\n      initialization or initialization that takes into the existing sparsity of\n      the layer. 'sparse' only makes sense when combined with pruning_method ==\n      'scratch'. 'random_zeros' set random weights to zero using end_sparsoty\n      parameter and used with 'baseline' method.\n    width: Float multiplier of the number of filters in each layer.\n    prune_last_layer: Whether or not to prune the last layer.\n    data_format: String specifying either \"channels_first\" for `[batch,\n      channels, height, width]` or \"channels_last for `[batch, height, width,\n      channels]`.\n    end_sparsity: Desired sparsity at the end of training. Necessary to\n      initialize an already sparse network.\n    weight_decay: Weight for the l2 regularization loss.\n\n  Raises:\n    ValueError: If the resnet_depth int is not in the model_params dictionary.\n  \"\"\"\n\n  def model_fn(inputs, is_training):\n    del is_training\n    return vgg_net(\n        inputs,\n        num_classes,\n        name=vgg_type,\n        pruning_method=pruning_method,\n        init_method=init_method,\n        data_format=data_format,\n        width=width,\n        prune_last_layer=prune_last_layer,\n        end_sparsity=end_sparsity,\n        weight_decay=weight_decay)\n\n  return model_fn\n"
  },
  {
    "path": "rigl/mnist/mnist_train_eval.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"A configurable, multi-layer fully connected network trained on MNIST.\n\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport time\nfrom absl import flags\n\nimport numpy as np\nfrom rigl import sparse_optimizers\nfrom rigl import sparse_utils\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.contrib import layers as contrib_layers\nfrom tensorflow.contrib.model_pruning.python import pruning\nfrom tensorflow.contrib.model_pruning.python.layers import layers\nfrom tensorflow.examples.tutorials.mnist import input_data\n\n\nflags.DEFINE_string('mnist', '/tmp/data', 'Location of the MNIST ' 'dataset.')\n\n## optimizer hyperparameters\nflags.DEFINE_integer('batch_size', 100, 'The number of samples in each batch')\nflags.DEFINE_float('learning_rate', .2, 'Initial learning rate.')\nflags.DEFINE_float('momentum', .9, 'Momentum.')\nflags.DEFINE_boolean('use_nesterov', True, 'Use nesterov momentum.')\nflags.DEFINE_integer('num_epochs', 200, 'Number of epochs to run.')\nflags.DEFINE_integer('lr_drop_epoch', 75, 'The epoch to start dropping lr.')\nflags.DEFINE_string('optimizer', 'momentum',\n                    'Optimizer to use. sgd, momentum or adam')\nflags.DEFINE_float('l2_scale', 1e-4, 'l2 loss scale')\nflags.DEFINE_string('network_type', 'fc',\n                    'Type of the network. See below for available options.')\nflags.DEFINE_enum(\n    'training_method', 'baseline',\n    ('scratch', 'set', 'baseline', 'momentum', 'rigl', 'static', 'snip',\n     'prune'),\n    'Method used for training sparse network. `scratch` means initial mask is '\n    'kept during training. `set` is for sparse evalutionary training and '\n    '`baseline` is for dense baseline.')\nflags.DEFINE_float('drop_fraction', 0.3,\n                   'When changing mask dynamically, this fraction decides how '\n                   'much of the ')\nflags.DEFINE_string('drop_fraction_anneal', 'cosine',\n                    'If not empty the drop fraction is annealed during sparse'\n                    ' training. One of the following: `constant`, `cosine` or '\n                    '`exponential_(\\\\d*\\\\.?\\\\d*)$`. For example: '\n                    '`exponential_3`, `exponential_.3`, `exponential_0.3`. '\n                    'The number after `exponential` defines the exponent.')\nflags.DEFINE_string('grow_init', 'zeros',\n                    'Passed to the SparseInitializer, one of: zeros, '\n                    'initial_value, random_normal, random_uniform.')\nflags.DEFINE_float('s_momentum', 0.9,\n                   'Momentum values for exponential moving average of '\n                   'gradients. Used when training_method=\"momentum\".')\nflags.DEFINE_string(\n    'input_mask_path', '',\n    'If given, uses the first mask of the checkpoint to mask '\n    'the input. If all the outgoing connections are masked '\n    'in the mask, we mask that dimension of the input.')\nflags.DEFINE_float('sparsity_scale', 0.9, 'Relative sparsity of second layer.')\nflags.DEFINE_float('rigl_acc_scale', 0.,\n                   'Used to scale initial accumulated gradients for new '\n                   'connections.')\nflags.DEFINE_integer('maskupdate_begin_step', 0, 'Step to begin mask updates.')\nflags.DEFINE_integer('maskupdate_end_step', 50000, 'Step to end mask updates.')\nflags.DEFINE_integer('maskupdate_frequency', 100,\n                     'Step interval between mask updates.')\nflags.DEFINE_integer('mask_record_frequency', 0,\n                     'Step interval between mask logging.')\nflags.DEFINE_string(\n    'mask_init_method',\n    default='random',\n    help='If not empty string and mask is not loaded from a checkpoint, '\n    'indicates the method used for mask initialization. One of the following: '\n    '`random`, `erdos_renyi`.')\nflags.DEFINE_integer('prune_begin_step', 2000, 'step to begin pruning')\nflags.DEFINE_integer('prune_end_step', 30000, 'step to end pruning')\nflags.DEFINE_float('end_sparsity', .98, 'desired sparsity of final model.')\nflags.DEFINE_integer('pruning_frequency', 500, 'how often to prune.')\nflags.DEFINE_float('threshold_decay', 0, 'threshold_decay for pruning.')\nflags.DEFINE_string('save_path', '', 'Where to save the model.')\nflags.DEFINE_boolean('save_model', True, 'Whether to save model or not.')\nflags.DEFINE_integer('seed', default=0, help=('Sets the random seed.'))\n\nFLAGS = flags.FLAGS\n\n\n# momentum = 0.9\n# lr = 0.2\n# batch = 100\n# decay = 1e-4\ndef mnist_network_fc(input_batch, reuse=False, model_pruning=False):\n  \"\"\"Define a basic FC network.\"\"\"\n  regularizer = contrib_layers.l2_regularizer(scale=FLAGS.l2_scale)\n  if model_pruning:\n    y = layers.masked_fully_connected(\n        inputs=input_batch[0],\n        num_outputs=300,\n        activation_fn=tf.nn.relu,\n        weights_regularizer=regularizer,\n        reuse=reuse,\n        scope='layer1')\n    y1 = layers.masked_fully_connected(\n        inputs=y,\n        num_outputs=100,\n        activation_fn=tf.nn.relu,\n        weights_regularizer=regularizer,\n        reuse=reuse,\n        scope='layer2')\n    logits = layers.masked_fully_connected(\n        inputs=y1, num_outputs=10, reuse=reuse, activation_fn=None,\n        weights_regularizer=regularizer, scope='layer3')\n  else:\n    y = tf.layers.dense(\n        inputs=input_batch[0],\n        units=300,\n        activation=tf.nn.relu,\n        kernel_regularizer=regularizer,\n        reuse=reuse,\n        name='layer1')\n    y1 = tf.layers.dense(\n        inputs=y,\n        units=100,\n        activation=tf.nn.relu,\n        kernel_regularizer=regularizer,\n        reuse=reuse,\n        name='layer2')\n    logits = tf.layers.dense(inputs=y1, units=10, reuse=reuse,\n                             kernel_regularizer=regularizer, name='layer3')\n\n  cross_entropy = tf.losses.sparse_softmax_cross_entropy(\n      labels=input_batch[1], logits=logits)\n\n  cross_entropy += tf.losses.get_regularization_loss()\n\n  predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)\n  accuracy = tf.reduce_mean(\n      tf.cast(tf.equal(input_batch[1], predictions), tf.float32))\n\n  return cross_entropy, accuracy\n\n\n\n\ndef get_compressed_fc(masks):\n  \"\"\"Given the masks of a sparse network returns the compact network.\"\"\"\n  # Dead input pixels.\n  inds = np.sum(masks[0], axis=1) != 0\n  masks[0] = masks[0][inds]\n  compressed_masks = []\n  for i in range(len(masks)):\n    w = masks[i]\n    # Find neurons that doesn't have any incoming edges.\n    do_w = np.sum(w, axis=0) != 0\n    if i < (len(masks) - 1):\n      # Find neurons that doesn't have any outgoing edges.\n      di_wnext = np.sum(masks[i+1], axis=1) != 0\n      # Kept neurons should have at least one incoming and one outgoing edges.\n      do_w = np.logical_and(do_w, di_wnext)\n    compressed_w = w[:, do_w]\n    compressed_masks.append(compressed_w)\n    if i < (len(masks) - 1):\n      # Remove incoming edges from removed neurons.\n      masks[i+1] = masks[i+1][do_w]\n  sparsities = [np.sum(m == 0) / float(np.size(m)) for m in compressed_masks]\n  sizes = [compressed_masks[0].shape[0]]\n  for m in compressed_masks:\n    sizes.append(m.shape[1])\n  return sparsities, sizes\n\n\ndef main(unused_args):\n  tf.set_random_seed(FLAGS.seed)\n  tf.get_variable_scope().set_use_resource(True)\n  np.random.seed(FLAGS.seed)\n\n  # Load the MNIST data and set up an iterator.\n  mnist_data = input_data.read_data_sets(\n      FLAGS.mnist, one_hot=False, validation_size=0)\n  train_images = mnist_data.train.images\n  test_images = mnist_data.test.images\n  if FLAGS.input_mask_path:\n    reader = tf.train.load_checkpoint(FLAGS.input_mask_path)\n    input_mask = reader.get_tensor('layer1/mask')\n    indices = np.sum(input_mask, axis=1) != 0\n    train_images = train_images[:, indices]\n    test_images = test_images[:, indices]\n  dataset = tf.data.Dataset.from_tensor_slices(\n      (train_images, mnist_data.train.labels.astype(np.int32)))\n  num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size\n  dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0])\n  batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)\n  iterator = batched_dataset.make_one_shot_iterator()\n\n  test_dataset = tf.data.Dataset.from_tensor_slices(\n      (test_images, mnist_data.test.labels.astype(np.int32)))\n  num_test_images = mnist_data.test.images.shape[0]\n  test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images)\n  test_iterator = test_dataset.make_one_shot_iterator()\n\n  # Set up loss function.\n  use_model_pruning = FLAGS.training_method != 'baseline'\n\n  if FLAGS.network_type == 'fc':\n    cross_entropy_train, _ = mnist_network_fc(\n        iterator.get_next(), model_pruning=use_model_pruning)\n    cross_entropy_test, accuracy_test = mnist_network_fc(\n        test_iterator.get_next(), reuse=True, model_pruning=use_model_pruning)\n  else:\n    raise RuntimeError(FLAGS.network + ' is an unknown network type.')\n\n  # Remove extra added ones. Current implementation adds the variables twice\n  # to the collection. Improve this hacky thing.\n  # TODO test the following with the convnet or any other network.\n  if use_model_pruning:\n    for k in ('masks', 'masked_weights', 'thresholds', 'kernel'):\n      # del tf.get_collection_ref(k)[2]\n      # del tf.get_collection_ref(k)[2]\n      collection = tf.get_collection_ref(k)\n      del collection[len(collection)//2:]\n      print(tf.get_collection_ref(k))\n\n  # Set up optimizer and update ops.\n  global_step = tf.train.get_or_create_global_step()\n  batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size\n\n  if FLAGS.optimizer != 'adam':\n    if not use_model_pruning:\n      boundaries = [int(round(s * batch_per_epoch)) for s in [60, 70, 80]]\n    else:\n      boundaries = [int(round(s * batch_per_epoch)) for s\n                    in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]]\n    learning_rate = tf.train.piecewise_constant(\n        global_step, boundaries,\n        values=[FLAGS.learning_rate / (3. ** i)\n                for i in range(len(boundaries) + 1)])\n  else:\n    learning_rate = FLAGS.learning_rate\n\n  if FLAGS.optimizer == 'adam':\n    opt = tf.train.AdamOptimizer(FLAGS.learning_rate)\n  elif FLAGS.optimizer == 'momentum':\n    opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum,\n                                     use_nesterov=FLAGS.use_nesterov)\n  elif FLAGS.optimizer == 'sgd':\n    opt = tf.train.GradientDescentOptimizer(learning_rate)\n  else:\n    raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type')\n  custom_sparsities = {\n      'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale,\n      'layer3': FLAGS.end_sparsity * 0\n  }\n\n  if FLAGS.training_method == 'set':\n    # We override the train op to also update the mask.\n    opt = sparse_optimizers.SparseSETOptimizer(\n        opt, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal)\n  elif FLAGS.training_method == 'static':\n    # We override the train op to also update the mask.\n    opt = sparse_optimizers.SparseStaticOptimizer(\n        opt, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal)\n  elif FLAGS.training_method == 'momentum':\n    # We override the train op to also update the mask.\n    opt = sparse_optimizers.SparseMomentumOptimizer(\n        opt, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,\n        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,\n        grow_init=FLAGS.grow_init,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)\n  elif FLAGS.training_method == 'rigl':\n    # We override the train op to also update the mask.\n    opt = sparse_optimizers.SparseRigLOptimizer(\n        opt, begin_step=FLAGS.maskupdate_begin_step,\n        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,\n        frequency=FLAGS.maskupdate_frequency,\n        drop_fraction=FLAGS.drop_fraction,\n        drop_fraction_anneal=FLAGS.drop_fraction_anneal,\n        initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)\n  elif FLAGS.training_method == 'snip':\n    opt = sparse_optimizers.SparseSnipOptimizer(\n        opt,\n        mask_init_method=FLAGS.mask_init_method,\n        default_sparsity=FLAGS.end_sparsity,\n        custom_sparsity_map=custom_sparsities,\n        use_tpu=False)\n  elif FLAGS.training_method in ('scratch', 'baseline', 'prune'):\n    pass\n  else:\n    raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)\n\n  train_op = opt.minimize(cross_entropy_train, global_step=global_step)\n\n\n  if FLAGS.training_method == 'prune':\n    hparams_string = ('begin_pruning_step={0},sparsity_function_begin_step={0},'\n                      'end_pruning_step={1},sparsity_function_end_step={1},'\n                      'target_sparsity={2},pruning_frequency={3},'\n                      'threshold_decay={4}'.format(\n                          FLAGS.prune_begin_step, FLAGS.prune_end_step,\n                          FLAGS.end_sparsity, FLAGS.pruning_frequency,\n                          FLAGS.threshold_decay))\n    pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)\n    pruning_hparams.set_hparam('weight_sparsity_map',\n                               ['{0}:{1}'.format(k, v) for k, v\n                                in custom_sparsities.items()])\n    print(pruning_hparams)\n    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)\n    with tf.control_dependencies([train_op]):\n      train_op = pruning_obj.conditional_mask_update_op()\n  weight_sparsity_levels = pruning.get_weight_sparsity()\n  global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks())\n  tf.summary.scalar('test_accuracy', accuracy_test)\n  tf.summary.scalar('global_sparsity', global_sparsity)\n  for k, v in zip(pruning.get_masks(), weight_sparsity_levels):\n    tf.summary.scalar('sparsity/%s' % k.name, v)\n  if FLAGS.training_method in ('prune', 'snip', 'baseline'):\n    mask_init_op = tf.no_op()\n    tf.logging.info('No mask is set, starting dense.')\n  else:\n    all_masks = pruning.get_masks()\n    mask_init_op = sparse_utils.get_mask_init_fn(\n        all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity,\n        custom_sparsities)\n\n  if FLAGS.save_model:\n    saver = tf.train.Saver()\n  init_op = tf.global_variables_initializer()\n  hyper_params_string = '_'.join([FLAGS.network_type, str(FLAGS.batch_size),\n                                  str(FLAGS.learning_rate),\n                                  str(FLAGS.momentum),\n                                  FLAGS.optimizer,\n                                  str(FLAGS.l2_scale),\n                                  FLAGS.training_method,\n                                  str(FLAGS.prune_begin_step),\n                                  str(FLAGS.prune_end_step),\n                                  str(FLAGS.end_sparsity),\n                                  str(FLAGS.pruning_frequency),\n                                  str(FLAGS.seed)])\n  tf.io.gfile.makedirs(FLAGS.save_path)\n  filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt')\n  merged_summary_op = tf.summary.merge_all()\n\n  # Run session.\n  if not use_model_pruning:\n    with tf.Session() as sess:\n      summary_writer = tf.summary.FileWriter(FLAGS.save_path,\n                                             graph=tf.get_default_graph())\n      print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy')\n      sess.run([init_op])\n      tic = time.time()\n      with tf.io.gfile.GFile(filename, 'w') as outputfile:\n        for i in range(FLAGS.num_epochs * num_batches):\n          sess.run([train_op])\n\n          if (i % num_batches) == (-1 % num_batches):\n            epoch_time = time.time() - tic\n            loss, accuracy, summary = sess.run([cross_entropy_test,\n                                                accuracy_test,\n                                                merged_summary_op])\n            # Write logs at every test iteration.\n            summary_writer.add_summary(summary, i)\n            log_str = '%d, %.4f, %.4f, %.4f' % (\n                i // num_batches, epoch_time, loss, accuracy)\n            print(log_str)\n            print(log_str, file=outputfile)\n            tic = time.time()\n      if FLAGS.save_model:\n        saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))\n  else:\n    with tf.Session() as sess:\n      summary_writer = tf.summary.FileWriter(FLAGS.save_path,\n                                             graph=tf.get_default_graph())\n      log_str = ','.join([\n          'Epoch', 'Iteration', 'Test loss', 'Test accuracy', 'G_Sparsity',\n          'Sparsity Layer 0', 'Sparsity Layer 1'\n      ])\n      sess.run(init_op)\n      sess.run(mask_init_op)\n      tic = time.time()\n      mask_records = {}\n      with tf.io.gfile.GFile(filename, 'w') as outputfile:\n        print(log_str)\n        print(log_str, file=outputfile)\n        for i in range(FLAGS.num_epochs * num_batches):\n          if (FLAGS.mask_record_frequency > 0 and\n              i % FLAGS.mask_record_frequency == 0):\n            mask_vals = sess.run(pruning.get_masks())\n            # Cast into bool to save space.\n            mask_records[i] = [a.astype(bool) for a in mask_vals]\n          sess.run([train_op])\n          weight_sparsity, global_sparsity_val = sess.run(\n              [weight_sparsity_levels, global_sparsity])\n          if (i % num_batches) == (-1 % num_batches):\n            epoch_time = time.time() - tic\n            loss, accuracy, summary = sess.run([cross_entropy_test,\n                                                accuracy_test,\n                                                merged_summary_op])\n            # Write logs at every test iteration.\n            summary_writer.add_summary(summary, i)\n            log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % (\n                i // num_batches, i, loss, accuracy, global_sparsity_val,\n                weight_sparsity[0], weight_sparsity[1])\n            print(log_str)\n            print(log_str, file=outputfile)\n            mask_vals = sess.run(pruning.get_masks())\n            if FLAGS.network_type == 'fc':\n              sparsities, sizes = get_compressed_fc(mask_vals)\n              print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities,\n                                                              sizes))\n              print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities,\n                                                              sizes),\n                    file=outputfile)\n            tic = time.time()\n      if FLAGS.save_model:\n        saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))\n      if mask_records:\n        np.save(os.path.join(FLAGS.save_path, 'mask_records'), mask_records)\n\n\nif __name__ == '__main__':\n  tf.app.run()\n"
  },
  {
    "path": "rigl/mnist/visualize_mask_records.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Visualizes the dumped masks using matplotlib.\n\nWe count the number of outgoing edges from the input dimensions. For the first\nlayer input dimensions correspond to the input pixels and we can visualize it\nnicely. You can control which layer is visualized by changing `layer_id` and\n`new_shape`. Default is the first layer and we visualize the number of outgoing\nconnections from individual pixels.\n\npython visualize_mask_records.py --records_path=/tmp/mnist/mask_records.npy\n\nTo save the results as gif:\npython visualize_mask_records.py --records_path=/path/to/mask_records.npy \\\n--save_path=/path/to/mask.gif\n\nModified from:\nhttps://eli.thegreenplace.net/2016/drawing-animated-gifs-with-matplotlib/\n\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import flags\n\nfrom matplotlib.animation import FuncAnimation\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nflags.DEFINE_string('records_path', '/tmp/mnist/mask_records.npy',\n                    'Path to load masks records.')\nflags.DEFINE_string('save_path', '', 'Path to save the animation.')\nflags.DEFINE_list('new_shape', '28,28', 'Path for reshaping the units.')\nflags.DEFINE_integer('interval', 100, 'Miliseconds between plot updates.')\nflags.DEFINE_integer('layer_id', 0, 'of which we plot statistics during '\n                     'training.')\nflags.DEFINE_integer('skip_mask', 10, 'number of checkpoints to skip for '\n                     'each frame.')\nflags.DEFINE_integer(\n    'slow_until', 50, 'Number of masks to show with slower '\n    'speed. After this number of frames, we start skipping '\n    'frames to make the video shorter.')\nFLAGS = flags.FLAGS\n\n\ndef main(unused_args):\n  fig, ax = plt.subplots()\n  fig.set_tight_layout(True)\n\n  # Query the figure's on-screen size and DPI. Note that when saving the figure\n  # to a file, we need to provide a DPI for that separately.\n  print('fig size: {0} DPI, size in inches {1}'.format(fig.get_dpi(),\n                                                       fig.get_size_inches()))\n\n  # Plot a scatter that persists (isn't redrawn) and the initial line.\n  mask_records = np.load(FLAGS.records_path, allow_pickle=True).item()\n  sorted_keys = sorted(mask_records.keys())\n  new_shape = [int(a) for a in FLAGS.new_shape]\n  reshape_fn = lambda mask: np.reshape(np.sum(mask, axis=1), new_shape)\n  c_mask = mask_records[sorted_keys[0]][FLAGS.layer_id]\n  im = plt.imshow(reshape_fn(c_mask), interpolation='none', vmin=0, vmax=30)\n  fig.colorbar(im, ax=ax)\n\n  def update(i):\n    \"\"\"Updates the plot.\"\"\"\n    save_iter = sorted_keys[i]\n    label = 'timestep {0}'.format(save_iter)\n\n    print(label)\n    # Update the line and the axes (with a new xlabel). Return a tuple of\n    # \"artists\" that have to be redrawn for this frame.\n    c_data = reshape_fn(mask_records[save_iter][FLAGS.layer_id])\n    im.set_data(c_data)\n    ax.set_xlabel(label)\n    return [im, ax]\n\n  # FuncAnimation will call the 'update' function for each frame; here\n  # animating over 10 frames, with an interval of 200ms between frames.\n  iteration = FLAGS.slow_until\n  frames = (\n      list(np.arange(0, iteration, 1)) +\n      list(np.arange(iteration, len(sorted_keys), FLAGS.skip_mask)))\n\n  anim = FuncAnimation(fig, update, frames=frames, interval=FLAGS.interval)\n  if FLAGS.save_path:\n    anim.save(FLAGS.save_path, dpi=80, writer='imagemagick')\n  else:\n    # plt.show() will just loop the animation forever.\n    plt.show()\n\n\nif __name__ == '__main__':\n  tf.app.run(main)\n"
  },
  {
    "path": "rigl/requirements.txt",
    "content": "absl-py>=0.6.0\ngin-config\nnumpy>=1.15.4\nsix>=1.12.0\ntensorflow>=1.12.0,<2.0  # change to 'tensorflow-gpu' for gpu support \ntensorflow-datasets==2.1\ntensorflow-model-optimization"
  },
  {
    "path": "rigl/rigl_tf2/README.md",
    "content": "# Gradient Flow in Sparse Neural Networks and How Lottery Tickets Win\n<img src=\"https://github.com/google-research/rigl/blob/master/imgs/lottery_init.jpg\" alt=\"Lottery Tickets explained\" width=\"80%\" align=\"middle\">\n**Paper**: [https://arxiv.org/abs/2010.03533](https://arxiv.org/abs/2010.03533)\n\nThis code includes a TF-2 implementation of RigL and some other popular sparse training methods along with pruning, scratch and lottery ticket experiments in a unified codebase.\n\n\nRun pruning experiments.\n\n```\npython train.py --gin_config=configs/prune.gin\n```\n\nRuns lottery training.\n\n```\nLottery experiments:\npython train.py logdir=/tmp/sparse_spectrum/lottery --seed=8 \\\n--gin_config=configs/lottery.gin\n```\n\nRuns scratch training.\n\n```\npython train.py --logdir=/tmp/sparse_spectrum/scratch --seed=8 \\\n--gin_config=configs/scratch.gin\n```\n\nFor assigning different gin flags use gin_bindings. i.e.\n\n```\n`--gin_bindings='network.weight_init_method=\"unit_scaled\"'\n--gin_bindings='unit_scaled_init.init_method=\"faninout_uniform\"'\n```\n\nCalculating eigenvalues of hessian. Use logdir to point different checkpoints.\n\n```\npython train.py --mode=hessian \\\n--gin_config=configs/hessian.gin\n```\n\nPoint `mlp_configs` to run MLP experiments.\n\n```\npython train.py  --gin_config=mlp_configs/prune.gin\n```\n\nRunning interpolation experiments is done as the following:\n\n```\npython interpolate.py --logdir=/tmp/sparse_spectrum/scratch \\\n--gin_config=configs/interpolate.gin \\\n--ckpt_start=/path_to_lottery_logdir/cp-11719.ckpt \\\n--ckpt_end=/path_to_prune_logdir/cp-11719.ckpt \\\n--operative_gin=/path_to_logdir/operative_config.gin \\\n--logdir=/path_to_prune_logdir/ltsolution2prune/\n```\n\n## a journey with train.py.\n\n1) check `main()`.\n\n-   Load preload_gin_config. This is useful for scratch experiments to use same\n    hyper_parameters as the pruning experiments. We can overwrite these with\n    regular `gin_configs/bindings` flags.\n-   Load data and create the network. Network might load its values from a\n    checkpoint. These arguments are set through gin. See utils.get_network for\n    details.\n-   Then the code either trains the network `mode=train_eval` or calculates the\n    hessian: `mode=hessian`.\n\n2) train_model()\n\n-   Create the optimizer and samples a validation set from the training set.\n    Validation set is a subset of the training set and used to get better\n    estimates of certain metrics.\n-   Create the `mask_updater` object. The returned value can be none, then the\n    masks are not updated.\n-   Perform pre-training updates to the network: i.e. meta_initialization.\n-   Set-up checkpointing so that if a checkpoint exist continue from where it is\n    left.\n-   Define gradient function. This function is used during training and for\n    certain other metrics. Note that we have to manually mask the gradients\n    since they are dense.\n-   Define logging function for logging tensorboard event summaries.\n-   Main training loop: save, log, gradient step, mask update.\n"
  },
  {
    "path": "rigl/rigl_tf2/colabs/MnistProp.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"e5O1UdsY202_\"\n      },\n      \"source\": [\n        \"##### Copyright 2020 Google LLC.\\n\",\n        \"\\n\",\n        \"Licensed under the Apache License, Version 2.0 (the \\\"License\\\");\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"jUW1g2_jWmBk\"\n      },\n      \"source\": [\n        \"## Measuring Signal Properties of Various Initializations\\n\",\n        \"For a random signal x ~ normal(0, 1), and a neural network denoted with f(x)=y; ensuring std(y)=1 at initialization is a common goal for popular NN initialization schemes. Here we measure signal propagation for different sparse initializations.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"cellView\": \"form\",\n        \"id\": \"4rvDSX8FFYTI\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"#@title Imports and Definitions\\n\",\n        \"import numpy as np\\n\",\n        \"import os\\n\",\n        \"import tensorflow.compat.v2 as tf\\n\",\n        \"tf.enable_v2_behavior()\\n\",\n        \"\\n\",\n        \"import gin\\n\",\n        \"from rigl import sparse_utils\\n\",\n        \"from rigl.rigl_tf2 import init_utils\\n\",\n        \"from rigl.rigl_tf2 import utils\\n\",\n        \"from rigl.rigl_tf2 import train\\n\",\n        \"from rigl.rigl_tf2 import networks\\n\",\n        \"from rigl.rigl_tf2 import mask_updaters\\n\",\n        \"\\n\",\n        \"import functools\\n\",\n        \"\\n\",\n        \"pruning_params = utils.get_pruning_params(mode='constant', final_sparsity = 0., begin_step=int(1e10))\\n\",\n        \"INPUT_SHAPE = (28, 28, 3)\\n\",\n        \"class Lenet5(tf.keras.Model):\\n\",\n        \"\\n\",\n        \"  def __init__(self,\\n\",\n        \"               input_shape,\\n\",\n        \"               num_classes,\\n\",\n        \"               activation: str,\\n\",\n        \"               hidden_sizes = (6, 16, 120, 84)):\\n\",\n        \"    super(Lenet5, self).__init__()\\n\",\n        \"    l = tf.keras.layers\\n\",\n        \"    kwargs = {'activation': activation}\\n\",\n        \"    filter_fn = lambda _: True\\n\",\n        \"    wrap_fn = functools.partial(utils.maybe_prune_layer, params=pruning_params, filter_fn=filter_fn)\\n\",\n        \"    self.conv1 =  wrap_fn(l.Conv2D(hidden_sizes[0], 5, input_shape=input_shape, **kwargs))\\n\",\n        \"    self.pool1 = l.MaxPool2D(pool_size=(2, 2))\\n\",\n        \"    self.conv2 =  wrap_fn(l.Conv2D(hidden_sizes[1], 5, input_shape=input_shape, **kwargs))\\n\",\n        \"    self.pool2 = l.MaxPool2D(pool_size=(2, 2))\\n\",\n        \"    self.flatten = l.Flatten()\\n\",\n        \"    self.dense1 = wrap_fn(l.Dense(hidden_sizes[2], **kwargs))\\n\",\n        \"    self.dense2 = wrap_fn(l.Dense(hidden_sizes[3], **kwargs))\\n\",\n        \"    self.dense3 = wrap_fn(l.Dense(num_classes, **kwargs))\\n\",\n        \"    self.build((1,)+input_shape)\\n\",\n        \"\\n\",\n        \"  def call(self, inputs):\\n\",\n        \"    x = inputs\\n\",\n        \"    results = {}\\n\",\n        \"    for l_name in ['conv1', 'pool1', 'conv2', 'pool2', 'flatten', 'dense1', 'dense2', 'dense3']:\\n\",\n        \"      x = getattr(self, l_name)(x)\\n\",\n        \"      results[l_name] = x \\n\",\n        \"    return results\\n\",\n        \"\\n\",\n        \"def get_mask_random_numpy(mask_shape, sparsity):\\n\",\n        \"  \\\"\\\"\\\"Creates a random sparse mask with deterministic sparsity.\\n\",\n        \"\\n\",\n        \"  Args:\\n\",\n        \"    mask_shape: list, used to obtain shape of the random mask.\\n\",\n        \"    sparsity: float, between 0 and 1.\\n\",\n        \"\\n\",\n        \"  Returns:\\n\",\n        \"    numpy.ndarray\\n\",\n        \"  \\\"\\\"\\\"\\n\",\n        \"  all_ones = np.abs(np.ones(mask_shape))\\n\",\n        \"  n_zeros = int(np.floor(sparsity * all_ones.size))\\n\",\n        \"  rand_vals = np.random.uniform(size=mask_shape, high=range(1,mask_shape[-1]+1))\\n\",\n        \"  randflat=rand_vals.flatten()\\n\",\n        \"  randflat.sort()\\n\",\n        \"  t = randflat[n_zeros]\\n\",\n        \"  all_ones[rand_vals\\u003c=t] = 0\\n\",\n        \"  return all_ones\\n\",\n        \"\\n\",\n        \"def create_convnet(sparsity=0, weight_init_method = None, scale=2, method='fanin_normal'):\\n\",\n        \"  model = Lenet5(INPUT_SHAPE, num_classes, 'relu')\\n\",\n        \"  if sparsity \\u003e 0:\\n\",\n        \"    all_masks = [layer.pruning_vars[0][1] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]\\n\",\n        \"    for mask in all_masks:\\n\",\n        \"      new_mask = tf.cast(get_mask_random_numpy(mask.shape, sparsity), dtype=mask.dtype)\\n\",\n        \"      mask.assign(new_mask)\\n\",\n        \"    if weight_init_method:\\n\",\n        \"      all_weights = [layer.pruning_vars[0][0] for layer in model.layers if isinstance(layer, utils.PRUNING_WRAPPER)]\\n\",\n        \"      for mask, param in zip(all_masks, all_weights):\\n\",\n        \"        if weight_init_method == 'unit':\\n\",\n        \"          new_init = init_utils.unit_scaled_init(mask, method=method, scale=scale)\\n\",\n        \"        elif weight_init_method == 'layer':\\n\",\n        \"          new_init = init_utils.layer_scaled_init(mask, method=method, scale=scale)\\n\",\n        \"        else:\\n\",\n        \"          raise ValueError\\n\",\n        \"        param.assign(new_init)\\n\",\n        \"  return model\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"fkZ_GNjyYYqZ\"\n      },\n      \"source\": [\n        \"Here we demonstrate how we can calculate the standard deviation of random noise at initialization for `layer-wise` scaled initialization of Liu et. al.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"NsmPRCuZnxDA\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"# Let's create a 95% sparse Lenet-5.\\n\",\n        \"model = create_convnet(sparsity=0.95, weight_init_method='layer', scale=2, method='fanin_normal')\\n\",\n        \"# Random input signal\\n\",\n        \"random_input = tf.random.normal((1000,) + INPUT_SHAPE)\\n\",\n        \"output_dict = model(random_input)\\n\",\n        \"all_stds = []\\n\",\n        \"for k in ['dense1', 'dense2', 'dense3']:\\n\",\n        \"  out_dim = output_dict[k].shape[-1]\\n\",\n        \"  stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)\\n\",\n        \"  all_stds.append(stds)\\n\",\n        \"print('Mean deviation per neuron', np.mean(np.concatenate(all_stds, axis=0)))\\n\",\n        \"print('Mean deviation per output neuron', np.mean(all_stds[-1]))\\n\",\n        \"print('Deviation at output', np.std(random_input))\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"l3ttY88rYovo\"\n      },\n      \"source\": [\n        \"Now we define the code above as a function and use it on a grid to plot signal propagation at different sparsities.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 1,\n      \"metadata\": {\n        \"executionInfo\": {\n          \"elapsed\": 320,\n          \"status\": \"ok\",\n          \"timestamp\": 1613388807790,\n          \"user\": {\n            \"displayName\": \"\",\n            \"photoUrl\": \"\",\n            \"userId\": \"\"\n          },\n          \"user_tz\": -180\n        },\n        \"id\": \"4rfMGKciOOHf\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"def propagate_signal(sparsity, init_method, batch_size=500):\\n\",\n        \"  model = create_convnet(sparsity=sparsity, weight_init_method=init_method)\\n\",\n        \"  random_input = tf.random.normal((batch_size,) + INPUT_SHAPE)\\n\",\n        \"  # print(np.mean(random_input), np.std(random_input))\\n\",\n        \"  output_dict = model(random_input)\\n\",\n        \"  out_std = np.std(output_dict['dense3'])\\n\",\n        \"  all_stds = []\\n\",\n        \"  for k in ['dense1', 'dense2', 'dense3']:\\n\",\n        \"    out_dim = output_dict[k].shape[-1]\\n\",\n        \"    stds = np.std(np.reshape(output_dict[k], (-1,out_dim)),axis=0)\\n\",\n        \"    all_stds.append(stds)\\n\",\n        \"  meanstd = np.mean(np.concatenate(all_stds, axis=0))\\n\",\n        \"  return meanstd, out_std\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"F1rNPLXk7Ins\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import itertools, collections\\n\",\n        \"import numpy as np\\n\",\n        \"all_results = collections.defaultdict(dict)\\n\",\n        \"\\n\",\n        \"N_EXP = 3\\n\",\n        \"for s in np.linspace(0.8,0.98,5):\\n\",\n        \"  print(s)\\n\",\n        \"  for  method, name in zip((None, 'unit', 'layer'), ('Masked Dense', 'Ours', 'Scaled-Init')):\\n\",\n        \"    all_results[name][s] = [propagate_signal(s, method) for _ in range(N_EXP)]\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"id\": \"Sbjc7LxpVGl0\"\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import matplotlib.pyplot as plt\\n\",\n        \"\\n\",\n        \"for k, v in all_results.items():\\n\",\n        \"  # if k == 'Masked Dense':\\n\",\n        \"  #   continue\\n\",\n        \"  x = sorted(v.keys())\\n\",\n        \"  y = [np.mean([vv[1] for vv in v[kk]])+1e-5 for kk in x]\\n\",\n        \"  plt.plot(x, y, label=k)\\n\",\n        \"plt.hlines(y=1, color='r', xmin=0, xmax=1)\\n\",\n        \"plt.yscale('log')\\n\",\n        \"plt.title('std(output)')\\n\",\n        \"plt.legend()\\n\",\n        \"plt.show()\\n\",\n        \"\\n\",\n        \"for k, v in all_results.items():\\n\",\n        \"  # if k == 'Masked Dense':\\n\",\n        \"  #   continue\\n\",\n        \"  x = sorted(v.keys())\\n\",\n        \"  y = [np.mean([vv[0] for vv in v[kk]])+1e-5 for kk in x]\\n\",\n        \"  plt.plot(x, y, label=k)\\n\",\n        \"plt.yscale('log')\\n\",\n        \"plt.hlines(y=1, color='r', xmin=0, xmax=1)\\n\",\n        \"plt.title('mean(std_per_neuron)')\\n\",\n        \"plt.legend()\\n\",\n        \"plt.show()\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [],\n      \"last_runtime\": {\n        \"build_target\": \"//learning/deepmind/public/tools/ml_python:ml_notebook\",\n        \"kind\": \"private\"\n      },\n      \"name\": \"Mnist propagation init sparse .ipynb\",\n      \"provenance\": [\n        {\n          \"file_id\": \"126QJDydlS0V4tQ-KhiN6bSlCOisqLV-Z\",\n          \"timestamp\": 1612472405306\n        },\n        {\n          \"file_id\": \"137QdNeUdTGoAOEPKpPMC09keiwlu12Bh\",\n          \"timestamp\": 1601472560303\n        }\n      ]\n    },\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"name\": \"python3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/dense.gin",
    "content": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 500 steps.\ntraining.log_freq = 200\nnetwork.network_name = 'lenet5'\nnetwork.weight_decay = 0.0005\n# original_hidden_size/sqrt(20) -> 20 comes from 95% sparsity.\n# following lenet has 2399 params vs 2396 (95% sparse lenet5).\nlenet5.hidden_sizes = (6, 16, 120, 84)\nlenet5.use_batch_norm = False\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.05\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n# NON-DEFAULT\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/grasp.gin",
    "content": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\ntraining.gradient_regularization=0\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.1\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n\n# NON-DEFAULT\nnetwork.weight_decay = 0.0002\n# Disable GMP pruning.\npruning.mode = 'constant'\npruning.final_sparsity = 0.\n# Enable one shot pruning.\ntraining.oneshot_prune_fraction = 0.95\ntraining.val_batch_size = 5000\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n\n# Mask Updates\nmask_updater.update_alg = 'rigl_grasp' # Prune part of rigl_grasp corresponds to grasp.\nmask_updater.last_update_step=0  # Never updates.\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/hessian.gin",
    "content": "hessian.batch_size = 60000\nhessian.rows_at_once = 2\n# range(0,100,5) + range(100,2000,100) + range(2000,11719,500)\nhessian.ckpt_ids = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, 10500, 11000, 11500]\n# range(4000,11719,50)\n# For Rigl updates\n# hessian.ckpt_ids = [-499, -999, -1499, -1999, -2499, -2999, -3499, -3999, -4499, -4999, -5499, -5999, -6499, -6999, -7499, -7999, -8499, -8999, -9499, -9999, -10499, -10999, -11499, -500, -1000, -1500, -2000, -2500, -3000, -3500, -4000, -4500, -5000, -5500, -6000, -6500, -7000, -7500, -8000, -8500, -9000, -9500, -10000, -10500, -11000, -11500]\n# hessian.ckpt_ids = [-100, -99, -199, -200, -500, -499, -999, -1999, -1499, -1500, -1000, -2000]\nhessian.overwrite = True\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/interpolate.gin",
    "content": "interpolate.i_start = -0.20\ninterpolate.i_end = 1.20\ninterpolate.n_interpolation = 29\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/lottery.gin",
    "content": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.05\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_path = '/tmp/sparse_spectrum/ckpt-0'\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/prune.gin",
    "content": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\nnetwork.network_name = 'lenet5'\nnetwork.mask_init_path = None\nnetwork.weight_decay = 0.0005\nlenet5.use_batch_norm = False\nlenet5.hidden_sizes = (6, 16, 120, 84)\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.05\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\npruning.mode = 'prune'\npruning.initial_sparsity = 0.0\npruning.final_sparsity = 0.95\npruning.begin_step = 3000\npruning.end_step = 7000\npruning.frequency = 100\n\n\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/rigl.gin",
    "content": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\ntraining.gradient_regularization=0\n\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.05\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_method = None\nnetwork.weight_decay = 0.0005\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\nunit_scaled_init.method='fanin_normal'\n\n# Mask Updates\nmask_updater.update_alg = 'rigl'\nmask_updater.schedule_alg = 'lr'\nmask_updater.update_freq = 100\nmask_updater.init_drop_fraction = 0.3\nmask_updater.last_update_step=-1\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/scratch.gin",
    "content": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\ntraining.gradient_regularization=0\n\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.05\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_method = None\nnetwork.shuffle_mask = False\nnetwork.weight_decay = 0.0005\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/set.gin",
    "content": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\ntraining.gradient_regularization=0\n\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.05\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_method = None\nnetwork.weight_decay = 0.0005\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\nunit_scaled_init.method='fanin_normal'\n\n# Mask Updates\nmask_updater.update_alg = 'set'\nmask_updater.schedule_alg = 'lr'\nmask_updater.update_freq = 100\nmask_updater.init_drop_fraction = 0.3\nmask_updater.last_update_step=-1\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/small_dense.gin",
    "content": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\nnetwork.network_name = 'lenet5'\nnetwork.weight_decay = 0.0005\n# original_hidden_size/sqrt(20) -> 20 comes from 95% sparsity.\n# following lenet has 2399 params vs 2396 (95% sparse lenet5).\nlenet5.hidden_sizes = (3, 3, 27, 20)\nlenet5.use_batch_norm = False\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.05\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n# NON-DEFAULT\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n"
  },
  {
    "path": "rigl/rigl_tf2/configs/snip.gin",
    "content": "training.use_metainit = False\ntraining.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\ntraining.gradient_regularization=0\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.1\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n\n# NON-DEFAULT\nnetwork.weight_decay = 0.0002\n# Disable GMP pruning.\npruning.mode = 'constant'\npruning.final_sparsity = 0.\n# Enable one shot pruning.\ntraining.oneshot_prune_fraction = 0.95\ntraining.val_batch_size = 5000\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n\n# Mask Updates\nmask_updater.update_alg = 'rigl_s' # Prune part of rigl_s corresponds to snip.\nmask_updater.last_update_step=0  # Never updates.\n"
  },
  {
    "path": "rigl/rigl_tf2/init_utils.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Implements initializations for sparse layers.\"\"\"\nimport math\nimport gin\nimport tensorflow as tf\n\n\n@gin.configurable(denylist=['mask'])\ndef unit_scaled_init(mask, method='fanavg_uniform', scale=1.0):\n  \"\"\"Scales the variance of each unit with correct fan_in.\"\"\"\n  mode, distribution = method.strip().split('_')\n  # Lets calculate all fan_ins.\n  if len(mask.shape) == 4:\n    mask_reduced2d = tf.reduce_sum(mask, axis=[0, 1])\n  elif len(mask.shape) == 2:\n    mask_reduced2d = mask\n  else:\n    raise ValueError(f'mask.shape: {mask.shape} must be 4 or 2 dimensional.')\n  fan_ins = tf.reduce_sum(mask_reduced2d, axis=-2)\n  fan_outs = tf.reduce_sum(mask_reduced2d, axis=-1)\n  non_zero_indices = tf.where(mask)  # shape=(NZ, N_dim)\n  # Lets sample each row with the correct fan_in.\n  new_vals = []\n  # Following iterates over each output channel.\n  for index in non_zero_indices:\n    # Get fan_in and out of neurons that the non_zero connection connects.\n    fan_in = fan_ins[index[-1]]\n    fan_out = fan_outs[index[-2]]\n    # Following code is modified from `tensorflow/python/ops/init_ops_v2.py`.\n    if mode == 'fanin':\n      current_scale = scale / max(1., fan_in)\n    elif mode == 'fanout':\n      current_scale = scale / max(1., fan_out)\n    elif mode == 'fanavg':\n      current_scale = scale / max(1., (fan_in + fan_out) / 2.)\n    else:\n      raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')\n    if distribution == 'normal':\n      stddev = math.sqrt(current_scale)\n      new_val = tf.random.normal((1,), 0.0, stddev, mask.dtype)\n    elif distribution == 'uniform':\n      limit = math.sqrt(3.0 * current_scale)\n      new_val = tf.random.uniform((1,), -limit, limit, mask.dtype)\n    else:\n      raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')\n    new_vals.append(new_val)\n  new_vals = tf.concat(new_vals, axis=-1)\n  new_weights = tf.scatter_nd(\n      indices=non_zero_indices,\n      updates=new_vals,\n      shape=mask.shape)\n  return new_weights\n\n\n@gin.configurable(denylist=['mask'])\ndef layer_scaled_init(mask, method='fanavg_uniform', scale=1.0):\n  \"\"\"Scales the variance of each unit with correct fan_in.\"\"\"\n  mode, distribution = method.strip().split('_')\n  init_factory = tf.keras.initializers.VarianceScaling(\n      mode=mode.replace('fan', 'fan_'), scale=scale, distribution=distribution)\n  dense_init = init_factory(shape=mask.shape, dtype=mask.dtype)\n  fraction_nnz = tf.reduce_sum(mask) / tf.size(mask, out_type=mask.dtype)\n  new_weights = dense_init / tf.math.sqrt(fraction_nnz)\n  return new_weights\n\n\ndef unit_scaled_init_tf1(mask,\n                         method='fanavg_uniform',\n                         scale=1.0,\n                         dtype=tf.float32):\n  \"\"\"Scales the variance of each unit with correct fan_in.\"\"\"\n  mode, distribution = method.strip().split('_')\n  # Lets calculate all fan_ins.\n  if len(mask.shape) == 4:\n    mask_reduced2d = tf.reduce_sum(mask, axis=[0, 1])\n  elif len(mask.shape) == 2:\n    mask_reduced2d = mask\n  else:\n    raise ValueError(f'mask.shape: {mask.shape} must be 4 or 2 dimensional.')\n  fan_ins = tf.reduce_sum(mask_reduced2d, axis=-2)\n  fan_outs = tf.reduce_sum(mask_reduced2d, axis=-1)\n  non_zero_indices = tf.where(mask)  # shape=(NZ, N_dim)\n\n  # Lets sample each row with the correct fan_in.\n  def new_val_fn(index):\n    # Get fan_in and out of neurons that the non_zero connection connects.\n    fan_in = fan_ins[index[-1]]\n    fan_out = fan_outs[index[-2]]\n    # Following code is modified from `tensorflow/python/ops/init_ops_v2.py`.\n    if mode == 'fanin':\n      current_scale = scale / tf.math.maximum(1., fan_in)\n    elif mode == 'fanout':\n      current_scale = scale / tf.math.maximum(1., fan_out)\n    elif mode == 'fanavg':\n      current_scale = scale / tf.math.maximum(1., (fan_in + fan_out) / 2.)\n    else:\n      raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')\n    if distribution == 'normal':\n      stddev = tf.math.sqrt(current_scale)\n      new_val = tf.random.normal((1,), 0.0, stddev, dtype)\n    elif distribution == 'uniform':\n      limit = tf.math.sqrt(3.0 * current_scale)\n      new_val = tf.random.uniform((1,), -limit, limit, dtype)\n    else:\n      raise ValueError(f'mode: {mode} must can be fanin, fanout, fanavg.')\n    return new_val\n\n  # Following iterates over each output channel.\n  new_vals = tf.squeeze(tf.map_fn(new_val_fn, non_zero_indices, dtype=dtype))\n  new_weights = tf.scatter_nd(\n      indices=non_zero_indices, updates=new_vals, shape=mask.shape)\n  return new_weights\n"
  },
  {
    "path": "rigl/rigl_tf2/interpolate.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Script for interpolating between checkpoints.\n\"\"\"\n\nimport os\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\n\nimport gin\nimport numpy as np\nfrom rigl.rigl_tf2 import utils\nimport tensorflow.compat.v2 as tf\n\nfrom pyglib import timer\nFLAGS = flags.FLAGS\nflags.DEFINE_string('logdir', '/tmp/sparse_spectrum/interpolation',\n                    'Directory to save experiment in.')\nflags.DEFINE_string('ckpt_start', '/tmp/sparse_spectrum/cp-0001.ckpt',\n                    'Directory to save experiment in.')\nflags.DEFINE_string('ckpt_end', '/tmp/sparse_spectrum/cp-0041.ckpt',\n                    'Directory to save experiment in.')\nflags.DEFINE_string(\n    'preload_gin_config', '', 'If non-empty reads a gin file '\n    'before parsing gin_config and bindings. This is useful,'\n    'when you want to start from a configuration of another '\n    'run. Values are then overwritten by additional configs '\n    'and bindings provided.')\nflags.DEFINE_bool('use_tpu', True, 'Whether to run on TPU or not.')\nflags.DEFINE_bool('eval_on_train', True, 'Whether to evaluate on training set.')\nflags.DEFINE_integer('load_mask_from', 0, '0 means start checkpoint, 1 means '\n                     'end checkpoint. -1 means no mask loaded.')\nflags.DEFINE_enum('mode', 'train_eval', ('train_eval', 'hessian'),\n                  'Whether to run on TPU or not.')\nflags.DEFINE_string(\n    'tpu_job_name', 'tpu_worker',\n    'Name of the TPU worker job. This is required when having '\n    'multiple TPU worker jobs.')\nflags.DEFINE_string('master', None, 'TPU worker.')\nflags.DEFINE_multi_string('gin_config', [],\n                          'List of paths to the config files.')\nflags.DEFINE_multi_string('gin_bindings', [],\n                          'Newline separated list of Gin parameter bindings.')\n\n\ndef test_model(model, d_test, batch_size=1000):\n  \"\"\"Tests the model and calculates cross entropy loss and accuracy.\"\"\"\n  test_loss = tf.keras.metrics.Mean(name='test_loss')\n  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n      name='test_accuracy')\n  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n  for x, y in d_test.batch(batch_size):\n    predictions = model(x, training=False)\n    batch_loss = loss_object(y, predictions)\n    test_loss.update_state(batch_loss)\n    test_accuracy.update_state(y, predictions)\n  logging.info('Test loss: %f', test_loss.result().numpy())\n  logging.info('Test accuracy: %f', test_accuracy.result().numpy())\n  return test_loss.result().numpy(), test_accuracy.result().numpy()\n\n\n@gin.configurable(\n    'interpolate',\n    denylist=['model_start', 'model_end', 'model_inter', 'd_set'])\ndef interpolate(model_start, model_end, model_inter, d_set,\n                i_start=-0.2, i_end=1.2, n_interpolation=29):\n  \"\"\"Interpolates between 2 sparse networks linearly and evaluates.\"\"\"\n  interpolation_coefs = np.linspace(i_start, i_end, n_interpolation)\n  all_scores = {}\n  for i_coef in interpolation_coefs:\n    logging.info('Interpolating with: %f', i_coef)\n    for var_start, var_end, var_inter in zip(model_start.trainable_variables,\n                                             model_end.trainable_variables,\n                                             model_inter.trainable_variables):\n      new_value = (1 - i_coef) * var_start + i_coef * var_end\n      var_inter.assign(new_value)\n    scores = test_model(model_inter, d_set)\n    all_scores[i_coef] = scores\n  return all_scores\n\n\ndef main(unused_argv):\n  init_timer = timer.Timer()\n  init_timer.Start()\n  if FLAGS.preload_gin_config:\n    # Load default values from the original experiment, always the first one.\n    with gin.unlock_config():\n      gin.parse_config_file(FLAGS.preload_gin_config, skip_unknown=True)\n    logging.info('Operative Gin configurations loaded from: %s',\n                 FLAGS.preload_gin_config)\n  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)\n\n  data_train, data_test, info = utils.get_dataset()\n  input_shape = info.features['image'].shape\n  num_classes = info.features['label'].num_classes\n  logging.info('Input Shape: %s', input_shape)\n  logging.info('train samples: %s', info.splits['train'].num_examples)\n  logging.info('test samples: %s', info.splits['test'].num_examples)\n  data_eval = data_train if FLAGS.eval_on_train else data_test\n  pruning_params = utils.get_pruning_params(mode='constant')\n  mask_load_dict = {-1: None, 0: FLAGS.ckpt_start, 1: FLAGS.ckpt_end}\n  mask_path = mask_load_dict[FLAGS.load_mask_from]\n  # Currently we interpolate only on the same sparse space.\n  model_start = utils.get_network(\n      pruning_params,\n      input_shape,\n      num_classes,\n      mask_init_path=mask_path,\n      weight_init_path=FLAGS.ckpt_start)\n  model_start.summary()\n  model_end = utils.get_network(\n      pruning_params,\n      input_shape,\n      num_classes,\n      mask_init_path=mask_path,\n      weight_init_path=FLAGS.ckpt_end)\n  model_end.summary()\n\n  # Create a third network for interpolation.\n  model_inter = utils.get_network(\n      pruning_params,\n      input_shape,\n      num_classes,\n      mask_init_path=mask_path,\n      weight_init_path=FLAGS.ckpt_end)\n  logging.info('Performance at init (model_start:')\n  test_model(model_start, data_eval)\n  logging.info('Performance at init (model_end:')\n  test_model(model_end, data_eval)\n  all_results = interpolate(model_start=model_start, model_end=model_end,\n                            model_inter=model_inter, d_set=data_eval)\n\n  tf.io.gfile.makedirs(FLAGS.logdir)\n  results_path = os.path.join(FLAGS.logdir, 'all_results')\n  with tf.io.gfile.GFile(results_path, 'wb') as f:\n    np.save(f, all_results)\n  logging.info('Total runtime: %.3f s', init_timer.GetDuration())\n  logconfigfile_path = os.path.join(FLAGS.logdir, 'operative_config.gin')\n  with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:\n    f.write('# Gin-Config:\\n %s' % gin.config.operative_config_str())\n\n\nif __name__ == '__main__':\n  tf.enable_v2_behavior()\n  app.run(main)\n"
  },
  {
    "path": "rigl/rigl_tf2/mask_updaters.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Implements RigL.\"\"\"\nimport gin\nfrom rigl.rigl_tf2 import utils\nimport tensorflow as tf\n\n\ndef get_all_layers(model, filter_fn=lambda _: True):\n  \"\"\"Gets all layers of a model and layers of a layer if it is a keras.Model.\"\"\"\n  all_layers = []\n  for l in model.layers:\n    if hasattr(l, 'layers'):\n      all_layers.extend(get_all_layers(l, filter_fn=filter_fn))\n    elif filter_fn(l):\n      all_layers.append(l)\n  return all_layers\n\n\ndef is_pruned(layer):\n  return isinstance(layer, utils.PRUNING_WRAPPER) and layer.trainable\n\n\nclass MaskUpdater(object):\n  \"\"\"Base class for mask update algorithms.\n\n    Attributes:\n    model: tf.keras.Model\n    optimizer: tf.train.Optimizer\n    use_stateless: bool, if True stateless operations are used. This is\n      important for multi-worker jobs not to diverge.\n    stateless_seed_offset: int, added to the seed of stateless operations.\n      Use this to create randomness without divergence across workers.\n  \"\"\"\n\n  def __init__(self, model, optimizer, use_stateless=True,\n               stateless_seed_offset=0, loss_fn=None):\n    self._model = model\n    self._optimizer = optimizer\n    self._use_stateless = use_stateless\n    self._stateless_seed_offset = stateless_seed_offset\n    self._loss_fn = loss_fn\n    self.val_x = self.val_y = None\n\n  def prune_masks(self, prune_fraction):\n    \"\"\"Updates a fraction of weights in each layer.\"\"\"\n    all_masks, all_vars = self.get_vars_and_masks()\n    drop_scores = self.get_drop_scores(all_vars, all_masks)\n    grow_score = None\n    for mask, var, drop_score in zip(all_masks, all_vars, drop_scores):\n      self.generic_mask_update(mask, var, drop_score, grow_score,\n                               prune_fraction)\n\n  def update_masks(self, drop_fraction):\n    \"\"\"Updates a fraction of weights in each layer.\"\"\"\n    all_masks, all_vars = self.get_vars_and_masks()\n    drop_scores = self.get_drop_scores(all_vars, all_masks)\n    grow_scores = self.get_grow_scores(all_vars, all_masks)\n    for mask, var, drop_score, grow_score in zip(all_masks, all_vars,\n                                                 drop_scores, grow_scores):\n      self.generic_mask_update(mask, var, drop_score, grow_score, drop_fraction)\n\n  def get_all_pruning_layers(self):\n    \"\"\"Returns all pruned layers from the model.\"\"\"\n    if hasattr(self._model, 'layers'):\n      return get_all_layers(self._model, filter_fn=is_pruned)\n    else:\n      return [self._model] if is_pruned(self._model) else []\n\n  def get_vars_and_masks(self):\n    \"\"\"Gets all masked variables and corresponding masks.\"\"\"\n    all_masks = []\n    all_vars = []\n    for layer in self.get_all_pruning_layers():\n      for var, mask, _ in layer.pruning_vars:\n        all_vars.append(var)\n        all_masks.append(mask)\n    return all_masks, all_vars\n\n  def get_drop_scores(self, all_vars, all_masks):\n    raise NotImplementedError\n\n  def get_grow_scores(self, all_vars, all_masks):\n    raise NotImplementedError\n\n  def generic_mask_update(self, mask, var, score_drop, score_grow,\n                          drop_fraction, reinit_when_same=False):\n    \"\"\"Prunes+grows connections, all tensors same shape.\"\"\"\n    n_total = tf.size(score_drop)\n    n_ones = tf.cast(tf.reduce_sum(mask), dtype=tf.int32)\n    n_prune = tf.cast(\n        tf.cast(n_ones, dtype=tf.float32) * drop_fraction, tf.int32)\n    n_keep = n_ones - n_prune\n\n    # Sort the entire array since the k needs to be constant for TPU.\n    _, sorted_indices = tf.math.top_k(\n        tf.reshape(score_drop, [-1]), k=n_total)\n    sorted_indices_ex = tf.expand_dims(sorted_indices, 1)\n    # We will have zeros after having `n_keep` many ones.\n    new_values = tf.where(\n        tf.range(n_total) < n_keep,\n        tf.ones_like(sorted_indices, dtype=mask.dtype),\n        tf.zeros_like(sorted_indices, dtype=mask.dtype))\n    mask1 = tf.scatter_nd(sorted_indices_ex, new_values,\n                          new_values.shape)\n    if score_grow is not None:\n      # Flatten the scores.\n      score_grow = tf.reshape(score_grow, [-1])\n      # Set scores of the enabled connections(ones) to min(s) - 1, so that they\n      # have the lowest scores.\n      score_grow_lifted = tf.where(\n          tf.math.equal(mask1, 1),\n          tf.ones_like(mask1) * (tf.reduce_min(score_grow) - 1), score_grow)\n      _, sorted_indices = tf.math.top_k(score_grow_lifted, k=n_total)\n      sorted_indices_ex = tf.expand_dims(sorted_indices, 1)\n      new_values = tf.where(\n          tf.range(n_total) < n_prune,\n          tf.ones_like(sorted_indices, dtype=mask.dtype),\n          tf.zeros_like(sorted_indices, dtype=mask.dtype))\n      mask2 = tf.scatter_nd(sorted_indices_ex, new_values, new_values.shape)\n      # Ensure masks are disjoint.\n      tf.debugging.assert_near(tf.reduce_sum(mask1 * mask2), 0.)\n\n      # Let's set the weights of the growed connections.\n      mask2_reshaped = tf.reshape(mask2, mask.shape)\n      # Set the values of the new connections.\n      grow_tensor = tf.zeros_like(var, dtype=var.dtype)\n      if reinit_when_same:\n        # If dropped and grown, we re-initialize.\n        new_connections = tf.math.equal(mask2_reshaped, 1)\n      else:\n        new_connections = tf.math.logical_and(\n            tf.math.equal(mask2_reshaped, 1), tf.math.equal(mask, 0))\n      new_weights = tf.where(new_connections, grow_tensor, var)\n      var.assign(new_weights)\n      # Ensure there is no momentum value for new connections\n      self.reset_momentum(var, new_connections)\n      mask_combined = tf.reshape(mask1 + mask2, mask.shape)\n    else:\n      mask_combined = tf.reshape(mask1, mask.shape)\n    mask.assign(mask_combined)\n\n  def reset_momentum(self, var, new_connections):\n    for s_name in self._optimizer.get_slot_names():\n      # Momentum variable for example, we reset the aggregated values to zero.\n      optim_var = self._optimizer.get_slot(var, s_name)\n      new_values = tf.where(new_connections,\n                            tf.zeros_like(optim_var), optim_var)\n      optim_var.assign(new_values)\n\n  def _random_uniform(self, *args, **kwargs):\n    if self._use_stateless:\n      c_seed = self._stateless_seed_offset + kwargs['seed']\n      kwargs['seed'] = tf.cast(\n          tf.stack([c_seed, self._optimizer.iterations]), tf.int32)\n      return tf.random.stateless_uniform(*args, **kwargs)\n    else:\n      return tf.random.uniform(*args, **kwargs)\n\n  def _random_normal(self, *args, **kwargs):\n    if self._use_stateless:\n      c_seed = self._stateless_seed_offset + kwargs['seed']\n      kwargs['seed'] = tf.cast(\n          tf.stack([c_seed, self._optimizer.iterations]), tf.int32)\n      return tf.random.stateless_normal(*args, **kwargs)\n    else:\n      return tf.random.normal(*args, **kwargs)\n\n  def set_validation_data(self, val_x, val_y):\n    self.val_x, self.val_y = val_x, val_y\n\n  def _get_gradients(self, all_vars):\n    \"\"\"Returns the gradients of the given weights using the validation data.\"\"\"\n    with tf.GradientTape() as tape:\n      batch_loss = self._loss_fn(self.val_x, self.val_y)\n    grads = tape.gradient(batch_loss, all_vars)\n    if grads:\n      grads = tf.distribute.get_replica_context().all_reduce('sum', grads)\n    return grads\n\n\nclass SET(MaskUpdater):\n  \"\"\"Implementation of dynamic sparsity optimizers.\n\n  Implementation of SET.\n  See https://www.nature.com/articles/s41467-018-04316-3\n  This optimizer wraps a regular optimizer and performs updates on the masks\n  according to schedule given.\n  \"\"\"\n\n  def get_drop_scores(self, all_vars, all_masks, noise_std=0):\n    def score_fn(mask, var):\n      score = tf.math.abs(mask*var)\n      if noise_std != 0:\n        score += self._random_normal(\n            score.shape, stddev=noise_std, dtype=score.dtype,\n            seed=(hash(var.name + 'drop')))\n      return score\n    return [score_fn(mask, var) for mask, var in zip(all_masks, all_vars)]\n\n  def get_grow_scores(self, all_vars, all_masks):\n    return [self._random_uniform(var.shape, seed=hash(var.name + 'grow'))\n            for var in all_vars]\n\n\nclass RigL(MaskUpdater):\n  \"\"\"Implementation of dynamic sparsity optimizers.\n\n  Implementation of RigL.\n  \"\"\"\n\n  def get_drop_scores(self, all_vars, all_masks, noise_std=0):\n    def score_fn(mask, var):\n      score = tf.math.abs(mask*var)\n      if noise_std != 0:\n        score += self._random_normal(\n            score.shape, stddev=noise_std, dtype=score.dtype,\n            seed=(hash(var.name + 'drop')))\n      return score\n    return [score_fn(mask, var) for mask, var in zip(all_masks, all_vars)]\n\n  def get_grow_scores(self, all_vars, all_masks):\n    return [tf.abs(g) for g in self._get_gradients(all_vars)]\n\n\nclass RigLInverted(RigL):\n  \"\"\"Implementation of dynamic sparsity optimizers.\n\n  Implementation of RigL.\n  \"\"\"\n\n  def get_grow_scores(self, all_vars, all_masks):\n    return [-tf.abs(g) for g in self._get_gradients(all_vars)]\n\n\n\n\nclass UpdateSchedule(object):\n  \"\"\"Base class for mask update algorithms.\n\n    Attributes:\n    mask_updater: MaskUpdater, to invoke.\n    update_freq: int, frequency of mask updates.\n    init_drop_fraction: float, initial drop fraction.\n  \"\"\"\n\n  def __init__(self, mask_updater, init_drop_fraction, update_freq,\n               last_update_step):\n    self._mask_updater = mask_updater\n    self.update_freq = update_freq\n    self.last_update_step = last_update_step\n    self.init_drop_fraction = tf.convert_to_tensor(init_drop_fraction)\n    self.last_drop_fraction = 0\n\n  def get_drop_fraction(self, step):\n    raise NotImplementedError\n\n  def is_update_iter(self, step):\n    \"\"\"Returns true if it is a valid mask update step.\"\"\"\n    # last_update_step < 0 means, there is no last step.\n    # last_update_step = 0 means, never update.\n    tf.debugging.Assert(step >= 0, [step])\n\n    if self.last_update_step < 0:\n      is_valid_step = True\n    elif self.last_update_step == 0:\n      is_valid_step = False\n    else:\n      is_valid_step = step <= self.last_update_step\n\n    return tf.logical_and(is_valid_step, step % self.update_freq == 0)\n\n  def update(self, step, check_update_iter=True):\n    if check_update_iter:\n      tf.debugging.Assert(self.is_update_iter(step), [step])\n    self.last_drop_fraction = self.get_drop_fraction(step)\n\n    def true_fn():\n      self._mask_updater.update_masks(self.last_drop_fraction)\n\n    tf.cond(self.last_drop_fraction > 0., true_fn, lambda: None)\n\n  def prune(self, prune_fraction):\n    self.last_drop_fraction = prune_fraction\n    self._mask_updater.prune_masks(self.last_drop_fraction)\n\n  def set_validation_data(self, val_x, val_y):\n    self._mask_updater.set_validation_data(val_x, val_y)\n\n\nclass ConstantUpdateSchedule(UpdateSchedule):\n  \"\"\"Updates a constant fraction of connections.\"\"\"\n\n  def get_drop_fraction(self, step):\n    return self.init_drop_fraction\n\n\nclass CosineUpdateSchedule(UpdateSchedule):\n  \"\"\"Updates a constant fraction of connections.\"\"\"\n\n  def __init__(self, *args, **kwargs):\n    super().__init__(*args, **kwargs)\n    self._drop_fraction_fn = tf.keras.experimental.CosineDecay(\n        self.init_drop_fraction,\n        self.last_update_step,\n        alpha=0.0,\n        name='cosine_drop_fraction')\n\n  def get_drop_fraction(self, step):\n    return self._drop_fraction_fn(step)\n\n\nclass ScaledLRUpdateSchedule(UpdateSchedule):\n  \"\"\"Scales the drop fraction with learning rate.\"\"\"\n\n  def __init__(self, mask_updater, init_drop_fraction, update_freq,\n               last_update_step, optimizer):\n    self._optimizer = optimizer\n    self._initial_lr = self._get_lr(0)\n    super(ScaledLRUpdateSchedule, self).__init__(\n        mask_updater, init_drop_fraction, update_freq, last_update_step)\n\n  def _get_lr(self, step):\n    if isinstance(self._optimizer.lr, tf.Variable):\n      return self._optimizer.lr.numpy()\n    else:\n      return self._optimizer.lr(step)\n\n  def get_drop_fraction(self, step):\n    current_lr = self._get_lr(step)\n    return (self.init_drop_fraction / self._initial_lr) * current_lr\n\n\n\n\n@gin.configurable(\n    'mask_updater',\n    allowlist=[\n        'update_alg',\n        'schedule_alg',\n        'update_freq',\n        'init_drop_fraction',\n        'last_update_step',\n        'use_stateless',\n    ])\ndef get_mask_updater(\n    model,\n    optimizer,\n    loss_fn,\n    update_alg='',\n    schedule_alg='lr',\n    update_freq=100,\n    init_drop_fraction=0.3,\n    last_update_step=-1,\n    use_stateless=True):\n  \"\"\"Retrieves the update algorithm and passes it to the schedule object.\"\"\"\n  if not update_alg:\n    return None\n  elif update_alg == 'set':\n    mask_updater = SET(model, optimizer, use_stateless=use_stateless)\n  elif update_alg == 'rigl':\n    mask_updater = RigL(\n        model, optimizer, loss_fn=loss_fn, use_stateless=use_stateless)\n  elif update_alg == 'rigl_inverted':\n    mask_updater = RigLInverted(\n        model, optimizer, loss_fn=loss_fn, use_stateless=use_stateless)\n  else:\n    raise ValueError('update_alg:%s  is not valid.' % update_alg)\n  if schedule_alg == 'lr':\n    update_schedule = ScaledLRUpdateSchedule(\n        mask_updater, init_drop_fraction, update_freq, last_update_step,\n        optimizer)\n  elif schedule_alg == 'cosine':\n    update_schedule = CosineUpdateSchedule(\n        mask_updater, init_drop_fraction, update_freq, last_update_step)\n  elif schedule_alg == 'constant':\n    update_schedule = ConstantUpdateSchedule(mask_updater, init_drop_fraction,\n                                             update_freq, last_update_step)\n  else:\n    raise ValueError('schedule_alg:%s  is not valid.' % schedule_alg)\n  return update_schedule\n"
  },
  {
    "path": "rigl/rigl_tf2/metainit.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"MetaInit algorithm to dynamically initialize neural nets.\"\"\"\n\nimport numpy as np\nimport tensorflow.compat.v1 as tf1\nimport tensorflow.compat.v2 as tf\n\n\nclass ScaleSGD(tf1.train.Optimizer):\n  \"\"\"SGD optimizer that only trains the scales of the parameters.\n\n  This optimizer only tunes the scales of weight matrices.\n  \"\"\"\n\n  def __init__(self, learning_rate=0.1, momentum=0.9, mindim=3,\n               use_locking=False, name=\"ScaleSGD\"):\n    super(ScaleSGD, self).__init__(use_locking, name)\n    self._lr = learning_rate\n    self._momentum = momentum\n    self._mindim = mindim\n\n    # Tensor versions of the constructor arguments, created in _prepare().\n    self._lr_t = None\n    self._momentum_t = None\n\n  def _prepare(self):\n    self._lr_t = tf1.convert_to_tensor(self._lr, name=\"learning_rate\")\n    self._momentum_t = tf1.convert_to_tensor(self._momentum, name=\"momentum_t\")\n\n  def _create_slots(self, var_list):\n    for v in var_list:\n      self._get_or_make_slot_with_initializer(v,\n                                              tf1.constant_initializer(0),\n                                              tf1.TensorShape([]),\n                                              tf1.float32,\n                                              \"m\",\n                                              self._name)\n\n  def _resource_apply_dense(self, grad, handle):\n    var = handle\n    m = self.get_slot(var, \"m\")\n\n    if len(var.shape) < self._mindim:\n      return tf.group(*[var, m])\n    lr_t = tf1.cast(self._lr_t, var.dtype.base_dtype)\n    momentum_t = tf1.cast(self._momentum_t, var.dtype.base_dtype)\n\n    scale = tf1.sqrt(tf1.reduce_sum(var ** 2))\n    dscale = tf1.sign(tf1.reduce_sum(var * grad) / (scale + 1e-12))\n\n    m_t = m.assign(momentum_t * m - lr_t * dscale)\n\n    new_scale = scale + m_t\n    var_update = tf1.assign(var, var * new_scale / (scale + 1e-12))\n    return tf1.group(*[var_update, m_t])\n\n  def _apply_dense(self, grad, var):\n    return self._resource_apply_dense(grad, var)\n\n  def _apply_sparse(self, grad, var):\n    raise NotImplementedError(\"Sparse gradient updates are not supported.\")\n\n\ndef meta_init(model, loss, x_shape, y_shape, n_params, learning_rate=0.001,\n              momentum=0.9, meta_steps=1000, eps=1e-5, mask_gradient_fn=None):\n  \"\"\"Run MetaInit algorithm. See `https://papers.nips.cc/paper/9427-metainit-initializing-learning-by-learning-to-initialize`\"\"\"\n  optimizer = ScaleSGD(learning_rate, momentum=momentum)\n\n  for _ in range(meta_steps):\n    x = np.random.normal(0, 1, x_shape)\n    y = np.random.randint(0, y_shape[1], y_shape[0])\n\n    with tf.GradientTape(persistent=True) as tape:\n      batch_loss = loss(y, model(x, training=True))\n      grad = tape.gradient(batch_loss, model.trainable_variables)\n      if mask_gradient_fn is not None:\n        grad = mask_gradient_fn(model, grad, model.trainable_variables)\n      prod = tape.gradient(tf.reduce_sum([tf.reduce_sum(g**2) / 2\n                                          for g in grad]),\n                           model.trainable_variables)\n      if mask_gradient_fn is not None:\n        prod = mask_gradient_fn(model, prod, model.trainable_variables)\n      meta_loss = [tf.abs(1 - ((g - p) / (g + eps * tf.stop_gradient(\n          (2 * tf.cast(tf.greater_equal(g, 0), tf.float32)) - 1))))\n                   for g, p in zip(grad, prod)]\n      if mask_gradient_fn is not None:\n        meta_loss = mask_gradient_fn(model, meta_loss,\n                                     model.trainable_variables)\n      meta_loss = sum([tf.reduce_sum(m) for m in meta_loss]) / n_params\n    tf.summary.scalar(\"meta_loss\", meta_loss)\n\n    gradients = tape.gradient(meta_loss, model.trainable_variables)\n    if mask_gradient_fn is not None:\n      gradients = mask_gradient_fn(model, gradients, model.trainable_variables)\n    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n"
  },
  {
    "path": "rigl/rigl_tf2/mlp_configs/dense.gin",
    "content": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 500 steps.\ntraining.log_freq = 200\nnetwork.network_name = 'mlp'\nnetwork.weight_decay = 0.0001\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.2\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n# NON-DEFAULT\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n\n"
  },
  {
    "path": "rigl/rigl_tf2/mlp_configs/lottery.gin",
    "content": "# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_path = '/tmp/sparse_spectrum/ckpt-0'\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n\n"
  },
  {
    "path": "rigl/rigl_tf2/mlp_configs/prune.gin",
    "content": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\nnetwork.network_name = 'mlp'\nnetwork.mask_init_path = None\nnetwork.weight_decay = 0.0001\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.2\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\npruning.mode = 'prune'\npruning.initial_sparsity = 0.0\npruning.final_sparsity = 0.98\npruning.begin_step = 3000\npruning.end_step = 7000\npruning.frequency = 100\n\n\n"
  },
  {
    "path": "rigl/rigl_tf2/mlp_configs/rigl.gin",
    "content": "training.use_metainit = False\n\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_method = None\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\nunit_scaled_init.method='fanin_normal'\n\n# Mask Updates\nmask_updater.update_alg = 'rigl'\nmask_updater.schedule_alg = 'lr'\nmask_updater.update_freq = 500\nmask_updater.init_drop_fraction = 0.3\nmask_updater.last_update_step=-1\n"
  },
  {
    "path": "rigl/rigl_tf2/mlp_configs/scratch.gin",
    "content": "training.use_metainit = False\n\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_method = None\nnetwork.shuffle_mask = False\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n"
  },
  {
    "path": "rigl/rigl_tf2/mlp_configs/set.gin",
    "content": "training.use_metainit = False\n# NON-DEFAULT\nnetwork.mask_init_path = '/tmp/sparse_spectrum/ckpt-11719'\nnetwork.weight_init_method = None\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\nunit_scaled_init.method='fanin_normal'\n\n# Mask Updates\nmask_updater.update_alg = 'set'\nmask_updater.schedule_alg = 'lr'\nmask_updater.update_freq = 500\nmask_updater.init_drop_fraction = 0.3\nmask_updater.last_update_step=-1\n"
  },
  {
    "path": "rigl/rigl_tf2/mlp_configs/small_dense.gin",
    "content": "training.total_steps = 11719 # 6e4/128*25 epochs=11719\ntraining.batch_size = 128\ntraining.save_freq = 500  # Log every 5 steps.\ntraining.log_freq = 200\nnetwork.network_name = 'mlp'\nnetwork.weight_decay = 0.0001\n# (28*28*300 + 300*100 + 100*10)*0.02 + 410 = 5734 params\n# (28*28*8 + 8*8 + 8*10) + 8+8+10 = 6442\nmlp.hidden_sizes = (8, 8)\noptimizer.name = \"momentum\"\noptimizer.learning_rate = 0.2\noptimizer.momentum = 0.9\noptimizer.clipvalue = None\noptimizer.clipnorm = None\n# NON-DEFAULT\npruning.mode = 'constant'\npruning.final_sparsity = 0.\npruning.begin_step = 100000000 # High begin_step, so it never starts.\n"
  },
  {
    "path": "rigl/rigl_tf2/networks.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"This module has networks used in experiments.\n\"\"\"\nfrom typing import Optional, Tuple  # Non-expensive-to-import types.\nimport gin\n\nimport tensorflow.compat.v2 as tf\n\n\n@gin.configurable(allowlist=['hidden_sizes', 'use_batch_norm'])\ndef lenet5(input_shape,\n           num_classes,\n           activation,\n           kernel_regularizer,\n           use_batch_norm = False,\n           hidden_sizes = (6, 16, 120, 84)):\n  \"\"\"Lenet5 implementation.\"\"\"\n  network = tf.keras.Sequential()\n  kwargs = {\n      'activation': activation,\n      'kernel_regularizer': kernel_regularizer,\n  }\n  def maybe_add_batchnorm():\n    if use_batch_norm:\n      network.add(tf.keras.layers.BatchNormalization())\n  network.add(tf.keras.layers.Conv2D(\n      hidden_sizes[0], 5, input_shape=input_shape, **kwargs))\n  network.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2)))\n  maybe_add_batchnorm()\n  network.add(tf.keras.layers.Conv2D(hidden_sizes[1], 5, **kwargs))\n  network.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2)))\n  maybe_add_batchnorm()\n  network.add(tf.keras.layers.Flatten())\n  network.add(tf.keras.layers.Dense(hidden_sizes[2], **kwargs))\n  maybe_add_batchnorm()\n  network.add(tf.keras.layers.Dense(hidden_sizes[3], **kwargs))\n  maybe_add_batchnorm()\n  kwargs['activation'] = None\n  network.add(tf.keras.layers.Dense(num_classes, **kwargs))\n  return network\n\n\n@gin.configurable(allowlist=['hidden_sizes', 'use_batch_norm'])\ndef mlp(input_shape,\n        num_classes,\n        activation,\n        kernel_regularizer,\n        use_batch_norm = False,\n        hidden_sizes = (300, 100)):\n  \"\"\"Lenet5 implementation.\"\"\"\n  network = tf.keras.Sequential()\n  kwargs = {\n      'activation': activation,\n      'kernel_regularizer': kernel_regularizer\n  }\n  def maybe_add_batchnorm():\n    if use_batch_norm:\n      network.add(tf.keras.layers.BatchNormalization())\n  network.add(tf.keras.layers.Flatten(input_shape=input_shape))\n  network.add(tf.keras.layers.Dense(hidden_sizes[0], **kwargs))\n  maybe_add_batchnorm()\n  network.add(tf.keras.layers.Dense(hidden_sizes[1], **kwargs))\n  maybe_add_batchnorm()\n  kwargs['activation'] = None\n  network.add(tf.keras.layers.Dense(num_classes, **kwargs))\n  return network\n"
  },
  {
    "path": "rigl/rigl_tf2/train.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Training script for running experiments.\n\"\"\"\n\nimport os\nfrom typing import List  # Non-expensive-to-import types.\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nimport gin\nimport jax\nfrom jax.scipy.linalg import eigh\nimport numpy as np\nfrom rigl.rigl_tf2 import mask_updaters\nfrom rigl.rigl_tf2 import metainit\nfrom rigl.rigl_tf2 import utils\nimport tensorflow.compat.v2 as tf\nfrom pyglib import timer\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('logdir', '/tmp/sparse_spectrum',\n                    'Directory to save experiment in.')\nflags.DEFINE_string('preload_gin_config', '', 'If non-empty reads a gin file '\n                    'before parsing gin_config and bindings. This is useful,'\n                    'when you want to start from a configuration of another '\n                    'run. Values are then overwritten by additional configs '\n                    'and bindings provided.')\nflags.DEFINE_bool('use_tpu', True, 'Whether to run on TPU or not.')\nflags.DEFINE_enum('mode', 'train_eval', ('train_eval', 'hessian'),\n                  'Whether to run on TPU or not.')\nflags.DEFINE_string(\n    'tpu_job_name', 'tpu_worker',\n    'Name of the TPU worker job. This is required when having '\n    'multiple TPU worker jobs.')\nflags.DEFINE_integer('seed', default=0, help=('Sets the random seed.'))\nflags.DEFINE_multi_string('gin_config', [],\n                          'List of paths to the config files.')\nflags.DEFINE_multi_string('gin_bindings', [],\n                          'Newline separated list of Gin parameter bindings.')\n\n\n@tf.function\ndef get_rows(model, variables, masks, ind_l, indices, x_batch, y_batch,\n             is_dense_spectrum):\n  \"\"\"Calculates the rows (given by `ind_l`) of the Hessian.\"\"\"\n  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n  with tf.GradientTape(persistent=True) as tape:\n    predictions = model(x_batch, training=True)\n    loss = loss_object(y_batch, predictions)\n    grads, = tape.gradient(loss, [variables[ind_l]])\n    # Since the variables are masked before not during the forward pass,\n    # gradients are dense. We need to ensure they are sparse.\n    sparse_grads = grads * masks[ind_l]\n    single_grad = tf.reshape(sparse_grads, [-1])\n    s_grads = tf.gather(single_grad, indices)\n\n  flattened_list = []\n  hessians_slice_vars = tape.jacobian(\n      s_grads, variables, experimental_use_pfor=False)\n  for h, m in zip(hessians_slice_vars, masks):\n    if is_dense_spectrum:\n      # We apply the masks since weights are not hard constrained with sparsity.\n      vals = tf.reshape(h * m, (h.shape[0], -1))\n    else:\n      boolean_mask = tf.broadcast_to(tf.equal(m, 1), h.shape)\n      vals = tf.reshape(h[boolean_mask], (h.shape[0], -1))\n    flattened_list.append(vals)\n\n  res = tf.concat(flattened_list, 1)\n  return res\n\n\ndef sparse_hessian_calculator(model,\n                              data,\n                              rows_at_once,\n                              eigvals_path,\n                              overwrite,\n                              is_dense_spectrum=False):\n  \"\"\"Calculates the Hessian of the model parameters. Biases are dense.\"\"\"\n  # Read all data at once\n  x_batch, y_batch = list(data.batch(100000))[0]\n\n  if tf.io.gfile.exists(eigvals_path) and overwrite:\n    logging.info('Deleting existing Eigvals: %s', eigvals_path)\n    tf.io.gfile.rmtree(eigvals_path)\n  if tf.io.gfile.exists(eigvals_path):\n    with tf.io.gfile.GFile(eigvals_path, 'rb') as f:\n      eigvals = np.load(f)\n    logging.info('Eigvals exists, skipping :%s', eigvals_path)\n    return eigvals\n\n  # First lets create lists that indicate the valid dimension of each variable.\n  # If we want to calculate sparse spectrum, then we have to omit masked\n  # dimensions. Biases are dense, therefore have masks of 1's.\n  masks = []\n  variables = []\n  layer_group_indices = []\n  for l in model.layers:\n    if isinstance(l, utils.PRUNING_WRAPPER):\n      # TODO following the outcome of b/148083099, update following.\n      # Add the weight, mask and the valid dimensions.\n      weight = l.weights[0]\n      variables.append(weight)\n\n      mask = l.weights[2]\n      masks.append(mask)\n      logging.info(mask.shape)\n\n      if is_dense_spectrum:\n        n_params = tf.size(mask)\n        layer_group_indices.append(tf.range(n_params))\n      else:\n        fmask = tf.reshape(mask, [-1])\n        indices = tf.where(tf.equal(fmask, 1))[:, 0]\n        layer_group_indices.append(indices)\n      # Add the bias mask of ones and all of its dimensions.\n      bias = l.weights[1]\n      variables.append(bias)\n      masks.append(tf.ones_like(bias))\n      layer_group_indices.append(tf.range(tf.size(bias)))\n    else:\n      # For now we assume all parameterized layers are wrapped with\n      # PruneLowMagnitude.\n      assert not l.trainable_variables\n  result_all = []\n  init_timer = timer.Timer()\n  init_timer.Start()\n  n_total = 0\n  logging.info('Calculating Hessian...')\n  for i, inds in enumerate(layer_group_indices):\n    n_split = np.ceil(tf.size(inds).numpy() / rows_at_once)\n    logging.info('Nsplit: %d', n_split)\n    for c_slice in np.array_split(inds.numpy(), n_split):\n      res = get_rows(model, variables, masks, i, c_slice, x_batch, y_batch,\n                     is_dense_spectrum)\n      result_all.append(res.numpy())\n      n_total += res.shape[0]\n      target_n = float(res.shape[1])\n    logging.info('%.3f %% ..', (n_total / target_n))\n  # We convert in numpy so that it is on cpu automatically and we don't get OOM.\n  c_hessian = np.concatenate(result_all, 0)\n  logging.info('Total runtime for hessian: %.3f s', init_timer.GetDuration())\n  init_timer.Start()\n  eigens = jax.jit(eigh, backend='cpu')(c_hessian)\n  eigvals = np.asarray(eigens[0])\n  with tf.io.gfile.GFile(eigvals_path, 'wb') as f:\n    np.save(f, eigvals)\n  logging.info('EigVals saved: %s', eigvals_path)\n  logging.info('Total runtime for eigvals: %.3f s', init_timer.GetDuration())\n  return eigvals\n\n\n@gin.configurable(denylist=['model', 'ds_train', 'logdir'])\ndef hessian(model,\n            ds_train,\n            logdir,\n            ckpt_ids = gin.REQUIRED,\n            overwrite = False,\n            batch_size = 1000,\n            rows_at_once = 10,\n            is_dense_spectrum = False):\n  \"\"\"Loads checkpoints under a folder and calculates their hessian spectrum.\"\"\"\n  # Note that hessian is calculated using the same batch in different runs.\n  # This is needed since if the job dies and restarted we want it to be same.\n  data_hessian = ds_train.take(batch_size)\n  for ckpt_id in ckpt_ids:\n    # `cp-0005.ckpt.index` -> 15012\n    ckpt = tf.train.Checkpoint(model=model)\n    c_path = os.path.join(logdir, 'ckpt-%d' % ckpt_id)\n    ckpt.restore(c_path)\n    logging.info('Loaded from: %s', c_path)\n    eigvals_path = c_path + '.eigvals'\n    sparse_hessian_calculator(\n        model=model, data=data_hessian, eigvals_path=eigvals_path,\n        overwrite=overwrite, is_dense_spectrum=is_dense_spectrum,\n        rows_at_once=rows_at_once)\n\n\ndef update_prune_step(model, step):\n  for layer in model.layers:\n    if isinstance(layer, utils.PRUNING_WRAPPER):\n      # Assign iteration count to the layer pruning_step.\n      layer.pruning_step.assign(step)\n\n\ndef log_sparsities(model):\n  for layer in model.layers:\n    if isinstance(layer, utils.PRUNING_WRAPPER):\n      for _, mask, threshold in layer.pruning_vars:\n        scalar_name = f'sparsity/{mask.name}'\n        sparsity = 1 - tf.reduce_mean(mask)\n        tf.summary.scalar(scalar_name, sparsity)\n        tf.summary.scalar(f'threshold/{threshold.name}', threshold)\n\n\ndef cosine_distance(x, y):\n  \"\"\"Calculates the distance between 2 tensors of same shape.\"\"\"\n  normalizedx = tf.math.l2_normalize(x)\n  normalizedy = tf.math.l2_normalize(y)\n  return 1. - tf.reduce_sum(tf.multiply(normalizedx, normalizedy))\n\n\ndef flatten_list_of_vars(var_list):\n  flat_vars = [tf.reshape(v, -1) for v in var_list]\n  return tf.concat(flat_vars, axis=-1)\n\n\ndef var_to_img(tensor):\n  if len(tensor.shape) <= 1:\n    gray_image = tf.reshape(tensor, [1, -1])\n  elif len(tensor.shape) == 2:\n    gray_image = tensor\n  else:\n    gray_image = tf.reshape(tensor, [-1, tensor.shape[-1]])\n  # (H, W) -> (1, H, W, 1)\n  return tf.expand_dims(tf.expand_dims(gray_image, 0), -1)\n\n\ndef mask_gradients(model, gradients, variables):\n  name_to_grad = {var.name: grad for grad, var in zip(gradients, variables)}\n  for layer in model.layers:\n    if isinstance(layer, utils.PRUNING_WRAPPER):\n      for weights, mask, _ in layer.pruning_vars:\n        if weights.name in name_to_grad:\n          name_to_grad[weights.name] = name_to_grad[weights.name] * mask\n  masked_gradients = [name_to_grad[var.name] for var in variables]\n  return masked_gradients\n\n\n@gin.configurable(\n    'training', denylist=['model', 'ds_train', 'ds_test', 'logdir'])\ndef train_model(model,\n                ds_train,\n                ds_test,\n                logdir,\n                total_steps = 5000,\n                batch_size = 128,\n                val_batch_size = 1000,\n                save_freq = 5,\n                log_freq = 250,\n                use_metainit = False,\n                oneshot_prune_fraction = 0.,\n                gradient_regularization=0):\n  \"\"\"Training of the CNN on MNIST.\"\"\"\n  logging.info('Writing training logs to %s', logdir)\n  writer = tf.summary.create_file_writer(os.path.join(logdir, 'train_logs'))\n  optimizer = utils.get_optimizer(total_steps)\n  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n  train_batch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n      name='train_batch_accuracy')\n  # Let's create 2 disjoint validation sets.\n  (val_x, val_y), (val2_x, val2_y) = [\n      d for d in ds_train.take(val_batch_size * 2).batch(val_batch_size)\n  ]\n\n  # We use a separate set than the one we are using in our training.\n  def loss_fn(x, y):\n    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)\n    predictions = model(x, training=True)\n    reg_loss = tf.add_n(model.losses) if model.losses else 0\n    return loss_object(y, predictions) + reg_loss\n\n  mask_updater = mask_updaters.get_mask_updater(model, optimizer, loss_fn)\n  if mask_updater:\n    mask_updater.set_validation_data(val2_x, val2_y)\n  update_prune_step(model, 0)\n  if oneshot_prune_fraction > 0:\n    logging.info('Running one shot prunning at the beginning.')\n    if not mask_updater:\n      raise ValueError('mask_updater does not exists. Please set '\n                       'mask_updater.update_alg flag for one shot pruning.')\n    mask_updater.prune(oneshot_prune_fraction)\n  if use_metainit:\n    n_params = 0\n    for layer in model.layers:\n      if isinstance(layer, utils.PRUNING_WRAPPER):\n        for _, mask, _ in layer.pruning_vars:\n          n_params += tf.reduce_sum(mask)\n    metainit.meta_init(model, loss_object, (128, 28, 28, 1), (128, 10),\n                       n_params, mask_gradient_fn=mask_gradients)\n  # This is used to calculate some distances, would give incorrect results when\n  # we restart the training.\n  initial_params = list(map(lambda a: a.numpy(), model.trainable_variables))\n\n  # Create the checkpoint object and restore if there is a checkpoint in the\n  # folder.\n  ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)\n  ckpt_manager = tf.train.CheckpointManager(\n      checkpoint=ckpt, directory=logdir, max_to_keep=None)\n  if ckpt_manager.latest_checkpoint:\n    logging.info('Restored from %s', ckpt_manager.latest_checkpoint)\n    ckpt.restore(ckpt_manager.latest_checkpoint)\n    is_restored = True\n  else:\n    logging.info('Starting from scratch.')\n    is_restored = False\n  # Obtain global_step after loading checkpoint.\n  global_step = optimizer.iterations\n  tf.summary.experimental.set_step(global_step)\n  trainable_vars = model.trainable_variables\n\n  def get_gradients(x, y, log_batch_gradient=False, is_regularized=True):\n    \"\"\"Gets spars gradients and possibly logs some statistics.\"\"\"\n    is_grad_regularized = gradient_regularization != 0\n    with tf.GradientTape(persistent=is_grad_regularized) as tape:\n      predictions = model(x, training=True)\n      batch_loss = loss_object(y, predictions)\n      if is_regularized and is_grad_regularized:\n        gradients = tape.gradient(batch_loss, trainable_vars)\n        gradients = mask_gradients(model, gradients, trainable_vars)\n        grad_vec = flatten_list_of_vars(gradients)\n        batch_loss += tf.nn.l2_loss(grad_vec) * gradient_regularization\n      # Regularization might have been disabled.\n      reg_loss = tf.add_n(model.losses) if model.losses else 0\n      if is_regularized:\n        batch_loss += reg_loss\n    gradients = tape.gradient(batch_loss, trainable_vars)\n    # Gradients are dense, we should mask them to ensure updates are sparse;\n    # So is the norm calculation.\n    gradients = mask_gradients(model, gradients, trainable_vars)\n    # If batch gradient log it.\n    if log_batch_gradient:\n      tf.summary.scalar('train_batch_loss', batch_loss)\n      tf.summary.scalar('train_batch_reg_loss', reg_loss)\n      train_batch_accuracy.update_state(y, predictions)\n      tf.summary.scalar('train_batch_accuracy', train_batch_accuracy.result())\n      train_batch_accuracy.reset_states()\n    return gradients\n\n  def log_fn():\n    logging.info('Logging at iter: %d', global_step.numpy())\n    log_sparsities(model)\n    test_loss, test_acc = test_model(model, ds_test)\n    tf.summary.scalar('test_loss', test_loss)\n    tf.summary.scalar('test_acc', test_acc)\n    # Log gradient norm.\n    # We want to obtain/log gradients without regularization term.\n    gradients = get_gradients(val_x, val_y, log_batch_gradient=False,\n                              is_regularized=False)\n    for var, grad in zip(trainable_vars, gradients):\n      tf.summary.scalar(f'gradnorm/{var.name}', tf.norm(grad))\n    # Log all gradients together\n    all_norm = tf.norm(flatten_list_of_vars(gradients))\n    tf.summary.scalar('.allparams/gradnorm', all_norm)\n    # Log momentum values:\n    for s_name in optimizer.get_slot_names():\n      # Currently we only log momentum.\n      if s_name not in ['momentum']:\n        continue\n      all_slots = [optimizer.get_slot(var, s_name) for var in trainable_vars]\n      all_norm = tf.norm(flatten_list_of_vars(all_slots))\n      tf.summary.scalar(f'.allparams/norm_{s_name}', all_norm)\n    # Log distance to init.\n    for initial_val, val in zip(initial_params, model.trainable_variables):\n      tf.summary.scalar(f'dist_init_l2/{val.name}', tf.norm(initial_val - val))\n      cos_distance = cosine_distance(initial_val, val)\n      tf.summary.scalar(f'dist_init_cosine/{val.name}', cos_distance)\n    # Mask update logs:\n    if mask_updater:\n      tf.summary.scalar('drop_fraction', mask_updater.last_drop_fraction)\n    # Log all distances together.\n    flat_initial = flatten_list_of_vars(initial_params)\n    flat_current = flatten_list_of_vars(model.trainable_variables)\n    tf.summary.scalar('.allparams/dist_init_l2/',\n                      tf.norm(flat_initial - flat_current))\n    tf.summary.scalar('.allparams/dist_init_cosine/',\n                      cosine_distance(flat_initial, flat_current))\n    # Log masks\n    for layer in model.layers:\n      if isinstance(layer, utils.PRUNING_WRAPPER):\n        for _, mask, _ in layer.pruning_vars:\n          tf.summary.image('mask/%s' % mask.name, var_to_img(mask))\n    writer.flush()\n\n  def save_fn(step=None):\n    save_step = step if step else global_step\n    saved_ckpt = ckpt_manager.save(checkpoint_number=save_step)\n    logging.info('Saved checkpoint: %s', saved_ckpt)\n\n  with writer.as_default():\n    for x, y in ds_train.repeat().shuffle(\n        buffer_size=60000).batch(batch_size):\n      if global_step >= total_steps:\n        logging.info('Total steps: %d is completed', global_step.numpy())\n        save_fn()\n        break\n      update_prune_step(model, global_step)\n      if tf.equal(global_step, 0):\n        logging.info('Seed: %s First 10 Label: %s', FLAGS.seed, y[:10])\n      if global_step % save_freq == 0:\n        # If just loaded, don't save it again.\n        if is_restored:\n          is_restored = False\n        else:\n          save_fn()\n      if global_step % log_freq == 0:\n        log_fn()\n      gradients = get_gradients(x, y, log_batch_gradient=True)\n      tf.summary.scalar('lr', optimizer.lr(global_step))\n      optimizer.apply_gradients(zip(gradients, trainable_vars))\n      if mask_updater and mask_updater.is_update_iter(global_step):\n        # Save the network before mask_update, we want to use negative integers\n        # for this.\n        save_fn(step=(-global_step + 1))\n        # Gradient norm before.\n        gradients = get_gradients(\n            val_x, val_y, log_batch_gradient=False, is_regularized=False)\n        norm_before = tf.norm(flatten_list_of_vars(gradients))\n        results = mask_updater.update(global_step)\n        # Save network again\n        save_fn(step=-global_step)\n        if results:\n          for mask_name, drop_frac in results.items():\n            tf.summary.scalar('drop_fraction/%s' % mask_name, drop_frac)\n\n        # Gradient norm after mask update.\n        gradients = get_gradients(\n            val_x, val_y, log_batch_gradient=False, is_regularized=False)\n        norm_after = tf.norm(flatten_list_of_vars(gradients))\n        tf.summary.scalar('.allparams/gradnorm_mask_update_improvment',\n                          norm_after - norm_before)\n\n    logging.info('Performance after training:')\n    log_fn()\n  return model\n\n\ndef test_model(model, d_test, batch_size=1000):\n  \"\"\"Tests the model and calculates cross entropy loss and accuracy.\"\"\"\n  test_loss = tf.keras.metrics.Mean(name='test_loss')\n  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n      name='test_accuracy')\n  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n  for x, y in d_test.batch(batch_size):\n    predictions = model(x, training=False)\n    batch_loss = loss_object(y, predictions)\n    test_loss.update_state(batch_loss)\n    test_accuracy.update_state(y, predictions)\n  logging.info('Test loss: %f', test_loss.result().numpy())\n  logging.info('Test accuracy: %f', test_accuracy.result().numpy())\n  return test_loss.result(), test_accuracy.result()\n\n\ndef main(unused_argv):\n  tf.random.set_seed(FLAGS.seed)\n  init_timer = timer.Timer()\n  init_timer.Start()\n\n  if FLAGS.mode == 'hessian':\n    # Load default values from the original experiment.\n    FLAGS.preload_gin_config = os.path.join(FLAGS.logdir,\n                                            'operative_config.gin')\n\n  # Maybe preload a gin config.\n  if FLAGS.preload_gin_config:\n    config_path = FLAGS.preload_gin_config\n    gin.parse_config_file(config_path)\n    logging.info('Gin configuration pre-loaded from: %s', config_path)\n\n  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)\n  ds_train, ds_test, info = utils.get_dataset()\n  input_shape = info.features['image'].shape\n  num_classes = info.features['label'].num_classes\n  logging.info('Input Shape: %s', input_shape)\n  logging.info('train samples: %s', info.splits['train'].num_examples)\n  logging.info('test samples: %s', info.splits['test'].num_examples)\n\n  pruning_params = utils.get_pruning_params()\n  model = utils.get_network(pruning_params, input_shape, num_classes)\n  model.summary(print_fn=logging.info)\n  if FLAGS.mode == 'train_eval':\n    train_model(model, ds_train, ds_test, FLAGS.logdir)\n  elif FLAGS.mode == 'hessian':\n    test_model(model, ds_test)\n    hessian(model, ds_train, FLAGS.logdir)\n  logging.info('Total runtime: %.3f s', init_timer.GetDuration())\n\n  logconfigfile_path = os.path.join(\n      FLAGS.logdir,\n      'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin')\n  with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:\n    f.write('# Gin-Config:\\n %s' % gin.config.operative_config_str())\n\n\nif __name__ == '__main__':\n  tf.enable_v2_behavior()\n  app.run(main)\n"
  },
  {
    "path": "rigl/rigl_tf2/utils.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Utilities for training.\n\"\"\"\nimport functools\nfrom typing import Optional, Tuple\n\nfrom absl import flags\nfrom absl import logging\nimport gin\nfrom rigl.rigl_tf2 import init_utils\nfrom rigl.rigl_tf2 import networks\nimport tensorflow.compat.v2 as tf\nimport tensorflow_datasets as tfds\nfrom tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule\nfrom tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper\n\nFLAGS = flags.FLAGS\nPRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude\nPRUNED_LAYER_TYPES = (tf.keras.layers.Conv2D, tf.keras.layers.Dense)\n\n\n@gin.configurable('data')\ndef get_dataset():\n  \"\"\"Loads the dataset.\"\"\"\n  # the data, shuffled and split between train and test sets.\n  datasets, info = tfds.load('mnist', with_info=True)\n  ds_train, ds_test = datasets['train'].cache(), datasets['test'].cache()\n\n  preprocess_fn = lambda x: (tf.cast(x['image'], tf.float32) / 255., x['label'])\n  ds_train = ds_train.map(preprocess_fn)\n  ds_test = tfds.load('mnist', split='test').cache()\n  ds_test = ds_test.map(preprocess_fn)\n  return ds_train, ds_test, info\n\n\n@gin.configurable('pruning')\ndef get_pruning_params(mode='prune',\n                       initial_sparsity=0.0,\n                       final_sparsity=0.8,\n                       begin_step=2000,\n                       end_step=4000,\n                       frequency=200):\n  \"\"\"Gets pruning hyper-parameters.\"\"\"\n  p_params = {}\n  if mode == 'prune':\n    p_params['pruning_schedule'] = pruning_schedule.PolynomialDecay(\n        initial_sparsity=initial_sparsity,\n        final_sparsity=final_sparsity,\n        begin_step=begin_step,\n        end_step=end_step,\n        frequency=frequency)\n  elif mode == 'constant':\n    p_params['pruning_schedule'] = pruning_schedule.ConstantSparsity(\n        target_sparsity=final_sparsity, begin_step=begin_step)\n  else:\n    raise ValueError('Mode: %s, is not valid' % mode)\n  return p_params\n\n\n# Forked from tensorflow_model_optimization/python/core/sparsity/keras/prune.py\ndef maybe_prune_layer(layer, params, filter_fn):\n  if filter_fn(layer):\n    return PRUNING_WRAPPER(layer, **params)\n  return layer\n\n\n@gin.configurable('network')\ndef get_network(\n    pruning_params,\n    input_shape,\n    num_classes,\n    activation = 'relu',\n    network_name = 'lenet5',\n    mask_init_path = None,\n    shuffle_mask = False,\n    weight_init_path = None,\n    weight_init_method = None,\n    weight_decay = 0.,\n    noise_stddev = 0.,\n    pruned_layer_types = PRUNED_LAYER_TYPES):\n  \"\"\"Creates the network.\"\"\"\n  kernel_regularizer = (\n      tf.keras.regularizers.l2(weight_decay) if (weight_decay > 0) else None)\n  # (1) Create keras model.\n  model = getattr(networks, network_name)(\n      input_shape, num_classes, activation=activation,\n      kernel_regularizer=kernel_regularizer)\n  model.summary(print_fn=logging.info)\n  # (2) Adding wrappers. i.e. sparsify if conv or dense.\n  filter_fn = lambda layer: isinstance(layer, pruned_layer_types)\n  clone_fn = functools.partial(maybe_prune_layer,\n                               params=pruning_params,\n                               filter_fn=filter_fn)\n  model = tf.keras.models.clone_model(model, clone_function=clone_fn)\n\n  # (3) Update parameters of the model as necessary.\n  if mask_init_path:\n    logging.info('Loading masks from: %s', mask_init_path)\n    mask_init_model = tf.keras.models.clone_model(model)\n    ckpt = tf.train.Checkpoint(model=mask_init_model)\n    ckpt.restore(mask_init_path)\n    for l_source, l_target in zip(mask_init_model.layers, model.layers):\n      if isinstance(l_source, PRUNING_WRAPPER):\n        # l.pruning_vars[0][1] is the mask.\n        mask = l_target.pruning_vars[0][1]\n        n_active = tf.reduce_sum(mask)\n        n_dense = tf.cast(tf.size(mask), dtype=n_active.dtype)\n        logging.info('Before: %s, %.2f', l_target.name,\n                     (n_active / n_dense).numpy())\n        loaded_mask = l_source.pruning_vars[0][1]\n        if shuffle_mask:\n          # tf shuffle shuffles along the first dim, so we need to flatten.\n          loaded_mask = tf.reshape(\n              tf.random.shuffle(tf.reshape(loaded_mask, -1)), loaded_mask.shape)\n        mask.assign(loaded_mask)\n        n_active = tf.reduce_sum(mask)\n        n_dense = tf.cast(tf.size(mask), dtype=n_active.dtype)\n        logging.info('After: %s, %.2f', l_target.name,\n                     (n_active / n_dense).numpy())\n    del mask_init_model\n  if weight_init_path:\n    logging.info('Loading weights from: %s', weight_init_path)\n    weight_init_model = tf.keras.models.clone_model(model)\n    ckpt = tf.train.Checkpoint(model=weight_init_model)\n    ckpt.restore(weight_init_path)\n    for l_source, l_target in zip(weight_init_model.layers, model.layers):\n      for var_source, var_target in zip(l_source.trainable_variables,\n                                        l_target.trainable_variables):\n        var_target.assign(var_source)\n        logging.info('Weight %s loaded from ckpt.', var_target.name)\n    del weight_init_model\n  elif weight_init_method == 'unit_scaled':\n    logging.info('Using unit_scaled initialization.')\n    for layer in model.layers:\n      if isinstance(layer, PRUNING_WRAPPER):\n        # TODO following the outcome of b/148083099, update following.\n        # Add the weight, mask and the valid dimensions.\n        weight = layer.weights[0]\n        mask = layer.weights[2]\n        new_init = init_utils.unit_scaled_init(mask)\n        weight.assign(new_init)\n        logging.info('Weight %s updated init.', weight.name)\n  elif weight_init_method == 'layer_scaled':\n    logging.info('Using layer_scaled initialization.')\n    for layer in model.layers:\n      if isinstance(layer, PRUNING_WRAPPER):\n        # TODO following the outcome of b/148083099, update following.\n        # Add the weight, mask and the valid dimensions.\n        weight = layer.weights[0]\n        mask = layer.weights[2]\n        new_init = init_utils.layer_scaled_init(mask)\n        weight.assign(new_init)\n        logging.info('Weight %s updated init.', weight.name)\n  if noise_stddev > 0.:\n    logging.info('Adding noise to the initial point')\n    for layer in model.layers:\n      for var in layer.trainable_variables:\n        noise = tf.random.normal(var.shape, mean=0, stddev=noise_stddev)\n        var.assign_add(noise)\n  # Do this call to mask the weights with existing masks if it is not done\n  # already. This is needed for example when we use initial parameters to cal-\n  # culate distance.\n  model(tf.expand_dims(tf.ones(input_shape), 0))\n  return model\n\n\n@gin.configurable('optimizer', denylist=['total_steps'])\ndef get_optimizer(total_steps,\n                  name = 'adam',\n                  learning_rate = 0.001,\n                  clipnorm = None,\n                  clipvalue = None,\n                  momentum = None):\n  \"\"\"Creates the optimizer according to the arguments.\"\"\"\n  name = name.lower()\n  # We use cosine decay.\n  lr_decayed_fn = tf.keras.experimental.CosineDecay(learning_rate, total_steps)\n  kwargs = {}\n  if clipnorm:\n    # Not correct implementation, see http://b/152868229 .\n    kwargs['clipnorm'] = clipnorm\n  if clipvalue:\n    kwargs['clipvalue'] = clipvalue\n  if name == 'adam':\n    return tf.keras.optimizers.Adam(lr_decayed_fn, **kwargs)\n  if name == 'momentum':\n    return tf.keras.optimizers.SGD(lr_decayed_fn, momentum=momentum, **kwargs)\n  if name == 'sgd':\n    return tf.keras.optimizers.SGD(lr_decayed_fn, **kwargs)\n  if name == 'rmsprop':\n    return tf.keras.optimizers.RMSprop(\n        lr_decayed_fn, momentum=momentum, **kwargs)\n  raise NotImplementedError(f'Optimizers {name} not implemented.')\n"
  },
  {
    "path": "rigl/rl/README.md",
    "content": "# The State of Sparse Training in Deep Reinforcement Learning\n[**Paper**] [goo.gle/sparserl-paper](https://goo.gle/sparserl-paper)\n[**Video**] [goo.gle/sparserl-video](https://goo.gle/sparserl-video)\n\nThis code requires Tensorflow 2.0; therefore we need to use a separate\nrequirements file. Please follow the instructions below:\n\nFirst clone this repo.\n```bash\ngit clone https://github.com/google-research/rigl.git\ncd rigl\n```\n\nWe use [Neurips 2019 MicroNet Challenge](https://micronet-challenge.github.io/)\ncode for counting operations and size of our networks. Let's clone the\ngoogle_research repo and add current folder to the python path.\n```bash\ngit clone https://github.com/google-research/google-research.git\nmv google-research/ google_research/\nexport PYTHONPATH=$PYTHONPATH:$PWD\n```\n\nNow we can run some tests. Following script creates a virtual environment and\ninstalls the necessary libraries. Finally, it runs few tests.\n```bash\nvirtualenv -p python3 env_sparserl\nsource env_sparserl/bin/activate\n\npip install -r rigl/rl/requirements.txt\npython -m rigl.sparse_utils_test\n```\n\nFollow instructions here to install MuJoCo: https://github.com/openai/mujoco-py#install-mujoco\n\nTo run PPO:\n\n```\npython3 rigl/rl/tfagents/ppo_train_eval.py  \\\n--gin_file=rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin \\\n--root_dir=/tmp/sparserl/ --is_mujoco=True\n```\n\nTo run SAC:\n\n```\npython3 rigl/rl/tfagents/sac_train_eval.py  \\\n--gin_file=rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin \\\n--root_dir=/tmp/sparserl/ --is_mujoco=True\n```\n\n**Citation**:\n```\n@InProceedings{graesser22a,\n  title = \t {The State of Sparse Training in Deep Reinforcement Learning},\n  author =       {Graesser, Laura and Evci, Utku and Elsen, Erich and Castro, Pablo Samuel},\n  booktitle = \t {Proceedings of the 39th International Conference on Machine Learning},\n  pages = \t {7766--7792},\n  year = \t {2022},\n  editor = \t {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},\n  volume = \t {162},\n  series = \t {Proceedings of Machine Learning Research},\n  month = \t {17--23 Jul},\n  publisher =    {PMLR},\n  pdf = \t {https://proceedings.mlr.press/v162/graesser22a/graesser22a.pdf},\n  url = \t {https://proceedings.mlr.press/v162/graesser22a.html},\n}\n```\n"
  },
  {
    "path": "rigl/rl/dqn_agents.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Variants of DQN with sparsity.\"\"\"\n\nimport functools\nimport math\nfrom absl import logging\nfrom dopamine.agents.dqn import dqn_agent\nfrom dopamine.discrete_domains import atari_lib\nimport gin\nfrom rigl.rl import sparse_utils\nimport tensorflow as tf\nimport tensorflow.compat.v1 as tf1\n\n\n# one of ('dense', 'prune', 'rigl', 'static', 'set'). If 'dense' no modification\n# done. If 'prune', the agent is pruned after training.\n# If ('rigl', 'static', 'set') the corresponding sparse-to-sparse training\n# algorithm is used.\nLEARNER_MODES = ('dense', 'prune', 'rigl', 'static', 'set')\n\n\ndef flatten_list_of_vars(var_list):\n  flat_vars = [tf.reshape(v, [-1]) for v in var_list]\n  return tf.concat(flat_vars, axis=-1)\n\n\ndef _get_bn_layer_name(block_id, i):\n  return f'batch_norm_{block_id},{i}'\n\n\ndef _get_conv_layer_name(block_id, i):\n  return f'conv_{block_id},{i}'\n\n\nclass _Stack(tf.keras.Model):\n  \"\"\"Stack of pooling and convolutional blocks with residual connections.\n  \"\"\"\n\n  def __init__(self,\n               num_ch,\n               num_blocks,\n               use_max_pooling=True,\n               use_batch_norm=False,\n               name='stack'):\n    super(_Stack, self).__init__(name=name)\n    self._conv = tf.keras.layers.Conv2D(num_ch, 3, strides=1, padding='same')\n    self.use_max_pooling = use_max_pooling\n    self.use_batch_norm = use_batch_norm\n    self.num_blocks = num_blocks\n    if self.use_batch_norm:\n      self._batch_norm = tf.keras.layers.BatchNormalization()\n    if self.use_max_pooling:\n      self._max_pool = tf.keras.layers.MaxPool2D(\n          pool_size=3, padding='same', strides=2)\n    for block_id in range(num_blocks):\n      for i in range(2):\n        name = _get_conv_layer_name(block_id, i)\n        layer = tf.keras.layers.Conv2D(\n            num_ch, 3, strides=1, padding='same',\n            name=f'res_{block_id}/conv2d_{i}')\n        setattr(self, name, layer)\n        if self.use_batch_norm:\n          name = _get_bn_layer_name(block_id, i)\n          setattr(self, name, tf.keras.layers.BatchNormalization())\n\n  def call(self, conv_out, training=False):\n    # Downscale.\n    conv_out = self._conv(conv_out)\n    if self.use_max_pooling:\n      conv_out = self._max_pool(conv_out)\n    if self.use_batch_norm:\n      conv_out = self._batch_norm(conv_out, training=training)\n\n    # Residual block(s).\n    for block_id in range(self.num_blocks):\n      block_input = conv_out\n      for i in range(2):\n        conv_out = tf.nn.relu(conv_out)\n        conv_layer = getattr(self, _get_conv_layer_name(block_id, i))\n        conv_out = conv_layer(conv_out)\n        if self.use_batch_norm:\n          bn_layer = getattr(self, _get_bn_layer_name(block_id, i))\n          conv_out = bn_layer(conv_out, training=training)\n      conv_out += block_input\n    return conv_out\n\n\n@gin.configurable\nclass ImpalaNetwork(tf.keras.Model):\n  \"\"\"Agent with ResNet, but without LSTM and additional inputs.\n\n  The deep model used for DQN which follows\n  \"IMPALA: Scalable Distributed Deep-RL with Importance Weighted\n  Actor-Learner Architectures\" by Espeholt, Soyer, Munos et al.\n\n  Original implementation by Rishabh Agarwal, with minor modifications as\n  follows:\n  * rename nn_scale to width to fit with the sparserl API\n  * allow for non-integer widths.\n  * add training mode.\n  * removed the option to have multiple heads.\n  * modified the call function to return a compatible type.\n  * added custom logic for sparse training.\n  \"\"\"\n\n  def __init__(self,\n               num_actions,\n               width=1.0,\n               mode='dense',\n               name='impala_deep_network',\n               prune_allow_key='',\n               use_batch_norm=False):\n    super().__init__(name=name)\n    self._width = width\n    self._mode = mode\n\n    def _scale_width(n):\n      return int(math.ceil(n * width))\n\n    self.num_actions = num_actions\n    self.use_batch_norm = use_batch_norm\n    logging.info('Using batch norm in %s: %s', name, use_batch_norm)\n    stack_fn = functools.partial(_Stack, use_batch_norm=use_batch_norm)\n    # Parameters and layers for _torso.\n    self._stacks = [\n        stack_fn(_scale_width(32), 2, name='stack1'),\n        stack_fn(_scale_width(64), 2, name='stack2'),\n        stack_fn(_scale_width(64), 2, name='stack3'),\n    ]\n    self._dense1 = tf.keras.layers.Dense(_scale_width(256))\n    self._dense2 = tf.keras.layers.Dense(\n        self.num_actions, name='policy_logits')\n\n    layer_shape_dict = {\n        '_dense1': (7744, 512),\n        '_dense2': (512, self.num_actions),\n    }\n    def add_stack_shapes(name, in_width, out_width):\n      # First conv\n      layer_shape_dict[f'{name}/_conv'] = (3, 3, in_width, out_width)\n      for i in range(2):\n        for j in range(2):\n          l_name = _get_conv_layer_name(i, j)\n          layer_shape_dict[f'{name}/{l_name}'] = (3, 3, out_width, out_width)\n    add_stack_shapes('stack0', 4, _scale_width(32))\n    add_stack_shapes('stack1', _scale_width(32), _scale_width(64))\n    add_stack_shapes('stack2', _scale_width(64), _scale_width(64))\n\n    if mode != 'dense':\n      custom_sparsities = sparse_utils.get_pruning_sparsities(layer_shape_dict)\n      for l_name, sparsity in custom_sparsities.items():\n        logging.info('pruning, layer: %s, sparsity: %.4f', l_name, sparsity)\n        if l_name.startswith('stack'):\n          # stack1 -> 1\n          stack_id = int(l_name[len('stack')])\n          c_module = self._stacks[stack_id]\n          # `stack1/_conv` -> `_conv`\n          l_name = l_name.split('/')[1]\n        else:\n          c_module = self\n        if mode == 'prune':\n          if prune_allow_key and (prune_allow_key not in l_name):\n            sparsity = 0\n            logging.info('%s not pruned since, prune_allow_key: %s', l_name,\n                         prune_allow_key)\n          wrapped_layer = sparse_utils.maybe_prune_layer(\n              getattr(c_module, l_name),\n              params=sparse_utils.get_pruning_params(\n                  mode, final_sparsity=sparsity))\n        else:\n          wrapped_layer = sparse_utils.maybe_prune_layer(\n              getattr(c_module, l_name),\n              params=sparse_utils.get_pruning_params(mode))\n        setattr(c_module, l_name, wrapped_layer)\n\n  def get_features(self, state, training=True):\n    x = tf.cast(state, tf.float32)\n    x /= 255\n    conv_out = x\n    for stack in self._stacks:\n      conv_out = stack(conv_out, training=training)\n\n    conv_out = tf.nn.relu(conv_out)\n    conv_out = tf.keras.layers.Flatten()(conv_out)\n\n    out = self._dense1(conv_out)\n    out = tf.nn.relu(out)\n    out = self._dense2(out)\n    return out\n\n  def call(self, state, training=True):\n    out = self.get_features(state, training=training)\n    return atari_lib.DQNNetworkType(out)\n\n\n@gin.configurable\nclass NatureDQNNetwork(tf.keras.Model):\n  \"\"\"The convolutional network used to compute the agent's Q-values.\"\"\"\n\n  def __init__(self, num_actions, width=1, mode='dense', name=None):\n    \"\"\"Creates the layers used for calculating Q-values.\n\n    Args:\n      num_actions: int, number of actions.\n      width: float, Scales the width of the network uniformly.\n      mode: str, one of LEARNER_MODES.\n      name: str, used to create scope for network parameters.\n    \"\"\"\n    super().__init__(name=name)\n    self.num_actions = num_actions\n    self._width = width\n    self._mode = mode\n\n    def _scale_width(n):\n      return int(math.ceil(n * width))\n    # Defining layers.\n    activation_fn = tf.keras.activations.relu\n    # Setting names of the layers manually to make variable names more similar\n    # with tf.slim variable names/checkpoints.\n    self.conv1 = tf.keras.layers.Conv2D(\n        _scale_width(32), [8, 8],\n        strides=4,\n        padding='same',\n        activation=activation_fn,\n        name='Conv')\n    self.conv2 = tf.keras.layers.Conv2D(\n        _scale_width(64), [4, 4],\n        strides=2,\n        padding='same',\n        activation=activation_fn,\n        name='Conv')\n    self.conv3 = tf.keras.layers.Conv2D(\n        _scale_width(64), [3, 3],\n        strides=1,\n        padding='same',\n        activation=activation_fn,\n        name='Conv')\n    self.flatten = tf.keras.layers.Flatten()\n    self.dense1 = tf.keras.layers.Dense(\n        _scale_width(512), activation=activation_fn,\n        name='fully_connected')\n    self.dense2 = tf.keras.layers.Dense(num_actions, name='fully_connected')\n\n    layer_shape_dict = {\n        'conv1': (_scale_width(32), 8, 8, 4),\n        'conv2': (_scale_width(64), 4, 4, _scale_width(32)),\n        'conv3': (_scale_width(64), 3, 3, _scale_width(64)),\n        'dense1': (7744, _scale_width(512)),\n        'dense2': (_scale_width(512), num_actions)\n    }\n    if mode == 'dense':\n      pass\n    elif mode == 'prune':\n      custom_sparsities = sparse_utils.get_pruning_sparsities(layer_shape_dict)\n      for l_name, sparsity in custom_sparsities.items():\n        logging.info('pruning, layer: %s, sparsity: %.4f', l_name, sparsity)\n        wrapped_layer = sparse_utils.maybe_prune_layer(\n            getattr(self, l_name),\n            params=sparse_utils.get_pruning_params(\n                mode, final_sparsity=sparsity))\n        setattr(self, l_name, wrapped_layer)\n    else:\n      # static, rigl, set.\n      for l_name in layer_shape_dict:\n        wrapped_layer = sparse_utils.maybe_prune_layer(\n            getattr(self, l_name),\n            params=sparse_utils.get_pruning_params(mode))\n        setattr(self, l_name, wrapped_layer)\n\n  def call(self, state):\n    \"\"\"Creates the output tensor/op given the state tensor as input.\n\n    See https://www.tensorflow.org/api_docs/python/tf/keras/Model for more\n    information on this. Note that tf.keras.Model implements `call` which is\n    wrapped by `__call__` function by tf.keras.Model.\n\n    Parameters created here will have scope according to the `name` argument\n    given at `.__init__()` call.\n    Args:\n      state: Tensor, input tensor.\n    Returns:\n      collections.namedtuple, output ops (graph mode) or output tensors (eager).\n    \"\"\"\n    x = tf.cast(state, tf.float32)\n    x = x / 255\n    x = self.conv1(x)\n    x = self.conv2(x)\n    x = self.conv3(x)\n    x = self.flatten(x)\n    x = self.dense1(x)\n    return atari_lib.DQNNetworkType(self.dense2(x))\n\n\n@gin.configurable\nclass SparseDQNAgent(dqn_agent.DQNAgent):\n  \"\"\"A variant of DQN that is trained with sparse backbones.\"\"\"\n\n  def __init__(self,\n               sess,\n               num_actions,\n               mode='dense',\n               weight_decay=0.,\n               summary_writer=None):\n    \"\"\"Initializes the agent and constructs graph components.\n\n    Args:\n      sess: tf.Session, for executing ops.\n      num_actions: int, number of actions the agent can take at any state.\n      mode: str, one of LEARNER_MODES.\n      weight_decay: float, used to regularize online_convnet.\n      summary_writer: tf.SummaryWriter, for Tensorboard.\n    \"\"\"\n    self._weight_decay = weight_decay\n    if mode in LEARNER_MODES:\n      self._mode = mode\n    else:\n      raise ValueError(f'mode:{mode} not one of {LEARNER_MODES}')\n    self._global_step = tf1.train.get_or_create_global_step()\n    # update_period=1, we always update as the supervisor is fixed.\n    super().__init__(\n        sess, num_actions, summary_writer=summary_writer)\n\n  def _create_network(self, name):\n    network = self.network(\n        self.num_actions,\n        name=name + 'learner',\n        mode=self._mode)\n    return network\n\n  def _set_additional_ops(self):\n    if self._mode == 'dense':\n      self.step_update_op = tf.no_op()\n      self.mask_update_op = tf.no_op()\n      self.mask_init_op = tf.no_op()\n    elif self._mode in ['rigl', 'set', 'static']:\n      self.step_update_op = sparse_utils.update_prune_step(\n          self.online_convnet, self._global_step)\n      # This ensures sparse masks are applied before each run.\n      self.mask_update_op = sparse_utils.update_prune_masks(self.online_convnet)\n      self.mask_init_op = sparse_utils.init_masks(self.online_convnet)\n      # Wrap the optimizer.\n      if self._mode == 'rigl':\n        self.optimizer = sparse_utils.UpdatedRigLOptimizer(self.optimizer)\n        self.optimizer.set_model(self.online_convnet)\n      elif self._mode == 'set':\n        self.optimizer = sparse_utils.UpdatedSETOptimizer(self.optimizer)\n        self.optimizer.set_model(self.online_convnet)\n    elif self._mode == 'prune':\n      self.step_update_op = sparse_utils.update_prune_step(\n          self.online_convnet, self._global_step)\n      self.mask_update_op = sparse_utils.update_prune_masks(self.online_convnet)\n      self.mask_init_op = tf.no_op()\n    else:\n      raise ValueError(f'Invalid mode: {self._mode}')\n\n  def _build_train_op(self):\n    \"\"\"Builds a training op.\n\n    Returns:\n      train_op: An op performing one step of training from replay data.\n    \"\"\"\n    replay_action_one_hot = tf.one_hot(\n        self._replay.actions, self.num_actions, 1., 0., name='action_one_hot')\n    replay_chosen_q = tf.reduce_sum(\n        self._replay_net_outputs.q_values * replay_action_one_hot,\n        axis=1,\n        name='replay_chosen_q')\n\n    target = tf.stop_gradient(self._build_target_q_op())\n    loss = tf1.losses.huber_loss(\n        target, replay_chosen_q, reduction=tf.losses.Reduction.NONE)\n    loss = tf.reduce_mean(loss)\n    if self.summary_writer is not None:\n      tf1.summary.scalar('Losses/HuberLoss', loss)\n\n    reg_loss = 0.\n    if self._weight_decay:\n      for v in self.online_convnet.trainable_variables:\n        if 'bias' not in v.name:\n          reg_loss += tf.nn.l2_loss(v) * self._weight_decay\n      loss += reg_loss\n      tf1.summary.scalar('Losses/RegLoss', reg_loss)\n    tf1.summary.scalar('Losses/TotalLoss', loss)\n    sparse_utils.log_sparsities(self.online_convnet)\n    self._set_additional_ops()\n    grads_and_vars = self.optimizer.compute_gradients(loss)\n    train_op = self.optimizer.apply_gradients(\n        grads_and_vars, global_step=self._global_step)\n    self._create_summary_ops(grads_and_vars)\n    return train_op\n\n  def _create_summary_ops(self, grads_and_vars):\n    with tf1.variable_scope('Norm'):\n      all_norm = tf.norm(\n          flatten_list_of_vars(self.online_convnet.trainable_variables))\n      tf1.summary.scalar('online_convnet/weights_norm', all_norm)\n      all_norm = tf.norm(\n          flatten_list_of_vars(self.target_convnet.trainable_variables))\n      tf1.summary.scalar('target_convnet/weights_norm', all_norm)\n      all_grad_norm = tf.norm(\n          flatten_list_of_vars([\n              g for g, v in grads_and_vars\n              if v in self.online_convnet.trainable_variables\n          ]))\n      tf1.summary.scalar('online_convnet/grad_norm', all_grad_norm)\n\n    total_params, nparam_dict = sparse_utils.get_total_params(\n        self.online_convnet)\n    tf1.summary.scalar('params/total', total_params)\n    for k, val in nparam_dict.items():\n      tf1.summary.scalar('params/' + k, val)\n\n    if self._mode == 'rigl':\n      tf1.summary.scalar('drop_fraction', self.optimizer.drop_fraction)\n\n  def update_prune_step(self):\n    self._sess.run(self.step_update_op)\n\n  def maybe_update_and_apply_masks(self):\n    self._sess.run(self.mask_update_op)\n\n  def maybe_init_masks(self):\n    # If `dense`; no initialization.\n    self._sess.run(self.mask_init_op)\n\n  def _train_step(self):\n    if self._replay.memory.add_count > self.min_replay_history:\n      if self.training_steps % self.update_period == 0:\n        self.update_prune_step()\n        self.maybe_update_and_apply_masks()\n        self._sess.run(self._train_op)\n        c_step = self._sess.run(self._global_step)\n        if (self.summary_writer is not None and\n            self._merged_summaries is not None and\n            c_step % self.summary_writing_frequency == 0):\n          summary = self._sess.run(self._merged_summaries)\n          self.summary_writer.add_summary(summary, c_step)\n      if self.training_steps % self.target_update_period == 0:\n        # Mask weights before syncing\n        self.maybe_update_and_apply_masks()\n        self._sess.run(self._sync_qt_ops)\n\n    self.training_steps += 1\n\n  def _build_sync_op(self):\n    \"\"\"Builds ops for assigning weights from online to target network.\n\n    Returns:\n      ops: A list of ops assigning weights from online to target network.\n    \"\"\"\n    # Get trainable variables from online and target DQNs\n    sync_qt_ops = []\n    online_vars = sparse_utils.get_all_variables_and_masks(self.online_convnet)\n    target_vars = sparse_utils.get_all_variables_and_masks(self.target_convnet)\n    for (v_online, v_target) in zip(online_vars, target_vars):\n      # Assign weights from online to target network.\n      sync_qt_ops.append(v_target.assign(v_online, use_locking=True))\n    return sync_qt_ops\n\n  def _build_networks(self):\n    \"\"\"Builds the Q-value network computations needed for acting and training.\n\n    Same as the `super` class expect training=True flags are passed.\n    These are:\n      self.online_convnet: For computing the current state's Q-values.\n      self.target_convnet: For computing the next state's target Q-values.\n      self._net_outputs: The actual Q-values.\n      self._q_argmax: The action maximizing the current state's Q-values.\n      self._replay_net_outputs: The replayed states' Q-values.\n      self._replay_next_target_net_outputs: The replayed next states' target\n        Q-values (see Mnih et al., 2015 for details).\n    \"\"\"\n    self.online_convnet = self._create_network(name='Online')\n    self.target_convnet = self._create_network(name='Target')\n    self._net_outputs = self.online_convnet(self.state_ph, training=True)\n    self._q_argmax = tf.argmax(self._net_outputs.q_values, axis=1)[0]\n    self._replay_net_outputs = self.online_convnet(self._replay.states,\n                                                   training=True)\n    self._replay_next_target_net_outputs = self.target_convnet(\n        self._replay.next_states)\n"
  },
  {
    "path": "rigl/rl/requirements.txt",
    "content": "absl-py>=0.6.0\ndopamine-rl==4.0.5\ngin-config\nmujoco-py<2.2,>=2.1\nnumpy>=1.15.4\nsix>=1.12.0\ntensorflow==2.9.1  # change to 'tensorflow-gpu' for gpu support\ntensorflow-datasets==2.1\ntensorflow-model-optimization==0.7.2\ntf-agents[reverb]=0.13.0\n"
  },
  {
    "path": "rigl/rl/run.sh",
    "content": "# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/bin/bash\nset -e\nset -x\n\nvirtualenv -p python3 .\nsource ./bin/activate\n\npip install tensorflow\npip install -r sparse_rl/requirements.txt\npython -m sparse_rl.tfagents.sac_train_eval.py \\\n  --gin_file=sparse_rl/tfagents/configs/sac_mujoco_sparse_config.gin\n"
  },
  {
    "path": "rigl/rl/run_experiment.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Run policy evaluation as supervised learning, reloading representations.\"\"\"\n\nimport sys\n\nfrom absl import logging\nfrom dopamine.discrete_domains import gym_lib\nfrom dopamine.discrete_domains import run_experiment\nimport gin\nimport numpy as np\nfrom rigl.rl import dqn_agents\nimport tensorflow.compat.v1 as tf1\n\n# Last 10% of the training is averaged to get final reward.\nAVG_REWARD_FRAC = 0.1\n\n\n@gin.configurable\ndef create_sparse_agent(sess, num_actions, agent=None, summary_writer=None):\n  \"\"\"Creates a sparse agent.\n\n  Args:\n    sess: tf.Session.\n    num_actions: int, number of actions.\n    agent: str, type of learner/actor agent to create.\n    summary_writer: tf.SummaryWriter, for Tensorboard.\n\n  Returns:\n    A learner/actor agent.\n  \"\"\"\n  assert agent is not None\n  if agent == 'dqn':\n    return dqn_agents.SparseDQNAgent(\n        sess, num_actions, summary_writer=summary_writer)\n  else:\n    raise ValueError('Unknown learner agent: {}'.format(agent))\n\n\n@gin.configurable\nclass SparseTrainRunner(run_experiment.Runner):\n  \"\"\"Policy evaluation as supervised learning, from a loaded representation.\"\"\"\n\n  def __init__(self,\n               base_dir,\n               agent_type,\n               checkpoint_file_prefix='ckpt',\n               logging_file_prefix='log',\n               log_every_n=1,\n               num_iterations=200,\n               training_steps=250000,\n               evaluation_steps=125000,\n               max_steps_per_episode=27000,\n               load_env_fn=gym_lib.create_gym_environment,\n               clip_rewards=True,\n               atari_100k_eval=False,\n               num_eval_episodes=100,\n               observation_noise=None):\n    \"\"\"Initialize SparseTrainRunner in charge of running the experiment.\n\n    Args:\n      base_dir: str, the base directory to host all required sub-directories.\n      agent_type: str, defines the type of targets to be learned. Can be one of\n        {'dqn', 'rainbow'}.\n      checkpoint_file_prefix: str, the prefix to use for checkpoint files.\n      logging_file_prefix: str, prefix to use for the log files.\n      log_every_n: int, the frequency for writing logs.\n      num_iterations: int, the iteration number threshold (must be greater than\n        start_iteration).\n      training_steps: int, the number of training steps to perform.\n      evaluation_steps: int, the number of evaluation steps to perform.\n      max_steps_per_episode: int, maximum number of steps after which an episode\n        terminates.\n      load_env_fn: fn, function which loads and returns an environment.\n      clip_rewards: bool, whether to clip rewards in [-1, 1].\n      atari_100k_eval: bool, whether we are using the eval for Atari 100K.\n      num_eval_episodes: int, the number of full episodes to run during eval,\n        only used if atari_100k_eval is True.\n      observation_noise: float (optional), the stddev to use to add noise to the\n        observations before sending to the agent.\n    \"\"\"\n    self._logging_file_prefix = logging_file_prefix\n    self._log_every_n = log_every_n\n    self._num_iterations = num_iterations\n    self._training_steps = training_steps\n    self._evaluation_steps = evaluation_steps\n    self._max_steps_per_episode = max_steps_per_episode\n    self._clip_rewards = clip_rewards\n    self._atari_100k_eval = atari_100k_eval\n    self._num_eval_episodes = num_eval_episodes\n    self._base_dir = base_dir\n    self._create_directories()\n    self._summary_writer = tf1.summary.FileWriter(self._base_dir)\n    self._observation_noise = observation_noise\n\n    self._environment = load_env_fn()\n\n    num_actions = self._environment.action_space.n\n\n    config = tf1.ConfigProto(allow_soft_placement=True)\n    # Allocate only subset of the GPU memory as needed which allows for running\n    # multiple agents/workers on the same GPU.\n    config.gpu_options.allow_growth = True\n    # Set up a session and initialize variables.\n    self._sess = tf1.Session('local', config=config)\n    self._agent = create_sparse_agent(\n        self._sess, num_actions, agent=agent_type,\n        summary_writer=self._summary_writer)\n    self._summary_writer.add_graph(graph=tf1.get_default_graph())\n    self._sess.run(tf1.global_variables_initializer())\n\n    self._initialize_checkpointer_and_maybe_resume(checkpoint_file_prefix)\n\n  def _run_one_phase_fix_episodes(self, max_episodes, statistics):\n    \"\"\"Run one eval phase for the Atari 100k benchmark.\n\n    As opposed to the standard eval phase which runs for a fixed number of\n    steps, this will run for a fixed number of episodes, producing less noisy\n    results.\n\n    Args:\n      max_episodes: int, max number of episodes to run.\n      statistics: `IterationStatistics` object which records the experimental\n        results.\n\n    Returns:\n      Tuple containing the number of steps taken in this phase (int), the sum of\n        returns (float), and the number of episodes performed (int).\n    \"\"\"\n    step_count = 0\n    num_episodes = 0\n    sum_returns = 0.\n\n    while num_episodes < max_episodes:\n      episode_length, episode_return = self._run_one_episode()\n      statistics.append({\n          'eval_episode_lengths': episode_length,\n          'eval_episode_returns': episode_return\n      })\n      step_count += episode_length\n      sum_returns += episode_return\n      num_episodes += 1\n      # We use sys.stdout.write instead of logging so as to flush frequently\n      # without generating a line break.\n      sys.stdout.write('Steps executed: {} '.format(step_count) +\n                       'Episode length: {} '.format(episode_length) +\n                       'Num episodes: {} '.format(num_episodes) +\n                       'Return: {}\\r'.format(episode_return))\n      sys.stdout.flush()\n    return step_count, sum_returns, num_episodes\n\n  def _run_eval_phase(self, statistics):\n    if not self._atari_100k_eval:\n      return super()._run_eval_phase(statistics)\n    self._agent.eval_mode = True\n    _, sum_returns, num_episodes = self._run_one_phase_fix_episodes(\n        self._num_eval_episodes, statistics)\n    average_return = sum_returns / num_episodes if num_episodes > 0 else 0.0\n    logging.info('Average undiscounted return per evaluation episode: %.2f',\n                 average_return)\n    statistics.append({'eval_average_return': average_return})\n    return num_episodes, average_return\n\n  def _run_one_step(self, action):\n    \"\"\"Maybe adds noise to observations.\"\"\"\n    observation, reward, is_terminal, _ = self._environment.step(action)\n    if self._observation_noise is not None:\n      observation += np.random.normal(\n          scale=self._observation_noise,\n          size=observation.shape).astype(observation.dtype)\n    return observation, reward, is_terminal\n\n  def run_experiment(self):\n    \"\"\"Runs a full experiment, spread over multiple iterations.\"\"\"\n    logging.info('Beginning training...')\n    if self._num_iterations <= self._start_iteration:\n      logging.warning('num_iterations (%d) < start_iteration(%d)',\n                      self._num_iterations, self._start_iteration)\n      return\n    self._agent.update_prune_step()\n    self._agent.maybe_init_masks()\n    all_eval_returns = []\n    for iteration in range(self._start_iteration, self._num_iterations):\n      statistics = self._run_one_iteration(iteration)\n      all_eval_returns.append(statistics['eval_average_return'][-1])\n      self._log_experiment(iteration, statistics)\n      self._checkpoint_experiment(iteration)\n    last_n = int(self._num_iterations * AVG_REWARD_FRAC)\n    avg_return = np.mean(all_eval_returns[-last_n:])\n    logging.info('Step %d, Average Return: %f', iteration, avg_return)\n"
  },
  {
    "path": "rigl/rl/sparse_utils.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Defines pruning and sparse training utilities.\"\"\"\n\nimport functools\nimport re\n\nimport gin\nfrom rigl import sparse_optimizers_base as sparse_opt_base\nfrom rigl import sparse_utils\nfrom rigl.rigl_tf2 import init_utils\nimport tensorflow as tf\nimport tensorflow.compat.v1 as tf1\n\nfrom tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule\nfrom tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper\n\n\nPRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude\nPRUNED_LAYER_TYPES = (tf.keras.layers.Conv2D, tf.keras.layers.Dense)\n\n\ndef get_total_params(model):\n  \"\"\"Obtains total active parameters of a given network.\"\"\"\n  all_layers = get_all_layers(model)\n  total_count = 0.\n  nparams_dict = {}\n  for layer in all_layers:\n    n_param = 0.\n    if isinstance(layer, PRUNING_WRAPPER):\n      mask = layer.pruning_vars[0][1]\n      n_param += tf.reduce_sum(mask)\n      n_param += tf.size(layer.weights[1], out_type=tf.float32)\n    else:\n      for w in layer.weights:\n        n_param += tf.size(w, out_type=tf.float32)\n    nparams_dict[layer.name] = n_param\n    total_count += n_param\n  return total_count, nparams_dict\n\n\n@gin.configurable(denylist=['layer_dict'])\ndef get_pruning_sparsities(\n    layer_dict,\n    mask_init_method='erdos_renyi_kernel',\n    target_sparsity=0.9,\n    erk_power_scale=1.,\n    custom_sparsity_map=None):\n  \"\"\"Creates name/sparsity dict using the name/shapes dict (layer_dict).\"\"\"\n  if target_sparsity == 0:\n    return {k: 0 for k in layer_dict.keys()}\n\n  if custom_sparsity_map is None:\n    custom_sparsity_map = {}\n  extract_name_fn = lambda x: re.findall('(.+):0', x)[0]\n  dummy_masks_dict = {k: tf.ones(v) for k, v in layer_dict.items()}\n  reverse_dict = {v.name: k\n                  for k, v in dummy_masks_dict.items()}\n\n  sparsity_dict = sparse_utils.get_sparsities(\n      list(dummy_masks_dict.values()),\n      mask_init_method,\n      target_sparsity,\n      custom_sparsity_map,\n      extract_name_fn=extract_name_fn,\n      erk_power_scale=erk_power_scale)\n  renamed_sparsity_dict = {reverse_dict[k]: float(v)\n                           for k, v in sparsity_dict.items()}\n  return renamed_sparsity_dict\n\n\n@gin.configurable('pruning')\ndef get_pruning_params(mode,\n                       initial_sparsity=0.0,\n                       final_sparsity=0.95,\n                       begin_step=30000,\n                       end_step=100000,\n                       frequency=1000):\n  \"\"\"Gets pruning hyper-parameters.\"\"\"\n  p_params = {}\n  if mode == 'prune':\n    p_params['pruning_schedule'] = pruning_schedule.PolynomialDecay(\n        initial_sparsity=initial_sparsity,\n        final_sparsity=final_sparsity,\n        begin_step=begin_step,\n        end_step=end_step,\n        frequency=frequency)\n  elif mode in ('rigl', 'static', 'set'):\n    # For sparse training methods we don't use the pruning library to update the\n    # masks. Therefore we need to disable it. Following `pruning` flags serve\n    # that purpose.\n    # 1B. High begin_step, so it never starts.\n    p_params['pruning_schedule'] = pruning_schedule.ConstantSparsity(\n        target_sparsity=0, begin_step=1000000000)\n  else:\n    raise ValueError('Mode: %s, is not valid' % mode)\n  return p_params\n\n\ndef maybe_prune_layer(layer, params, filter_fn=None):\n  if filter_fn is None:\n    filter_fn = lambda l: isinstance(l, PRUNED_LAYER_TYPES)\n  if filter_fn(layer):\n    return PRUNING_WRAPPER(layer, **params)\n  return layer\n\n\ndef get_wrap_fn(mode):\n  \"\"\"Creates a function that wraps a given layer conditionally.\n\n  Args:\n    mode: str, If 'dense' no modification done. Otherwise the layer is pruned.\n\n  Returns:\n    function that accepts layer and returns a possibly wrapped one.\n  \"\"\"\n  if mode == 'dense':\n    # Do not wrap the layer.\n    wrap_fn = lambda x: x\n  else:\n    wrap_fn = functools.partial(\n        maybe_prune_layer, params=get_pruning_params(mode))\n  return wrap_fn\n\n\ndef update_prune_step(model, step):\n  \"\"\"Updates the pruning steps of each pruning layer.\"\"\"\n  assign_ops = []\n  for layer in get_all_pruning_layers(model):\n    # Assign iteration count to the layer pruning_step.\n    # pruning wrapper requires step to be >0.\n    assign_op = tf1.assign(layer.pruning_step, tf.maximum(step, 1))\n    assign_ops.append(assign_op)\n  return tf.group(assign_ops)\n\n\ndef update_prune_masks(model):\n  \"\"\"Updates the masks if it is an update iteration.\"\"\"\n  update_ops = [op for op in model.updates\n                if 'prune_low_magnitude' in op.name]\n  return tf.group(update_ops)\n\n\ndef get_all_layers(model, filter_fn=lambda _: True):\n  \"\"\"Gets all layers of a model and layers of a layer if it is a keras.Model.\"\"\"\n  all_layers = []\n  for l in model.layers:\n    if hasattr(l, 'layers'):\n      all_layers.extend(get_all_layers(l, filter_fn=filter_fn))\n    elif filter_fn(l):\n      all_layers.append(l)\n  return all_layers\n\n\ndef get_all_variables_and_masks(model):\n  \"\"\"Gets all trainable variables (+their masks) of a model.\"\"\"\n  all_layers = get_all_layers(model)\n  all_variables = []\n  for l in all_layers:\n    all_variables.extend(l.trainable_variables)\n    if isinstance(l, PRUNING_WRAPPER):\n      all_variables.append(l.pruning_vars[0][1])  # Adding mask.\n  return all_variables\n\n\ndef get_all_pruning_layers(model):\n  \"\"\"Gets all pruned layers of a model and layers of a layer if keras.Model.\"\"\"\n  return get_all_layers(\n      model, filter_fn=lambda l: isinstance(l, PRUNING_WRAPPER))\n\n\ndef log_sparsities(model):\n  for layer in get_all_pruning_layers(model):\n    for _, mask, threshold in layer.pruning_vars:\n      scalar_name = f'sparsity/{mask.name}'\n      sparsity = 1 - tf.reduce_mean(mask)\n      if len(mask.shape) == 2:\n        reshaped_mask = tf.expand_dims(tf.expand_dims(mask, 0), -1)\n        tf1.summary.image(f'img/{mask.name}', reshaped_mask)\n      tf1.summary.scalar(scalar_name, sparsity)\n      tf1.summary.scalar(f'threshold/{threshold.name}', threshold)\n\n\nclass SparseOptTf2Mixin:\n  \"\"\"Tf2 model_optimization pruning library specific variable retrieval.\"\"\"\n\n  def compute_gradients(self, *args, **kwargs):\n    \"\"\"Wraps the compute gradient of passed optimizer.\"\"\"\n    return self._optimizer.compute_gradients(*args, **kwargs)\n\n  def set_model(self, model):\n    self.model = model\n\n  def get_weights(self):\n    all_weights = [\n        layer.pruning_vars[0][0] for layer in get_all_pruning_layers(self.model)\n    ]\n    return all_weights\n\n  def get_masks(self):\n    all_masks = [\n        layer.pruning_vars[0][1] for layer in get_all_pruning_layers(self.model)\n    ]\n    return all_masks\n\n  def get_masked_weights(self):\n    all_masked_weights = [\n        w * m for w, m in zip(self.get_weights(), self.get_masks())\n    ]\n    return all_masked_weights\n\n\n@gin.configurable()\nclass UpdatedSETOptimizer(SparseOptTf2Mixin,\n                          sparse_opt_base.SparseSETOptimizerBase):\n\n  def _before_apply_gradients(self, grads_and_vars):\n    return tf1.no_op()\n\n\n@gin.configurable()\nclass UpdatedRigLOptimizer(SparseOptTf2Mixin,\n                           sparse_opt_base.SparseRigLOptimizerBase):\n\n  def _before_apply_gradients(self, grads_and_vars):\n    \"\"\"Updates momentum before updating the weights with gradient.\"\"\"\n    self._weight2masked_grads = {w.name: g for g, w in grads_and_vars}\n    return tf1.no_op()\n\n\n@gin.configurable()\ndef init_masks(model,\n               mask_init_method='random',\n               sparsity=0.9,\n               erk_power_scale=1.,\n               custom_sparsity_map=None,\n               fixed_sparse_init=False):\n  \"\"\"Inits the masks randomly according to the given sparsity.\"\"\"\n  if sparsity == 0:\n    return None\n\n  if custom_sparsity_map is None:\n    custom_sparsity_map = {}\n  all_masks = [\n      layer.pruning_vars[0][1] for layer in get_all_pruning_layers(model)\n  ]\n\n  assigner = sparse_utils.get_mask_init_fn(\n      all_masks,\n      mask_init_method,\n      sparsity,\n      custom_sparsity_map,\n      erk_power_scale=erk_power_scale)\n  if fixed_sparse_init:\n    all_weights = [\n        layer.pruning_vars[0][0] for layer in get_all_pruning_layers(model)\n    ]\n    with tf.control_dependencies([assigner]):\n      assign_ops = []\n      for param, mask in zip(all_weights, all_masks):\n        new_init = init_utils.unit_scaled_init_tf1(mask)\n        assign_ops.append(tf1.assign(param, new_init))\n      assigner = tf.group(assign_ops)\n  return assigner\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_dense.gin",
    "content": "include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin'\n\nimport rigl.rl.dqn_agents\n\nDQNAgent.network = @dqn_agents.NatureDQNNetwork\nDQNAgent.optimizer = @tf.train.AdamOptimizer()\ntf.train.AdamOptimizer.learning_rate = 0.00025\n\nWrappedReplayBuffer.batch_size = 32  # Same as original\n\nSparseDQNAgent.mode = 'dense'\nSparseDQNAgent.weight_decay = 0.0\n\natari_lib.create_atari_environment.game_name = 'Pong'\nSparseTrainRunner.load_env_fn = @atari_lib.create_atari_environment\n\nSparseTrainRunner.agent_type = 'dqn'\nSparseTrainRunner.num_iterations = 40\nSparseTrainRunner.training_steps = 250000\nSparseTrainRunner.evaluation_steps = 125000\nSparseTrainRunner.max_steps_per_episode = 27000  # Default max episode length.\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin",
    "content": "include 'third_party/py/dopamine/agents/dqn/configs/dqn.gin'\n\nimport rigl.rl.dqn_agents\n\nDQNAgent.network = @dqn_agents.ImpalaNetwork\nDQNAgent.optimizer = @tf.train.AdamOptimizer()\ntf.train.AdamOptimizer.learning_rate = 0.0001\ntf.train.AdamOptimizer.epsilon = 0.0003125\n\nWrappedReplayBuffer.batch_size = 32  # Same as original\n\nSparseDQNAgent.mode = 'dense'\nSparseDQNAgent.weight_decay = 1e-05\n\natari_lib.create_atari_environment.game_name = 'Pong'\nSparseTrainRunner.load_env_fn = @atari_lib.create_atari_environment\n\nSparseTrainRunner.agent_type = 'dqn'\nSparseTrainRunner.num_iterations = 40\nSparseTrainRunner.training_steps = 250000\nSparseTrainRunner.evaluation_steps = 125000\nSparseTrainRunner.max_steps_per_episode = 27000  # Default max episode length.\n\nImpalaNetwork.use_batch_norm = False\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_prune.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'prune'\n\n\nget_pruning_sparsities.target_sparsity = 0.95\nget_pruning_sparsities.mask_init_method = 'erdos_renyi_kernel'\n\npruning.initial_sparsity = 0.0\n# 0.5M = 20% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\npruning.begin_step = 500000  # 500k\n# 2M = 80% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\npruning.end_step = 2000000    # 2M\npruning.frequency = 5000\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_prune_impala_net.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'prune'\n\n\nget_pruning_sparsities.target_sparsity = 0.95\nget_pruning_sparsities.mask_init_method = 'erdos_renyi_kernel'\n\npruning.initial_sparsity = 0.0\n# 0.5M = 20% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\npruning.begin_step = 500000  # 500k\n# 2M = 80% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\npruning.end_step = 2000000    # 2M\npruning.frequency = 5000\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_rigl.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'rigl'\n\n# For sparse training methods we don't use the pruning library to update the\n# masks. Therefore we need to disable it. Following `pruning` flags serve that\n# purpose.\npruning.final_sparsity = 0.\npruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.\n\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.fixed_sparse_init = True\ninit_masks.sparsity = 0.9\n\nUpdatedRigLOptimizer.begin_step = 0\n# 2M = 80% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\nUpdatedRigLOptimizer.end_step = 2000000\nUpdatedRigLOptimizer.frequency = 5000\nUpdatedRigLOptimizer.drop_fraction_anneal = 'cosine'\nUpdatedRigLOptimizer.drop_fraction = 0.3\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_rigl_impala_net.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'rigl'\n\n# For sparse training methods we don't use the pruning library to update the\n# masks. Therefore we need to disable it. Following `pruning` flags serve that\n# purpose.\npruning.final_sparsity = 0.\npruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.\n\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.fixed_sparse_init = True\ninit_masks.sparsity = 0.9\n\nUpdatedRigLOptimizer.begin_step = 0\n# 2M = 80% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\nUpdatedRigLOptimizer.end_step = 2000000\nUpdatedRigLOptimizer.frequency = 5000\nUpdatedRigLOptimizer.drop_fraction_anneal = 'cosine'\nUpdatedRigLOptimizer.drop_fraction = 0.3\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_set.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'set'\n\n# For sparse training methods we don't use the pruning library to update the\n# masks. Therefore we need to disable it. Following `pruning` flags serve that\n# purpose.\npruning.final_sparsity = 0.\npruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.\n\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.fixed_sparse_init = True\ninit_masks.sparsity = 0.9\n\nUpdatedSETOptimizer.begin_step = 0\n# 2M = 80% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\nUpdatedSETOptimizer.end_step = 2000000\nUpdatedSETOptimizer.frequency = 5000\nUpdatedSETOptimizer.drop_fraction_anneal = 'cosine'\nUpdatedSETOptimizer.drop_fraction = 0.3\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_set_impala_net.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'set'\n\n# For sparse training methods we don't use the pruning library to update the\n# masks. Therefore we need to disable it. Following `pruning` flags serve that\n# purpose.\npruning.final_sparsity = 0.\npruning.begin_step = 1000000000 # 1B. High begin_step, so it never starts.\n\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.fixed_sparse_init = True\ninit_masks.sparsity = 0.9\n\nUpdatedSETOptimizer.begin_step = 0\n# 2M = 80% optimizer steps when training for 40M env steps with a frame skip\n# of 4 (= 10M transitions), and training every 4th env transition (2.5M train\n# steps in total).\nUpdatedSETOptimizer.end_step = 2000000\nUpdatedSETOptimizer.frequency = 5000\nUpdatedSETOptimizer.drop_fraction_anneal = 'cosine'\nUpdatedSETOptimizer.drop_fraction = 0.3\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_static.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense.gin'\n\nSparseDQNAgent.mode = 'static'\n\n# For sparse training methods we don't use the pruning library to update the\n# masks. Therefore we need to disable it. Following `pruning` flags serve that\n# purpose.\npruning.final_sparsity = 0.\npruning.begin_step = 1000000000  # 1B. High begin_step, so it never starts.\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.sparsity = 0.95\n"
  },
  {
    "path": "rigl/rl/sparsetrain_configs/dqn_atari_static_impala_net.gin",
    "content": "include 'rigl/rl/sparsetrain_configs/dqn_atari_dense_impala_net.gin'\n\nSparseDQNAgent.mode = 'static'\n\n# For sparse training methods we don't use the pruning library to update the\n# masks. Therefore we need to disable it. Following `pruning` flags serve that\n# purpose.\npruning.final_sparsity = 0.\npruning.begin_step = 1000000000  # 1B. High begin_step, so it never starts.\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.sparsity = 0.95\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/dqn_gym_dense_config.gin",
    "content": "# Configs to run DQN training for dense networks on classic control environments.\n\ntrain_eval.env_name='CartPole-v0'\ntrain_eval.fc_layer_params = (512, 512)\ntrain_eval.target_update_period = 100\ntrain_eval.batch_size = 128\n# Environment:train steps ratio is 1:1\ntrain_eval.num_iterations = 100000\ntrain_eval.weight_decay = 1e-6\ntrain_eval.width = 1.0\ntrain_eval.policy_save_interval = 10000\ntrain_eval.epsilon_greedy = 0.01\ntrain_eval.eval_interval = 2000\ntrain_eval.eval_episodes = 20\n\ntrain_eval.sparse_output_layer = False\ntrain_eval.train_mode = 'dense'\n\nmask_updater.update_alg = ''\nmask_updater.schedule_alg = ''\nlog_snr.freq=5000\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/dqn_gym_pruning_config.gin",
    "content": "include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin'\n\n# Configs to run DQN training for pruning on classic control environments.\n\ntrain_eval.sparse_output_layer = True\ntrain_eval.train_mode = 'sparse'\n\n# This must be set to 0 when pruning to avoid\n# initializing the masks\ninit_masks.sparsity = 0.0\n\nwrap_all_layers.mode = 'prune'\nwrap_all_layers.initial_sparsity = 0.0\nwrap_all_layers.final_sparsity = 0.9\nwrap_all_layers.mask_init_method = 'erdos_renyi_kernel'\n# Environment:train steps ratio is 1:1\n# We start pruning after 20% training (20,000) and stop after 75% (75,000)\nwrap_all_layers.begin_step = 20000\nwrap_all_layers.end_step = 75000\nwrap_all_layers.frequency = 1000\n\nlog_sparsities.log_images = False\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/dqn_gym_sparse_config.gin",
    "content": "include 'rigl/rl/tfagents/configs/dqn_gym_dense_config.gin'\n\n# Configs to run DQN training for static, set, and rigl on classic control\n# environments.\n\ntrain_eval.sparse_output_layer = True\ntrain_eval.train_mode = 'sparse'\n\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.fixed_sparse_init = True\ninit_masks.sparsity = 0.9\n\n# For static, set this to ''\n# For rigl set this to 'rigl'\n# For set set this to 'set'\nmask_updater.update_alg = ''\nmask_updater.schedule_alg = 'cosine'\nmask_updater.update_freq = 1000\nmask_updater.init_drop_fraction = 0.5\n# Environment:train steps ratio is 1:1, we stop after 75% training = 75,000\nmask_updater.last_update_step = 75000\nmask_updater.use_stateless = False\n\nwrap_all_layers.mode = 'constant'\n\nlog_sparsities.log_images = False\n\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin",
    "content": "# Config to run training for dense on mujoco environments.\n\ntrain_eval.env_name='HalfCheetah-v2'\ntrain_eval.actor_fc_layers = (64, 64)\ntrain_eval.value_fc_layers = (64, 64)\n# In order to execute ~1M environment steps, we run 489 iterations\n# (`--num_iterations=489`) which results in 1,001,472 environment steps. Each\n# iteration results in 320 training steps (or 320 gradient updates, this is\n# calulated from environemnt_steps * num_epochs / minibatch_size) and 2,048\n# environment steps. Thus 489 *2,048 = 1,001,472 environment steps and\n# 489 * 320 = 156,480 training steps.\ntrain_eval.num_iterations = 489\ntrain_eval.weight_decay = 1e-6\ntrain_eval.width = 1.0\ntrain_eval.policy_save_interval = 51000\ntrain_eval.num_epochs = 10\ntrain_eval.eval_interval = 2000\ntrain_eval.eval_episodes = 20\n\ntrain_eval.sparse_output_layer = False\ntrain_eval.train_mode_actor = 'dense'\ntrain_eval.train_mode_value = 'dense'\n\nmask_updater.update_alg = ''\nmask_updater.schedule_alg = ''\nlog_snr.freq=5000\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/ppo_mujoco_pruning_config.gin",
    "content": "include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin'\n\ntrain_eval.sparse_output_layer = True\ntrain_eval.train_mode_actor = 'sparse'\ntrain_eval.train_mode_value = 'sparse'\n\n# This must be set to 0 when pruning to avoid\n# initializing the masks\ninit_masks.sparsity = 0.0\n\nwrap_all_layers.mode = 'prune'\nwrap_all_layers.initial_sparsity = 0.0\nwrap_all_layers.final_sparsity = 0.9\nwrap_all_layers.mask_init_method = 'erdos_renyi_kernel'\n# 156,480 steps total\n# Start at ~20% = 31,296\n# End at ~75% = 117,360\nwrap_all_layers.begin_step = 32000\nwrap_all_layers.end_step = 120000\nwrap_all_layers.frequency = 500\n\nlog_sparsities.log_images = False\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/ppo_mujoco_sparse_config.gin",
    "content": "include 'rigl/rl/tfagents/configs/ppo_mujoco_dense_config.gin'\n\n# Config to run PPO training for static, set, and rigl on mujoco environments.\n\ntrain_eval.sparse_output_layer = True\ntrain_eval.train_mode_actor = 'sparse'\ntrain_eval.train_mode_value = 'sparse'\ntrain_eval.weight_decay = 1e-4\n\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.fixed_sparse_init = True\ninit_masks.sparsity = 0.9\n\n# For static, set this to ''\n# For rigl set this to 'rigl'\n# For set set this to 'set'\nmask_updater.update_alg = ''\nmask_updater.schedule_alg = 'cosine'\nmask_updater.update_freq = 250\nmask_updater.init_drop_fraction = 0.3\n# 156,480 steps total, end at 75% = 117,360\nmask_updater.last_update_step = 120000\nmask_updater.use_stateless = False\n\nwrap_all_layers.mode = 'constant'\n\nlog_sparsities.log_images = False\n\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin",
    "content": "# Config to run SAC training for dense on mujoco environments.\n\ntrain_eval.env_name = 'Humanoid-v2'\ntrain_eval.initial_collect_steps = 1000\ntrain_eval.num_iterations = 1000000  # 1M\ntrain_eval.width = 1.0\ntrain_eval.weight_decay = 1e-4\n\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/sac_mujoco_pruning_config.gin",
    "content": "include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin'\n\n# Configs to run SAC training for pruning on mujoco environments.\ntrain_eval.train_mode_actor = 'sparse'\n# Both critics\ntrain_eval.train_mode_value = 'sparse'\n\ntrain_eval.sparse_output_layer = True\n\ninit_masks.fixed_sparse_init = True\n# This must be set to 0 when pruning to avoid\n# initializing the masks\ninit_masks.sparsity = 0.0\n\nwrap_all_layers.mode = 'prune'\nwrap_all_layers.initial_sparsity = 0.0\nwrap_all_layers.final_sparsity = 0.9\nwrap_all_layers.mask_init_method = 'erdos_renyi_kernel'\n# 1M steps total\n# Start at 20%, end at 80%\nwrap_all_layers.begin_step = 200000\nwrap_all_layers.end_step = 800000\nwrap_all_layers.frequency = 1000\n\nlog_sparsities.log_images = False\n"
  },
  {
    "path": "rigl/rl/tfagents/configs/sac_mujoco_sparse_config.gin",
    "content": "include 'rigl/rl/tfagents/configs/sac_mujoco_dense_config.gin'\n\n# Configs to run SAC training for static, set, and rigl on mujoco\n# environments.\n\ntrain_eval.sparse_output_layer = True\ntrain_eval.train_mode_actor = 'sparse'\n# Both critics\ntrain_eval.train_mode_value = 'sparse'\ntrain_eval.actor_critic_sparsities_str = ''\ntrain_eval.weight_decay = 1e-6\n\ninit_masks.mask_init_method = 'erdos_renyi_kernel'\ninit_masks.fixed_sparse_init = True\ninit_masks.sparsity = 0.9\n\nmask_updater.update_alg = ''\nmask_updater.schedule_alg = 'cosine'\nmask_updater.update_freq = 1000\nmask_updater.init_drop_fraction = 0.5\n# 1M / train_eval.num_iterations * 0.8\nmask_updater.last_update_step = 800000\nmask_updater.use_stateless = False\n\nwrap_all_layers.mode = 'constant'\n\nlog_sparsities.log_images = False\n"
  },
  {
    "path": "rigl/rl/tfagents/dqn_train_eval.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Sparse training DQN using actor/learner in a gym environment.\n\"\"\"\nimport functools\nimport os\n\nfrom typing import Tuple\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\n\nimport gin\nimport numpy as np\nimport reverb\nfrom rigl.rigl_tf2 import mask_updaters\nfrom rigl.rl import sparse_utils\nfrom rigl.rl.tfagents import tf_sparse_utils\nimport tensorflow.compat.v2 as tf\n\nfrom tf_agents.agents.dqn import dqn_agent\nfrom tf_agents.environments import suite_atari\nfrom tf_agents.environments import suite_gym\nfrom tf_agents.metrics import py_metrics\nfrom tf_agents.networks import sequential\nfrom tf_agents.policies import py_tf_eager_policy\nfrom tf_agents.policies import random_py_policy\nfrom tf_agents.replay_buffers import reverb_replay_buffer\nfrom tf_agents.replay_buffers import reverb_utils\nfrom tf_agents.specs import tensor_spec\nfrom tf_agents.system import system_multiprocessing as multiprocessing\nfrom tf_agents.train import actor\nfrom tf_agents.train import learner\nfrom tf_agents.train import triggers\nfrom tf_agents.train.utils import train_utils\nfrom tf_agents.utils import common\nfrom tf_agents.utils import eager_utils\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),\n                    'Root directory for writing logs/summaries/checkpoints.')\nflags.DEFINE_integer(\n    'reverb_port', None,\n    'Port for reverb server, if None, use a randomly chosen unused port.')\n\nflags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')\nflags.DEFINE_multi_string(\n    'gin_bindings', [],\n    'Gin bindings to override the values set in the config files '\n    '(e.g. \"train_eval.env_name=Acrobot-v1\",'\n    '      \"init_masks.sparsity=0.9\").')\nflags.DEFINE_float(\n    'average_last_fraction', 0.1,\n    'Tells what fraction latest evaluation scores are averaged. This is used'\n    ' to reduce variance.')\n\n\n\n@gin.configurable\nclass SparseDqnAgent(dqn_agent.DqnAgent):\n  \"\"\"Wrapped DqnAgent that supports sparse training.\"\"\"\n\n  def __init__(self, *args, **kwargs):\n    super().__init__(*args, **kwargs)\n    _ = sparse_utils.init_masks(self._q_network)\n    def loss_fn(experience_data, weights_data):\n      # The following is just to fit to the existing API.\n      loss_info = self._loss(\n          experience_data,\n          td_errors_loss_fn=self._td_errors_loss_fn,\n          gamma=self._gamma,\n          reward_scale_factor=self._reward_scale_factor,\n          weights=weights_data,\n          training=True)\n      return loss_info.extra.td_loss\n    # Create mask updater if doesn't exists\n    self._mask_updater = mask_updaters.get_mask_updater(\n        self._q_network, self._optimizer, loss_fn)\n\n  def _train(self, experience, weights):\n    tf.compat.v2.summary.experimental.set_step(self.train_step_counter)\n\n    tf_sparse_utils.update_prune_step(self._q_network, self._train_step_counter)\n    with tf.GradientTape(persistent=True) as tape:\n      loss_info = self._loss(\n          experience,\n          td_errors_loss_fn=self._td_errors_loss_fn,\n          gamma=self._gamma,\n          reward_scale_factor=self._reward_scale_factor,\n          weights=weights,\n          training=True)\n    tf.debugging.check_numerics(loss_info.loss, 'Loss is inf or nan')\n    variables_to_train = self._q_network.trainable_weights\n    non_trainable_weights = self._q_network.non_trainable_weights\n    assert list(variables_to_train), \"No variables in the agent's q_network.\"\n    grads = tape.gradient(loss_info.loss, variables_to_train)\n\n    tf_sparse_utils.log_snr(tape, loss_info.extra.td_loss,\n                            self.train_step_counter, variables_to_train)\n\n    # Tuple is used for py3, where zip is a generator producing values once.\n    grads_and_vars = list(zip(grads, variables_to_train))\n\n    def _mask_update_step():\n      # Second argument is not used.\n      self._mask_updater.set_validation_data(experience, weights)\n      self._mask_updater.update(self.train_step_counter)\n      with tf.name_scope('/'):\n        tf.summary.scalar(\n            name='drop_fraction', data=self._mask_updater.last_drop_fraction)\n\n    tf_sparse_utils.log_sparsities(self._q_network)\n    if self._mask_updater is not None:\n      is_update = self._mask_updater.is_update_iter(self.train_step_counter)\n      tf.cond(is_update, _mask_update_step, lambda: None)\n\n    if self._gradient_clipping is not None:\n      grads_and_vars = eager_utils.clip_gradient_norms(grads_and_vars,\n                                                       self._gradient_clipping)\n\n    if self._summarize_grads_and_vars:\n      grads_and_vars_with_non_trainable = (\n          grads_and_vars + [(None, v) for v in non_trainable_weights])\n      eager_utils.add_variables_summaries(grads_and_vars_with_non_trainable,\n                                          self.train_step_counter)\n      eager_utils.add_gradients_summaries(grads_and_vars,\n                                          self.train_step_counter)\n    self._optimizer.apply_gradients(grads_and_vars)\n    self.train_step_counter.assign_add(1)\n\n    self._update_target()\n\n    return loss_info\n\n\ndef _scale_width(num_units, width):\n  assert width > 0\n  return int(max(1, num_units * width))\n\n\ndef build_network(\n    fc_layer_params,\n    num_actions,\n    is_sparse,\n    input_dim,\n    width = 1.0,\n    weight_decay = 0.0,\n    sparse_output_layer = True\n    ):\n  \"\"\"Builds a Sequential model.\"\"\"\n\n  def dense_layer(num_units):\n    return tf.keras.layers.Dense(\n        num_units,\n        activation=tf.keras.activations.relu,\n        kernel_initializer=tf.keras.initializers.VarianceScaling(\n            scale=2.0, mode='fan_in', distribution='truncated_normal'),\n        kernel_regularizer=tf.keras.regularizers.L2(weight_decay),)\n\n  # QNetwork consists of a sequence of Dense layers followed by a dense layer\n  # with `num_actions` units to generate one q_value per available action as\n  # its output.\n  all_layers = [\n      dense_layer(_scale_width(num_units, width=width)\n                  ) for num_units in fc_layer_params]\n  all_layers.append(\n      tf.keras.layers.Dense(\n          num_actions,\n          activation=None,\n          kernel_initializer=tf.keras.initializers.RandomUniform(\n              minval=-0.03, maxval=0.03),\n          bias_initializer=tf.keras.initializers.Constant(-0.2)))\n  if is_sparse:\n    if sparse_output_layer:\n      all_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)\n    else:\n      all_layers = (tf_sparse_utils.wrap_all_layers(all_layers[:-1], input_dim)\n                    + all_layers[-1:])\n  return sequential.Sequential(all_layers)\n\n\n\n\n@gin.configurable\ndef train_eval(\n    root_dir,\n    env_name='CartPole-v0',\n    # Training params\n    update_frequency=1,\n    initial_collect_steps=1000,\n    num_iterations=100000,\n    fc_layer_params=(100,),\n    # Agent params\n    epsilon_greedy=0.1,\n    epsilon_decay_period=250000,\n    batch_size=64,\n    learning_rate=1e-3,\n    n_step_update=1,\n    gamma=0.99,\n    target_update_tau=1.0,\n    target_update_period=100,\n    reward_scale_factor=1.0,\n    # Replay params\n    reverb_port=None,\n    replay_capacity=100000,\n    # Others\n    policy_save_interval=1000,\n    eval_interval=1000,\n    eval_episodes=10,\n    weight_decay = 0.0,\n    width = 1.0,\n    debug_summaries=False,\n    sparse_output_layer=True,\n    train_mode='dense'):\n  \"\"\"Trains and evaluates DQN.\"\"\"\n\n  logging.info('DQN params: Fc layer params: %s', fc_layer_params)\n  logging.info('DQN params: Train mode: %s', train_mode)\n  logging.info('DQN params: Target update period: %s', target_update_period)\n  logging.info('DQN params: Policy save interval: %s', policy_save_interval)\n  logging.info('DQN params: Eval interval: %s', eval_interval)\n  logging.info('DQN params: Environment name: %s', env_name)\n  logging.info('DQN params: Weight decay: %s', weight_decay)\n  logging.info('DQN params: Width: %s', width)\n  logging.info('DQN params: Batch size: %s', batch_size)\n  logging.info('DQN params: Target update period: %s', target_update_period)\n  logging.info('DQN params: Learning rate: %s', learning_rate)\n  logging.info('DQN params: Num iterations: %s', num_iterations)\n  logging.info('DQN params: Sparse output layer: %s', sparse_output_layer)\n\n  collect_env = suite_gym.load(env_name)\n  eval_env = suite_gym.load(env_name)\n  logging.info('Collect env: %s', collect_env)\n  logging.info('Eval env: %s', eval_env)\n\n  time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec())\n  action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec())\n  train_step = train_utils.create_train_step()\n  num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1\n  observation_shape = collect_env.observation_spec().shape\n  # Build network and get pruning params\n  is_atari = False\n  if not is_atari:\n    q_net = build_network(\n        fc_layer_params=fc_layer_params,\n        num_actions=num_actions,\n        is_sparse=(train_mode == 'sparse'),\n        # observation_shape is 1-dimensional. We need this so that we can\n        # calculate the dimensions of the first layer.\n        input_dim=observation_shape[-1],\n        width=width,\n        weight_decay=weight_decay,\n        sparse_output_layer=sparse_output_layer)\n    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n    loss = common.element_wise_squared_loss\n    decay_fn = epsilon_greedy\n\n  agent = SparseDqnAgent(\n      time_step_tensor_spec,\n      action_tensor_spec,\n      q_network=q_net,\n      epsilon_greedy=decay_fn,\n      n_step_update=n_step_update,\n      target_update_tau=target_update_tau,\n      target_update_period=target_update_period,\n      optimizer=optimizer,\n      td_errors_loss_fn=loss,\n      gamma=gamma,\n      reward_scale_factor=reward_scale_factor,\n      train_step_counter=train_step,\n      debug_summaries=debug_summaries)\n  table_name = 'uniform_table'\n  table = reverb.Table(\n      table_name,\n      max_size=replay_capacity,\n      sampler=reverb.selectors.Uniform(),\n      remover=reverb.selectors.Fifo(),\n      rate_limiter=reverb.rate_limiters.MinSize(1))\n  reverb_server = reverb.Server([table], port=reverb_port)\n  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(\n      agent.collect_data_spec,\n      sequence_length=2,\n      table_name=table_name,\n      local_server=reverb_server)\n  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(\n      reverb_replay.py_client, table_name,\n      sequence_length=2,\n      stride_length=1)\n\n  dataset = reverb_replay.as_dataset(\n      num_parallel_calls=3, sample_batch_size=batch_size,\n      num_steps=2).prefetch(3)\n  experience_dataset_fn = lambda: dataset\n\n  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)\n  env_step_metric = py_metrics.EnvironmentSteps()\n\n  learning_triggers = [\n      triggers.PolicySavedModelTrigger(\n          saved_model_dir,\n          agent,\n          train_step,\n          interval=policy_save_interval,\n          metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}),\n      triggers.StepPerSecondLogTrigger(train_step, interval=100),\n  ]\n\n  dqn_learner = learner.Learner(\n      root_dir,\n      train_step,\n      agent,\n      experience_dataset_fn,\n      triggers=learning_triggers,\n      run_optimizer_variable_init=False)\n\n  # If we haven't trained yet make sure we collect some random samples first to\n  # fill up the Replay Buffer with some experience.\n  random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),\n                                                  collect_env.action_spec())\n  initial_collect_actor = actor.Actor(\n      collect_env,\n      random_policy,\n      train_step,\n      steps_per_run=initial_collect_steps,\n      observers=[rb_observer])\n  logging.info('Doing initial collect.')\n  initial_collect_actor.run()\n\n  tf_collect_policy = agent.collect_policy\n  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,\n                                                      use_tf_function=True)\n\n  collect_actor = actor.Actor(\n      collect_env,\n      collect_policy,\n      train_step,\n      steps_per_run=update_frequency,\n      observers=[rb_observer, env_step_metric],\n      metrics=actor.collect_metrics(10),\n      reference_metrics=[env_step_metric],\n      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),\n  )\n\n  tf_greedy_policy = agent.policy\n  greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,\n                                                     use_tf_function=True)\n\n  eval_actor = actor.Actor(\n      eval_env,\n      greedy_policy,\n      train_step,\n      episodes_per_run=eval_episodes,\n      metrics=actor.eval_metrics(eval_episodes),\n      reference_metrics=[env_step_metric],\n      summary_dir=os.path.join(root_dir, 'eval'),\n  )\n\n  average_returns = []\n  if eval_interval:\n    logging.info('Evaluating.')\n    eval_actor.run_and_log()\n    for metric in eval_actor.metrics:\n      if isinstance(metric, py_metrics.AverageReturnMetric):\n        average_returns.append(metric._buffer.mean())\n\n  logging.info('Training.')\n  for _ in range(num_iterations):\n    collect_actor.run()\n    dqn_learner.run(iterations=1)\n\n    if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:\n      logging.info('Evaluating.')\n      eval_actor.run_and_log()\n      for metric in eval_actor.metrics:\n        if isinstance(metric, py_metrics.AverageReturnMetric):\n          average_returns.append(metric._buffer.mean())\n\n  # Log last section of evaluation scores for the final metric.\n  idx = int(FLAGS.average_last_fraction * len(average_returns))\n  avg_return = np.mean(average_returns[-idx:])\n  logging.info('Step %d, Average Return: %f', env_step_metric.result(),\n               avg_return)\n\n\n  rb_observer.close()\n  reverb_server.stop()\n\n\ndef main(_):\n  tf.config.experimental_run_functions_eagerly(False)\n  logging.set_verbosity(logging.INFO)\n  tf.enable_v2_behavior()\n\n  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)\n  logging.info('Gin bindings: %s', FLAGS.gin_bindings)\n\n  train_eval(\n      FLAGS.root_dir,\n      reverb_port=FLAGS.reverb_port)\n\n\nif __name__ == '__main__':\n  flags.mark_flag_as_required('root_dir')\n  multiprocessing.handle_main(functools.partial(app.run, main))\n"
  },
  {
    "path": "rigl/rl/tfagents/ppo_train_eval.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Sparse training PPO using actor/learner in a gym environment.\n\"\"\"\n\nimport collections\nimport functools\nimport os\nfrom typing import Optional\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\n\nimport gin\nimport numpy as np\nimport reverb\nfrom rigl.rigl_tf2 import mask_updaters\nfrom rigl.rl import sparse_utils\nfrom rigl.rl.tfagents import sparse_ppo_actor_network\nfrom rigl.rl.tfagents import sparse_ppo_discrete_actor_network\nfrom rigl.rl.tfagents import sparse_value_network\nfrom rigl.rl.tfagents import tf_sparse_utils\nimport tensorflow.compat.v2 as tf\n\nfrom tf_agents.agents import tf_agent\nfrom tf_agents.agents.ppo import ppo_clip_agent\nfrom tf_agents.agents.ppo import ppo_utils\nfrom tf_agents.environments import suite_gym\nfrom tf_agents.environments import suite_mujoco\nfrom tf_agents.metrics import py_metrics\nfrom tf_agents.networks import network\nfrom tf_agents.policies import py_tf_eager_policy\nfrom tf_agents.replay_buffers import reverb_replay_buffer\nfrom tf_agents.replay_buffers import reverb_utils\nfrom tf_agents.specs import tensor_spec\nfrom tf_agents.system import system_multiprocessing as multiprocessing\nfrom tf_agents.train import actor\nfrom tf_agents.train import learner\nfrom tf_agents.train import ppo_learner\nfrom tf_agents.train import triggers\nfrom tf_agents.train.utils import spec_utils\nfrom tf_agents.train.utils import train_utils\nfrom tf_agents.trajectories import time_step as ts\nfrom tf_agents.typing import types\nfrom tf_agents.utils import common\nfrom tf_agents.utils import eager_utils\nfrom tf_agents.utils import nest_utils\nfrom tf_agents.utils import object_identity\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),\n                    'Root directory for writing logs/summaries/checkpoints.')\nflags.DEFINE_integer(\n    'reverb_port', None,\n    'Port for reverb server, if None, use a randomly chosen unused port.')\n\nflags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')\nflags.DEFINE_multi_string(\n    'gin_bindings', [],\n    'Gin bindings to override the values set in the config files '\n    '(e.g. \"train_eval.env_name=Acrobot-v1\",'\n    '      \"init_masks.sparsity=0.9\").')\n\n# Env params\nflags.DEFINE_bool('is_atari', False, 'Whether the env is an atari game.')\nflags.DEFINE_bool('is_mujoco', False, 'Whether the env is a mujoco game.')\nflags.DEFINE_bool('is_classic', False,\n                  'Whether the env is a classic control game.')\nflags.DEFINE_float(\n    'average_last_fraction', 0.1,\n    'Tells what fraction latest evaluation scores are averaged. This is used'\n    ' to reduce variance.')\n\nSparsePPOLossInfo = collections.namedtuple('SparsePPOLossInfo', (\n    'policy_gradient_loss',\n    'value_estimation_loss',\n    'l2_regularization_loss',\n    'entropy_regularization_loss',\n    'kl_penalty_loss',\n    'total_loss_per_sample',\n))\n\n\ndef _normalize_advantages(advantages, axes=(0,), variance_epsilon=1e-8):\n  adv_mean, adv_var = tf.nn.moments(advantages, axes=axes, keepdims=True)\n  normalized_advantages = tf.nn.batch_normalization(\n      advantages,\n      adv_mean,\n      adv_var,\n      offset=None,\n      scale=None,\n      variance_epsilon=variance_epsilon)\n  return normalized_advantages\n\n\n@gin.configurable\nclass SparsePPOAgent(ppo_clip_agent.PPOClipAgent):\n  \"\"\"Wrapped PPOClipAgent that supports sparse training.\"\"\"\n\n  def __init__(self,\n               *args,\n               policy_l2_reg=0.0,\n               value_function_l2_reg=0.0,\n               shared_vars_l2_reg=0.0,\n               **kwargs):\n    super().__init__(*args,\n                     policy_l2_reg=policy_l2_reg,\n                     value_function_l2_reg=value_function_l2_reg,\n                     shared_vars_l2_reg=shared_vars_l2_reg,\n                     **kwargs)\n    # Name scoping has been removed here so\n    # debug_summaries are permenantly disabled. To restore with proper\n    # scoping.\n    self._debug_summaries = False\n    # Pruning layer requires the pruning_step to be >1 during forward pass.\n    tf_sparse_utils.update_prune_step(\n        self._actor_net, self.train_step_counter + 1)\n    tf_sparse_utils.update_prune_step(\n        self._value_net, self.train_step_counter + 1)\n    _ = sparse_utils.init_masks(self._actor_net)\n    _ = sparse_utils.init_masks(self._value_net)\n\n    # BEGIN: sparse training create mask updaters\n    def loss_fn(experience_data, weights_data):\n      # The following is just to fit to the existing API.\n      (time_steps, actions, old_act_log_probs, returns, normalized_advantages,\n       old_action_distribution_parameters, masked_weights,\n       old_value_predictions) = self._process_experience_weights(\n           experience_data, weights_data)\n      loss_info = self.get_loss(\n          time_steps,\n          actions,\n          old_act_log_probs,\n          returns,\n          normalized_advantages,\n          old_action_distribution_parameters,\n          masked_weights,\n          self.train_step_counter,\n          False,\n          old_value_predictions=old_value_predictions,\n          training=True)\n      return loss_info.extra.total_loss_per_sample\n    self._mask_updater_actor = mask_updaters.get_mask_updater(\n        self._actor_net, self._optimizer, loss_fn)\n    self._mask_updater_value = mask_updaters.get_mask_updater(\n        self._value_net, self._optimizer, loss_fn)\n    # END: sparse training create mask updaters\n    logging.info('SparsePPOAgent: policy_l2_reg %.5f.', policy_l2_reg)\n    logging.info('SparsePPOAgent: value_function_l2_reg %.5f.',\n                 value_function_l2_reg)\n    logging.info('SparsePPOAgent: shared_vars_l2_reg %.5f.', shared_vars_l2_reg)\n\n  def _process_experience_weights(self, experience, weights):\n    experience = self._as_trajectory(experience)\n\n    if self._compute_value_and_advantage_in_train:\n      processed_experience = self._preprocess(experience)\n    else:\n      processed_experience = experience\n\n    # Mask trajectories that cannot be used for training.\n    valid_mask = ppo_utils.make_trajectory_mask(processed_experience)\n    if weights is None:\n      masked_weights = valid_mask\n    else:\n      masked_weights = weights * valid_mask\n\n    # Reconstruct per-timestep policy distribution from stored distribution\n    #   parameters.\n    old_action_distribution_parameters = processed_experience.policy_info[\n        'dist_params']\n\n    old_actions_distribution = (\n        ppo_utils.distribution_from_spec(\n            self._action_distribution_spec,\n            old_action_distribution_parameters,\n            legacy_distribution_network=isinstance(\n                self._actor_net, network.DistributionNetwork)))\n\n    # Compute log probability of actions taken during data collection, using the\n    #   collect policy distribution.\n    old_act_log_probs = common.log_probability(old_actions_distribution,\n                                               processed_experience.action,\n                                               self._action_spec)\n\n    if self._debug_summaries and not tf.config.list_logical_devices('TPU'):\n      actions_list = tf.nest.flatten(processed_experience.action)\n      show_action_index = len(actions_list) != 1\n      for i, single_action in enumerate(actions_list):\n        action_name = ('actions_{}'.format(i)\n                       if show_action_index else 'actions')\n        tf.compat.v2.summary.histogram(\n            name=action_name, data=single_action, step=self.train_step_counter)\n\n    time_steps = ts.TimeStep(\n        step_type=processed_experience.step_type,\n        reward=processed_experience.reward,\n        discount=processed_experience.discount,\n        observation=processed_experience.observation)\n    actions = processed_experience.action\n    returns = processed_experience.policy_info['return']\n    advantages = processed_experience.policy_info['advantage']\n\n    normalized_advantages = _normalize_advantages(advantages,\n                                                  variance_epsilon=1e-8)\n\n    if self._debug_summaries and not tf.config.list_logical_devices('TPU'):\n      tf.compat.v2.summary.histogram(\n          name='advantages_normalized',\n          data=normalized_advantages,\n          step=self.train_step_counter)\n    old_value_predictions = processed_experience.policy_info['value_prediction']\n\n    return (time_steps, actions, old_act_log_probs, returns,\n            normalized_advantages, old_action_distribution_parameters,\n            masked_weights, old_value_predictions)\n\n  def _train(self, experience, weights):\n    tf.compat.v2.summary.experimental.set_step(self.train_step_counter)\n\n    (time_steps, actions, old_act_log_probs, returns, normalized_advantages,\n     old_action_distribution_parameters, masked_weights,\n     old_value_predictions) = self._process_experience_weights(\n         experience, weights)\n\n    if self._compute_value_and_advantage_in_train:\n      processed_experience = self._preprocess(experience)\n    else:\n      processed_experience = experience\n\n    batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]\n    # Loss tensors across batches will be aggregated for summaries.\n    policy_gradient_losses = []\n    value_estimation_losses = []\n    l2_regularization_losses = []\n    entropy_regularization_losses = []\n    kl_penalty_losses = []\n\n    loss_info = None\n    variables_to_train = list(\n        object_identity.ObjectIdentitySet(self._actor_net.trainable_weights +\n                                          self._value_net.trainable_weights))\n    # Sort to ensure tensors on different processes end up in same order.\n    variables_to_train = sorted(variables_to_train, key=lambda x: x.name)\n\n    for _ in range(self._num_epochs):\n      # Name scoping has been removed here so\n      # debug_summaries are permenantly disabled. To restore with proper\n      # scoping.\n      debug_summaries = False\n\n      with tf.GradientTape(persistent=True) as tape:\n        loss_info = self.get_loss(\n            time_steps,\n            actions,\n            old_act_log_probs,\n            returns,\n            normalized_advantages,\n            old_action_distribution_parameters,\n            masked_weights,\n            self.train_step_counter,\n            debug_summaries,\n            old_value_predictions=old_value_predictions,\n            training=True)\n\n      grads = tape.gradient(loss_info.loss, variables_to_train)\n\n      tf_sparse_utils.log_snr(tape, loss_info.extra.total_loss_per_sample,\n                              self.train_step_counter, variables_to_train)\n\n      # BEGIN sparse training mask update\n      # We use the lastest set of gradients to update the masks for sparse\n      # training. Note, we do this before gradient clipping.\n      def _mask_update_step(mask_updater, updater_name):\n        mask_updater.set_validation_data(experience, weights)\n        mask_updater.update(self.train_step_counter)\n        with tf.name_scope('Drop_fraction/'):\n          tf.summary.scalar(\n              name=f'{updater_name}',\n              data=mask_updater.last_drop_fraction)\n\n      mask_update_step_actor = functools.partial(\n          _mask_update_step, self._mask_updater_actor, 'actor')\n      mask_update_step_value = functools.partial(\n          _mask_update_step, self._mask_updater_value, 'value')\n\n      tf_sparse_utils.log_sparsities(self._actor_net, 'actor')\n      tf_sparse_utils.log_sparsities(self._value_net, 'value')\n      tf_sparse_utils.log_total_params([self._actor_net, self._value_net])\n      if self._mask_updater_actor is not None:\n        is_update_actor = self._mask_updater_actor.is_update_iter(\n            self.train_step_counter)\n\n        tf.cond(is_update_actor, mask_update_step_actor, lambda: None)\n\n      if self._mask_updater_value is not None:\n        is_update_value = self._mask_updater_value.is_update_iter(\n            self.train_step_counter)\n\n        tf.cond(is_update_value, mask_update_step_value, lambda: None)\n      # END sparse training mask update\n\n      if self._gradient_clipping > 0:\n        grads, _ = tf.clip_by_global_norm(grads, self._gradient_clipping)\n\n      # Tuple is used for py3, where zip is a generator producing values once.\n      grads_and_vars = tuple(zip(grads, variables_to_train))\n\n      # If summarize_gradients, create functions for summarizing both\n      # gradients and variables.\n      if self._summarize_grads_and_vars and debug_summaries:\n        eager_utils.add_gradients_summaries(grads_and_vars,\n                                            self.train_step_counter)\n        eager_utils.add_variables_summaries(grads_and_vars,\n                                            self.train_step_counter)\n\n      self._optimizer.apply_gradients(grads_and_vars)\n      self.train_step_counter.assign_add(1)\n\n      policy_gradient_losses.append(loss_info.extra.policy_gradient_loss)\n      value_estimation_losses.append(loss_info.extra.value_estimation_loss)\n      l2_regularization_losses.append(loss_info.extra.l2_regularization_loss)\n      entropy_regularization_losses.append(\n          loss_info.extra.entropy_regularization_loss)\n      kl_penalty_losses.append(loss_info.extra.kl_penalty_loss)\n\n    if self._initial_adaptive_kl_beta > 0:\n      # After update epochs, update adaptive kl beta, then update observation\n      #   normalizer and reward normalizer.\n      policy_state = self._collect_policy.get_initial_state(batch_size)\n      # Compute the mean kl from previous action distribution.\n      kl_divergence = self._kl_divergence(\n          time_steps, old_action_distribution_parameters,\n          self._collect_policy.distribution(time_steps, policy_state).action)\n      self.update_adaptive_kl_beta(kl_divergence)\n\n    if self.update_normalizers_in_train:\n      self.update_observation_normalizer(time_steps.observation)\n      self.update_reward_normalizer(processed_experience.reward)\n\n    loss_info = tf.nest.map_structure(tf.identity, loss_info)\n\n    # Make summaries for total loss averaged across all epochs.\n    # The *_losses lists will have been populated by\n    #   calls to self.get_loss. Assumes all the losses have same length.\n    with tf.name_scope('Losses/'):\n      num_epochs = len(policy_gradient_losses)\n      total_policy_gradient_loss = tf.add_n(policy_gradient_losses) / num_epochs\n      total_value_estimation_loss = tf.add_n(\n          value_estimation_losses) / num_epochs\n      total_l2_regularization_loss = tf.add_n(\n          l2_regularization_losses) / num_epochs\n      total_entropy_regularization_loss = tf.add_n(\n          entropy_regularization_losses) / num_epochs\n      total_kl_penalty_loss = tf.add_n(kl_penalty_losses) / num_epochs\n      tf.compat.v2.summary.scalar(\n          name='policy_gradient_loss',\n          data=total_policy_gradient_loss,\n          step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='value_estimation_loss',\n          data=total_value_estimation_loss,\n          step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='l2_regularization_loss',\n          data=total_l2_regularization_loss,\n          step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='entropy_regularization_loss',\n          data=total_entropy_regularization_loss,\n          step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='kl_penalty_loss',\n          data=total_kl_penalty_loss,\n          step=self.train_step_counter)\n\n      total_abs_loss = (\n          tf.abs(total_policy_gradient_loss) +\n          tf.abs(total_value_estimation_loss) +\n          tf.abs(total_entropy_regularization_loss) +\n          tf.abs(total_l2_regularization_loss) + tf.abs(total_kl_penalty_loss))\n\n      tf.compat.v2.summary.scalar(\n          name='total_abs_loss',\n          data=total_abs_loss,\n          step=self.train_step_counter)\n\n    with tf.name_scope('LearningRate/'):\n      learning_rate = ppo_utils.get_learning_rate(self._optimizer)\n      tf.compat.v2.summary.scalar(\n          name='learning_rate',\n          data=learning_rate,\n          step=self.train_step_counter)\n\n    if self._summarize_grads_and_vars and not tf.config.list_logical_devices(\n        'TPU'):\n      with tf.name_scope('Variables/'):\n        all_vars = (\n            self._actor_net.trainable_weights +\n            self._value_net.trainable_weights)\n        for var in all_vars:\n          tf.compat.v2.summary.histogram(\n              name=var.name.replace(':', '_'),\n              data=var,\n              step=self.train_step_counter)\n\n    return loss_info\n\n  def get_loss(self,\n               time_steps,\n               actions,\n               act_log_probs,\n               returns,\n               normalized_advantages,\n               action_distribution_parameters,\n               weights,\n               train_step,\n               debug_summaries,\n               old_value_predictions = None,\n               training = False):\n    \"\"\"Compute the loss and create optimization op for one training epoch.\n\n    All tensors should have a single batch dimension.\n\n    Args:\n      time_steps: A minibatch of TimeStep tuples.\n      actions: A minibatch of actions.\n      act_log_probs: A minibatch of action probabilities (probability under the\n        sampling policy).\n      returns: A minibatch of per-timestep returns.\n      normalized_advantages: A minibatch of normalized per-timestep advantages.\n      action_distribution_parameters: Parameters of data-collecting action\n        distribution. Needed for KL computation.\n      weights: Optional scalar or element-wise (per-batch-entry) importance\n        weights.  Includes a mask for invalid timesteps.\n      train_step: A train_step variable to increment for each train step.\n        Typically the global_step.\n      debug_summaries: True if debug summaries should be created.\n      old_value_predictions: (Optional) The saved value predictions, used for\n        calculating the value estimation loss when value clipping is performed.\n      training: Whether this loss is being used for training.\n\n    Returns:\n      A tf_agent.LossInfo named tuple with the total_loss and all intermediate\n        losses in the extra field contained in a PPOLossInfo named tuple.\n    \"\"\"\n    # Evaluate the current policy on timesteps.\n\n    # batch_size from time_steps\n    batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]\n    policy_state = self._collect_policy.get_initial_state(batch_size)\n    # We must use _distribution because the distribution API doesn't pass down\n    # the training= kwarg.\n    distribution_step = self._collect_policy._distribution(\n        time_steps,\n        policy_state,\n        training=training)\n    current_policy_distribution = distribution_step.action\n\n    # Call all loss functions and add all loss values.\n    (value_estimation_loss,\n     value_estimation_loss_per_sample) = self.value_estimation_loss(\n         time_steps=time_steps,\n         returns=returns,\n         old_value_predictions=old_value_predictions,\n         weights=weights,\n         debug_summaries=debug_summaries,\n         training=training)\n    (policy_gradient_loss,\n     policy_gradient_loss_per_sample) = self.policy_gradient_loss(\n         time_steps,\n         actions,\n         tf.stop_gradient(act_log_probs),\n         tf.stop_gradient(normalized_advantages),\n         current_policy_distribution,\n         weights,\n         debug_summaries=debug_summaries)\n\n    if (self._policy_l2_reg > 0.0 or self._value_function_l2_reg > 0.0 or\n        self._shared_vars_l2_reg > 0.0):\n      l2_regularization_loss = self.l2_regularization_loss(debug_summaries)\n    else:\n      l2_regularization_loss = tf.zeros_like(policy_gradient_loss)\n    l2_regularization_loss_per_sample = tf.repeat(\n        l2_regularization_loss / tf.cast(batch_size, tf.float32), batch_size)\n\n    if self._entropy_regularization > 0.0:\n      (entropy_regularization_loss, entropy_regularization_loss_per_sample\n      ) = self.entropy_regularization_loss(time_steps,\n                                           current_policy_distribution, weights,\n                                           debug_summaries)\n    else:\n      entropy_regularization_loss = tf.zeros_like(policy_gradient_loss)\n      entropy_regularization_loss_per_sample = tf.repeat(\n          tf.constant(0, dtype=tf.float32), batch_size)\n\n    if self._initial_adaptive_kl_beta == 0:\n      kl_penalty_loss = tf.zeros_like(policy_gradient_loss)\n    else:\n      kl_penalty_loss = self.kl_penalty_loss(time_steps,\n                                             action_distribution_parameters,\n                                             current_policy_distribution,\n                                             weights, debug_summaries)\n    kl_penalty_loss_per_sample = tf.repeat(\n        kl_penalty_loss / tf.cast(batch_size, tf.float32), batch_size)\n\n    total_loss = (\n        policy_gradient_loss + value_estimation_loss + l2_regularization_loss +\n        entropy_regularization_loss + kl_penalty_loss)\n    total_loss_per_sample = (\n        policy_gradient_loss_per_sample + value_estimation_loss_per_sample +\n        l2_regularization_loss_per_sample +\n        entropy_regularization_loss_per_sample + kl_penalty_loss_per_sample)\n\n    return tf_agent.LossInfo(\n        total_loss,\n        SparsePPOLossInfo(\n            policy_gradient_loss=policy_gradient_loss,\n            value_estimation_loss=value_estimation_loss,\n            l2_regularization_loss=l2_regularization_loss,\n            entropy_regularization_loss=entropy_regularization_loss,\n            kl_penalty_loss=kl_penalty_loss,\n            total_loss_per_sample=total_loss_per_sample\n            ))\n\n  def value_estimation_loss(self,\n                            time_steps,\n                            returns,\n                            weights,\n                            old_value_predictions = None,\n                            debug_summaries = False,\n                            training = False):\n    \"\"\"Computes the value estimation loss for actor-critic training.\n\n    All tensors should have a single batch dimension.\n\n    Args:\n      time_steps: A batch of timesteps.\n      returns: Per-timestep returns for value function to predict. (Should come\n        from TD-lambda computation.)\n      weights: Optional scalar or element-wise (per-batch-entry) importance\n        weights.  Includes a mask for invalid timesteps.\n      old_value_predictions: (Optional) The saved value predictions from\n        policy_info, required when self._value_clipping > 0.\n      debug_summaries: True if debug summaries should be created.\n      training: Whether this loss is going to be used for training.\n\n    Returns:\n      value_estimation_loss: A scalar value_estimation_loss loss.\n\n    Raises:\n      ValueError: If old_value_predictions was not passed in, but value clipping\n        was performed.\n    \"\"\"\n    observation = time_steps.observation\n    if debug_summaries and not tf.config.list_logical_devices('TPU'):\n      observation_list = tf.nest.flatten(observation)\n      show_observation_index = len(observation_list) != 1\n      for i, single_observation in enumerate(observation_list):\n        observation_name = ('observations_{}'.format(i)\n                            if show_observation_index else 'observations')\n        tf.compat.v2.summary.histogram(\n            name=observation_name,\n            data=single_observation,\n            step=self.train_step_counter)\n\n    batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]\n    value_state = self._collect_policy.get_initial_value_state(batch_size)\n\n    value_preds, _ = self._collect_policy.apply_value_network(\n        time_steps.observation,\n        time_steps.step_type,\n        value_state=value_state,\n        training=training)\n    value_estimation_error = tf.math.squared_difference(returns, value_preds)\n\n    if self._value_clipping > 0:\n      if old_value_predictions is None:\n        raise ValueError(\n            'old_value_predictions is None but needed for value clipping.')\n      clipped_value_preds = old_value_predictions + tf.clip_by_value(\n          value_preds - old_value_predictions, -self._value_clipping,\n          self._value_clipping)\n      clipped_value_estimation_error = tf.math.squared_difference(\n          returns, clipped_value_preds)\n      value_estimation_error = tf.maximum(value_estimation_error,\n                                          clipped_value_estimation_error)\n\n    if self._aggregate_losses_across_replicas:\n      value_estimation_loss = (\n          common.aggregate_losses(\n              per_example_loss=value_estimation_error,\n              sample_weight=weights).total_loss * self._value_pred_loss_coef)\n    else:\n      value_estimation_loss = tf.math.reduce_mean(\n          value_estimation_error * weights) * self._value_pred_loss_coef\n\n    value_estimation_loss_per_sample = tf.reduce_mean(value_estimation_error,\n                                                      axis=0)\n    if debug_summaries:\n      tf.compat.v2.summary.scalar(\n          name='value_pred_avg',\n          data=tf.reduce_mean(input_tensor=value_preds),\n          step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='value_actual_avg',\n          data=tf.reduce_mean(input_tensor=returns),\n          step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='value_estimation_loss',\n          data=value_estimation_loss,\n          step=self.train_step_counter)\n      if not tf.config.list_logical_devices('TPU'):\n        tf.compat.v2.summary.histogram(\n            name='value_preds', data=value_preds, step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='value_estimation_error',\n            data=value_estimation_error,\n            step=self.train_step_counter)\n\n    if self._check_numerics:\n      value_estimation_loss = tf.debugging.check_numerics(\n          value_estimation_loss, 'value_estimation_loss')\n      value_estimation_loss_per_sample = tf.debugging.check_numerics(\n          value_estimation_loss_per_sample, 'value_estimation_loss_per_sample')\n\n    return value_estimation_loss, value_estimation_loss_per_sample\n\n  def policy_gradient_loss(\n      self,\n      time_steps,\n      actions,\n      sample_action_log_probs,\n      advantages,\n      current_policy_distribution,\n      weights,\n      debug_summaries = False):\n    \"\"\"Create tensor for policy gradient loss.\n\n    All tensors should have a single batch dimension.\n\n    Args:\n      time_steps: TimeSteps with observations for each timestep.\n      actions: Tensor of actions for timesteps, aligned on index.\n      sample_action_log_probs: Tensor of sample probability of each action.\n      advantages: Tensor of advantage estimate for each timestep, aligned on\n        index. Works better when advantage estimates are normalized.\n      current_policy_distribution: The policy distribution, evaluated on all\n        time_steps.\n      weights: Optional scalar or element-wise (per-batch-entry) importance\n        weights.  Includes a mask for invalid timesteps.\n      debug_summaries: True if debug summaries should be created.\n\n    Returns:\n      policy_gradient_loss: A tensor that will contain policy gradient loss for\n        the on-policy experience.\n    \"\"\"\n    nest_utils.assert_same_structure(time_steps, self.time_step_spec)\n    action_log_prob = common.log_probability(current_policy_distribution,\n                                             actions, self._action_spec)\n    action_log_prob = tf.cast(action_log_prob, tf.float32)\n    if self._log_prob_clipping > 0.0:\n      action_log_prob = tf.clip_by_value(action_log_prob,\n                                         -self._log_prob_clipping,\n                                         self._log_prob_clipping)\n    if self._check_numerics:\n      action_log_prob = tf.debugging.check_numerics(action_log_prob,\n                                                    'action_log_prob')\n\n    # Prepare both clipped and unclipped importance ratios.\n    importance_ratio = tf.exp(action_log_prob - sample_action_log_probs)\n    importance_ratio_clipped = tf.clip_by_value(\n        importance_ratio, 1 - self._importance_ratio_clipping,\n        1 + self._importance_ratio_clipping)\n\n    if self._check_numerics:\n      importance_ratio = tf.debugging.check_numerics(importance_ratio,\n                                                     'importance_ratio')\n      if self._importance_ratio_clipping > 0.0:\n        importance_ratio_clipped = tf.debugging.check_numerics(\n            importance_ratio_clipped, 'importance_ratio_clipped')\n\n    # Pessimistically choose the minimum objective value for clipped and\n    #   unclipped importance ratios.\n    per_timestep_objective = importance_ratio * advantages\n    per_timestep_objective_clipped = importance_ratio_clipped * advantages\n    per_timestep_objective_min = tf.minimum(per_timestep_objective,\n                                            per_timestep_objective_clipped)\n\n    if self._importance_ratio_clipping > 0.0:\n      policy_gradient_loss = -per_timestep_objective_min\n    else:\n      policy_gradient_loss = -per_timestep_objective\n\n    policy_gradient_loss_per_sample = tf.reduce_mean(policy_gradient_loss,\n                                                     axis=0)\n\n    if self._aggregate_losses_across_replicas:\n      policy_gradient_loss = common.aggregate_losses(\n          per_example_loss=policy_gradient_loss,\n          sample_weight=weights).total_loss\n    else:\n      policy_gradient_loss = tf.math.reduce_mean(policy_gradient_loss * weights)\n\n    if debug_summaries:\n      if self._importance_ratio_clipping > 0.0:\n        clip_fraction = tf.reduce_mean(\n            input_tensor=tf.cast(\n                tf.greater(\n                    tf.abs(importance_ratio -\n                           1.0), self._importance_ratio_clipping), tf.float32))\n        tf.compat.v2.summary.scalar(\n            name='clip_fraction',\n            data=clip_fraction,\n            step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='importance_ratio_mean',\n          data=tf.reduce_mean(input_tensor=importance_ratio),\n          step=self.train_step_counter)\n      entropy = common.entropy(current_policy_distribution, self.action_spec)\n      tf.compat.v2.summary.scalar(\n          name='policy_entropy_mean',\n          data=tf.reduce_mean(input_tensor=entropy),\n          step=self.train_step_counter)\n      if not tf.config.list_logical_devices('TPU'):\n        tf.compat.v2.summary.histogram(\n            name='action_log_prob',\n            data=action_log_prob,\n            step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='action_log_prob_sample',\n            data=sample_action_log_probs,\n            step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='importance_ratio',\n            data=importance_ratio,\n            step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='importance_ratio_clipped',\n            data=importance_ratio_clipped,\n            step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='per_timestep_objective',\n            data=per_timestep_objective,\n            step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='per_timestep_objective_clipped',\n            data=per_timestep_objective_clipped,\n            step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='per_timestep_objective_min',\n            data=per_timestep_objective_min,\n            step=self.train_step_counter)\n\n        tf.compat.v2.summary.histogram(\n            name='policy_entropy', data=entropy, step=self.train_step_counter)\n        for i, (single_action, single_distribution) in enumerate(\n            zip(\n                tf.nest.flatten(self.action_spec),\n                tf.nest.flatten(current_policy_distribution))):\n          # Categorical distribution (used for discrete actions) doesn't have a\n          # mean.\n          distribution_index = '_{}'.format(i) if i > 0 else ''\n          if not tensor_spec.is_discrete(single_action):\n            tf.compat.v2.summary.histogram(\n                name='actions_distribution_mean' + distribution_index,\n                data=single_distribution.mean(),\n                step=self.train_step_counter)\n            tf.compat.v2.summary.histogram(\n                name='actions_distribution_stddev' + distribution_index,\n                data=single_distribution.stddev(),\n                step=self.train_step_counter)\n        tf.compat.v2.summary.histogram(\n            name='policy_gradient_loss',\n            data=policy_gradient_loss,\n            step=self.train_step_counter)\n\n    if self._check_numerics:\n      policy_gradient_loss = tf.debugging.check_numerics(\n          policy_gradient_loss, 'policy_gradient_loss')\n      policy_gradient_loss_per_sample = tf.debugging.check_numerics(\n          policy_gradient_loss_per_sample, 'policy_gradient_loss_per_sample')\n\n    return policy_gradient_loss, policy_gradient_loss_per_sample\n\n  def entropy_regularization_loss(\n      self,\n      time_steps,\n      current_policy_distribution,\n      weights,\n      debug_summaries = False):\n    \"\"\"Create regularization loss tensor based on agent parameters.\"\"\"\n    if self._entropy_regularization > 0:\n      nest_utils.assert_same_structure(time_steps, self.time_step_spec)\n      with tf.name_scope('entropy_regularization'):\n        entropy = tf.cast(\n            common.entropy(current_policy_distribution, self.action_spec),\n            tf.float32)\n\n        if self._aggregate_losses_across_replicas:\n          entropy_reg_loss = common.aggregate_losses(\n              per_example_loss=-entropy,\n              sample_weight=weights).total_loss * self._entropy_regularization\n        else:\n          entropy_reg_loss = (\n              tf.math.reduce_mean(-entropy * weights) *\n              self._entropy_regularization)\n\n        if self._check_numerics:\n          entropy_reg_loss = tf.debugging.check_numerics(\n              entropy_reg_loss, 'entropy_reg_loss')\n\n        if debug_summaries and not tf.config.list_logical_devices('TPU'):\n          tf.compat.v2.summary.histogram(\n              name='entropy_reg_loss',\n              data=entropy_reg_loss,\n              step=self.train_step_counter)\n    else:\n      raise ValueError('This is not allowed, this is handled at loss level.')\n\n    entropy_reg_loss_per_sample = -entropy\n    if self._check_numerics:\n      entropy_reg_loss_per_sample = tf.debugging.check_numerics(\n          entropy_reg_loss_per_sample, 'entropy_reg_loss_per_sample')\n\n    return entropy_reg_loss, entropy_reg_loss_per_sample\n\n\nclass ReverbFixedLengthSequenceObserver(reverb_utils.ReverbAddTrajectoryObserver\n                                       ):\n  \"\"\"Reverb fixed length sequence observer.\n\n  This is a specialized observer similar to ReverbAddTrajectoryObserver but each\n  sequence contains a fixed number of steps and can span multiple episodes. This\n  implementation is consistent with (Schulman, 17).\n\n  **Note**: Counting of steps in drivers does not include boundary steps. To\n  guarantee only 1 item is pushed to the replay when collecting n steps with a\n  `sequence_length` of n make sure to set the `stride_length`.\n  \"\"\"\n\n  def __call__(self, trajectory):\n    \"\"\"Writes the trajectory into the underlying replay buffer.\n\n    Allows trajectory to be a flattened trajectory. No batch dimension allowed.\n\n    Args:\n      trajectory: The trajectory to be written which could be (possibly nested)\n        trajectory object or a flattened version of a trajectory. It assumes\n        there is *no* batch dimension.\n    \"\"\"\n    self._writer.append(trajectory)\n    self._cached_steps += 1\n\n    self._write_cached_steps()\n\n\n@gin.configurable\ndef train_eval(\n    root_dir,\n    env_name='HalfCheetah-v2',\n    # Training params\n    num_iterations=1600,\n    actor_fc_layers=(64, 64),\n    value_fc_layers=(64, 64),\n    learning_rate=3e-4,\n    collect_sequence_length=2048,\n    minibatch_size=64,\n    num_epochs=10,\n    # Agent params\n    importance_ratio_clipping=0.2,\n    lambda_value=0.95,\n    discount_factor=0.99,\n    entropy_regularization=0.,\n    value_pred_loss_coef=0.5,\n    use_gae=True,\n    use_td_lambda_return=True,\n    gradient_clipping=0.5,\n    value_clipping=None,\n    # Replay params\n    reverb_port=None,\n    replay_capacity=10000,\n    # Others\n    policy_save_interval=5000,\n    summary_interval=1000,\n    eval_interval=10000,\n    eval_episodes=100,\n    debug_summaries=False,\n    summarize_grads_and_vars=False,\n    train_mode_actor='dense',\n    train_mode_value='dense',\n    sparse_output_layer=True,\n    weight_decay=0.0,\n    width=1.0):\n  \"\"\"Trains and evaluates DQN.\"\"\"\n\n  logging.info('Actor fc layer params: %s', actor_fc_layers)\n  logging.info('Value fc layer params: %s', value_fc_layers)\n  logging.info('Policy save interval: %s', policy_save_interval)\n  logging.info('Eval interval: %s', eval_interval)\n  logging.info('Environment name: %s', env_name)\n  logging.info('Learning rate: %s', learning_rate)\n  logging.info('Num iterations: %s', num_iterations)\n  logging.info('Sparse output layer: %s', sparse_output_layer)\n  logging.info('Train mode actor: %s', train_mode_actor)\n  logging.info('Train mode value: %s', train_mode_value)\n  logging.info('Width: %s', width)\n  logging.info('Weight decay: %s', weight_decay)\n\n  if FLAGS.is_mujoco:\n    collect_env = suite_mujoco.load(env_name)\n    eval_env = suite_mujoco.load(env_name)\n    logging.info('Loaded Mujoco environment %s', env_name)\n  elif FLAGS.is_classic:\n    collect_env = suite_gym.load(env_name)\n    eval_env = suite_gym.load(env_name)\n    logging.info('Loaded Classic control environment %s', env_name)\n  else:\n    raise ValueError('Environment init for Atari not supported yet.')\n\n  num_environments = 1\n\n  observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (\n      spec_utils.get_tensor_specs(collect_env))\n  observation_tensor_spec = tf.TensorSpec(\n      dtype=tf.float32, shape=observation_tensor_spec.shape)\n\n  train_step = train_utils.create_train_step()\n\n  if FLAGS.is_classic:\n    actor_net_constructor = sparse_ppo_discrete_actor_network.PPODiscreteActorNetwork\n  else:\n    actor_net_constructor = sparse_ppo_actor_network.PPOActorNetwork\n\n  actor_net_builder = actor_net_constructor(\n            is_sparse=train_mode_actor == 'sparse',\n            sparse_output_layer=sparse_output_layer,\n            weight_decay=0,\n            width=width)\n  actor_net = actor_net_builder.create_sequential_actor_net(\n      actor_fc_layers, action_tensor_spec,\n      input_dim=time_step_tensor_spec.observation.shape[0])\n\n  value_net = sparse_value_network.ValueNetwork(\n      observation_tensor_spec,\n      fc_layer_params=value_fc_layers,\n      kernel_initializer=tf.keras.initializers.Orthogonal(),\n      is_sparse=train_mode_value == 'sparse',\n      sparse_output_layer=sparse_output_layer,\n      weight_decay=0,\n      width=width)\n  logging.info('Train eval: weight decay %.5f.', weight_decay)\n\n  current_iteration = tf.Variable(0, dtype=tf.int64)\n  def learning_rate_fn():\n    # Linearly decay the learning rate.\n    return learning_rate * (1 - current_iteration / num_iterations)\n\n  agent = SparsePPOAgent(\n      time_step_tensor_spec,\n      action_tensor_spec,\n      optimizer=tf.compat.v1.train.AdamOptimizer(\n          learning_rate=learning_rate_fn, epsilon=1e-5),\n      actor_net=actor_net,\n      value_net=value_net,\n      importance_ratio_clipping=importance_ratio_clipping,\n      lambda_value=lambda_value,\n      discount_factor=discount_factor,\n      entropy_regularization=entropy_regularization,\n      value_pred_loss_coef=value_pred_loss_coef,\n      policy_l2_reg=weight_decay,\n      value_function_l2_reg=weight_decay,\n      shared_vars_l2_reg=weight_decay,\n      # This is a legacy argument for the number of times we repeat the data\n      # inside of the train function, incompatible with mini batch learning.\n      # We set the epoch number from the replay buffer and tf.Data instead.\n      num_epochs=1,\n      use_gae=use_gae,\n      use_td_lambda_return=use_td_lambda_return,\n      gradient_clipping=gradient_clipping,\n      value_clipping=value_clipping,\n      compute_value_and_advantage_in_train=False,\n      # Skips updating normalizers in the agent, as it's handled in the learner.\n      update_normalizers_in_train=False,\n      debug_summaries=debug_summaries,\n      summarize_grads_and_vars=summarize_grads_and_vars,\n      train_step_counter=train_step)\n  agent.initialize()\n\n  reverb_server = reverb.Server(\n      [\n          reverb.Table(  # Replay buffer storing experience for training.\n              name='training_table',\n              sampler=reverb.selectors.Fifo(),\n              remover=reverb.selectors.Fifo(),\n              rate_limiter=reverb.rate_limiters.MinSize(1),\n              max_size=replay_capacity,\n              max_times_sampled=1,\n          ),\n          reverb.Table(  # Replay buffer storing experience for normalization.\n              name='normalization_table',\n              sampler=reverb.selectors.Fifo(),\n              remover=reverb.selectors.Fifo(),\n              rate_limiter=reverb.rate_limiters.MinSize(1),\n              max_size=replay_capacity,\n              max_times_sampled=1,\n          )\n      ],\n      port=reverb_port)\n\n  # Create the replay buffer.\n  reverb_replay_train = reverb_replay_buffer.ReverbReplayBuffer(\n      agent.collect_data_spec,\n      sequence_length=collect_sequence_length,\n      table_name='training_table',\n      server_address='localhost:{}'.format(reverb_server.port),\n      # The only collected sequence is used to populate the batches.\n      max_cycle_length=1,\n      num_workers_per_iterator=1,\n      max_samples_per_stream=1,\n      rate_limiter_timeout_ms=1000)\n  reverb_replay_normalization = reverb_replay_buffer.ReverbReplayBuffer(\n      agent.collect_data_spec,\n      sequence_length=collect_sequence_length,\n      table_name='normalization_table',\n      server_address='localhost:{}'.format(reverb_server.port),\n      # The only collected sequence is used to populate the batches.\n      max_cycle_length=1,\n      num_workers_per_iterator=1,\n      max_samples_per_stream=1,\n      rate_limiter_timeout_ms=1000)\n\n  rb_observer = ReverbFixedLengthSequenceObserver(\n      reverb_replay_train.py_client, ['training_table', 'normalization_table'],\n      sequence_length=collect_sequence_length,\n      stride_length=collect_sequence_length)\n\n  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)\n  collect_env_step_metric = py_metrics.EnvironmentSteps()\n  learning_triggers = [\n      triggers.PolicySavedModelTrigger(\n          saved_model_dir,\n          agent,\n          train_step,\n          interval=policy_save_interval,\n          metadata_metrics={\n              triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric\n          }),\n      triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval),\n  ]\n\n  def training_dataset_fn():\n    return reverb_replay_train.as_dataset(\n        sample_batch_size=num_environments,\n        sequence_preprocess_fn=agent.preprocess_sequence)\n\n  def normalization_dataset_fn():\n    return reverb_replay_normalization.as_dataset(\n        sample_batch_size=num_environments,\n        sequence_preprocess_fn=agent.preprocess_sequence)\n\n  agent_learner = ppo_learner.PPOLearner(\n      root_dir,\n      train_step,\n      agent,\n      experience_dataset_fn=training_dataset_fn,\n      normalization_dataset_fn=normalization_dataset_fn,\n      num_samples=1,\n      summary_interval=10,\n      num_epochs=num_epochs,\n      minibatch_size=minibatch_size,\n      shuffle_buffer_size=collect_sequence_length,\n      triggers=learning_triggers)\n\n  tf_collect_policy = agent.collect_policy\n  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(\n      tf_collect_policy, use_tf_function=True)\n\n  collect_actor = actor.Actor(\n      collect_env,\n      collect_policy,\n      train_step,\n      steps_per_run=collect_sequence_length,\n      observers=[rb_observer, collect_env_step_metric],\n      metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric],\n      reference_metrics=[collect_env_step_metric],\n      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),\n      summary_interval=summary_interval)\n\n  eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(\n      agent.policy, use_tf_function=True)\n\n  average_returns = []\n  if eval_interval:\n    logging.info('Intial evaluation.')\n    eval_actor = actor.Actor(\n        eval_env,\n        eval_greedy_policy,\n        train_step,\n        metrics=actor.eval_metrics(eval_episodes),\n        reference_metrics=[collect_env_step_metric],\n        summary_dir=os.path.join(root_dir, 'eval'),\n        episodes_per_run=eval_episodes)\n\n    eval_actor.run_and_log()\n    for metric in eval_actor.metrics:\n      if isinstance(metric, py_metrics.AverageReturnMetric):\n        average_returns.append(metric._buffer.mean())\n\n  logging.info('Training on %s', env_name)\n  last_eval_step = 0\n  for i in range(num_iterations):\n    logging.info('collect_actor.run')\n    collect_actor.run()\n    # Reset the reverb observer to make sure the data collected is flushed and\n    # written to the RB.\n    # At this point, there a small number of steps left in the cache because the\n    # actor does not count a boundary step as a step, whereas it still gets\n    # added to Reverb for training. We throw away those extra steps without\n    # padding to align with the paper implementation which never collects them\n    # in the first place.\n    logging.info('rb_observer.reset')\n    rb_observer.reset(write_cached_steps=False)\n    logging.info('reverb_replay_normalization.size: %d',\n                 reverb_replay_normalization.get_table_info().current_size)\n    logging.info('reverb_replay_train.size: %d',\n                 reverb_replay_train.get_table_info().current_size)\n    logging.info('agent_learner.run')\n    agent_learner.run()\n    logging.info('reverb_replay_train.clear')\n    reverb_replay_train.clear()\n    logging.info('reverb_replay_normalization.clear')\n    reverb_replay_normalization.clear()\n    current_iteration.assign_add(1)\n\n    # Eval only if `eval_interval` has been set. Then, eval if the current train\n    # step is equal or greater than the `last_eval_step` + `eval_interval` or if\n    # this is the last iteration. This logic exists because agent_learner.run()\n    # does not return after every train step.\n    if (eval_interval and\n        (agent_learner.train_step_numpy >= eval_interval + last_eval_step\n         or i == num_iterations - 1)):\n      logging.info('Evaluating.')\n      eval_actor.run_and_log()\n      last_eval_step = agent_learner.train_step_numpy\n      for metric in eval_actor.metrics:\n        if isinstance(metric, py_metrics.AverageReturnMetric):\n          average_returns.append(metric._buffer.mean())\n\n  # Log last section of evaluation scores for the final metric.\n  idx = int(FLAGS.average_last_fraction * len(average_returns))\n  avg_return = np.mean(average_returns[-idx:])\n  logging.info('Step %d, Average Return: %f', collect_env_step_metric.result(),\n               avg_return)\n\n  rb_observer.close()\n  reverb_server.stop()\n\n\ndef main(_):\n  tf.config.experimental_run_functions_eagerly(False)\n  logging.set_verbosity(logging.INFO)\n  tf.enable_v2_behavior()\n\n  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)\n  logging.info('Gin bindings: %s', FLAGS.gin_bindings)\n\n  train_eval(\n      FLAGS.root_dir,\n      reverb_port=FLAGS.reverb_port)\n\n\nif __name__ == '__main__':\n  flags.mark_flag_as_required('root_dir')\n  multiprocessing.handle_main(functools.partial(app.run, main))\n"
  },
  {
    "path": "rigl/rl/tfagents/sac_train_eval.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Train and Eval SAC.\n\"\"\"\n\nimport functools\nimport os\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\n\nimport gin\nimport numpy as np\nimport reverb\nfrom rigl.rigl_tf2 import mask_updaters\nfrom rigl.rl import sparse_utils\nfrom rigl.rl.tfagents import sparse_tanh_normal_projection_network\nfrom rigl.rl.tfagents import tf_sparse_utils\nimport tensorflow as tf\nfrom tf_agents.agents import tf_agent\nfrom tf_agents.agents.sac import sac_agent\nfrom tf_agents.environments import suite_mujoco\nfrom tf_agents.keras_layers import inner_reshape\nfrom tf_agents.metrics import py_metrics\nfrom tf_agents.networks import nest_map\nfrom tf_agents.networks import sequential\nfrom tf_agents.policies import greedy_policy\nfrom tf_agents.policies import py_tf_eager_policy\nfrom tf_agents.policies import random_py_policy\nfrom tf_agents.replay_buffers import reverb_replay_buffer\nfrom tf_agents.replay_buffers import reverb_utils\nfrom tf_agents.train import actor\nfrom tf_agents.train import learner\nfrom tf_agents.train import triggers\nfrom tf_agents.train.utils import spec_utils\nfrom tf_agents.train.utils import strategy_utils\nfrom tf_agents.train.utils import train_utils\nfrom tf_agents.utils import common\nfrom tf_agents.utils import object_identity\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),\n                    'Root directory for writing logs/summaries/checkpoints.')\nflags.DEFINE_integer(\n    'reverb_port', None,\n    'Port for reverb server, if None, use a randomly chosen unused port.')\nflags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.')\nflags.DEFINE_multi_string('gin_bindings', [], 'Gin binding parameters.')\n\n# Env params\nflags.DEFINE_bool('is_atari', False, 'Whether the env is an atari game.')\nflags.DEFINE_bool('is_mujoco', False, 'Whether the env is a mujoco game.')\nflags.DEFINE_bool('is_classic', False,\n                  'Whether the env is a classic control game.')\nflags.DEFINE_float(\n    'average_last_fraction', 0.1,\n    'Tells what fraction latest evaluation scores are averaged. This is used'\n    ' to reduce variance.')\n\ndense = functools.partial(\n    tf.keras.layers.Dense,\n    activation=tf.keras.activations.relu,\n    kernel_initializer='glorot_uniform')\n\n\ndef create_fc_layers(layer_units, width=1.0, weight_decay=0):\n  layers = [\n      dense(tf_sparse_utils.scale_width(num_units, width=width),\n            kernel_regularizer=tf.keras.regularizers.L2(weight_decay))\n      for num_units in layer_units\n  ]\n  return layers\n\n\ndef create_identity_layer():\n  return tf.keras.layers.Lambda(lambda x: x)\n\n\ndef create_sequential_critic_network(obs_fc_layer_units,\n                                     action_fc_layer_units,\n                                     joint_fc_layer_units,\n                                     input_dim,\n                                     is_sparse = False,\n                                     width = 1.0,\n                                     weight_decay = 0.0,\n                                     sparse_output_layer = True):\n  \"\"\"Create a sequential critic network.\"\"\"\n  # Split the inputs into observations and actions.\n  def split_inputs(inputs):\n    return {'observation': inputs[0], 'action': inputs[1]}\n\n  # Create an observation network layers.\n  obs_network_layers = (\n      create_fc_layers(obs_fc_layer_units, width=width,\n                       weight_decay=weight_decay)\n      if obs_fc_layer_units else None)\n\n  # Create an action network layers.\n  action_network_layers = (\n      create_fc_layers(action_fc_layer_units, width=width,\n                       weight_decay=weight_decay)\n      if action_fc_layer_units else None)\n\n  # Create a joint network layers.\n  joint_network_layers = (\n      create_fc_layers(joint_fc_layer_units, width=width,\n                       weight_decay=weight_decay)\n      if joint_fc_layer_units else None)\n\n  # Final layer.\n  value_layer = tf.keras.layers.Dense(\n      1, kernel_initializer='glorot_uniform',\n      kernel_regularizer=tf.keras.regularizers.L2(weight_decay))\n\n  layer_list = [obs_network_layers, action_network_layers,\n                joint_network_layers]\n  if is_sparse:\n    # We need to process all-layers together to distribute sparsities for\n    # pruning.\n    all_layers = []\n    for layers in layer_list:\n      if layers is not None:\n        all_layers += layers\n    if sparse_output_layer:\n      all_layers.append(value_layer)\n      new_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)\n      value_layer = new_layers[-1]\n      new_layers = new_layers[:-1]\n    else:\n      new_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)\n    # Split back the layers to their own groups\n    c_index = 0\n    new_layer_list = []\n    for layers in layer_list:\n      if layers is None:\n        new_layer_list.append(None)\n      else:\n        new_layer_list.append(new_layers[c_index:len(layers)])\n        c_index += len(layers)\n    layer_list = new_layer_list\n  # Convert layer_list to sequential or identity lambdas:\n  module_list = [create_identity_layer() if layers is None else\n                 sequential.Sequential(layers)\n                 for layers in layer_list]\n  obs_network, action_network, joint_network = module_list\n\n  return sequential.Sequential([\n      tf.keras.layers.Lambda(split_inputs),\n      nest_map.NestMap({\n          'observation': obs_network,\n          'action': action_network\n      }),\n      nest_map.NestFlatten(),\n      tf.keras.layers.Concatenate(),\n      joint_network,\n      value_layer,\n      inner_reshape.InnerReshape(current_shape=[1], new_shape=[])\n  ], name='sequential_critic')\n\n\nclass _TanhNormalProjectionNetworkWrapper(\n    sparse_tanh_normal_projection_network.SparseTanhNormalProjectionNetwork):\n  \"\"\"Wrapper to pass predefined `outer_rank` to underlying projection net.\"\"\"\n\n  def __init__(self, sample_spec, predefined_outer_rank=1, weight_decay=0.0):\n    super(_TanhNormalProjectionNetworkWrapper, self).__init__(\n        sample_spec=sample_spec,\n        weight_decay=weight_decay)\n    self.predefined_outer_rank = predefined_outer_rank\n\n  def call(self, inputs, network_state=(), **kwargs):\n    kwargs['outer_rank'] = self.predefined_outer_rank\n    if 'step_type' in kwargs:\n      del kwargs['step_type']\n    return super(_TanhNormalProjectionNetworkWrapper,\n                 self).call(inputs, **kwargs)\n\n\ndef create_sequential_actor_network(actor_fc_layers,\n                                    action_tensor_spec,\n                                    input_dim,\n                                    is_sparse = False,\n                                    width = 1.0,\n                                    weight_decay = 0.0,\n                                    sparse_output_layer = True):\n  \"\"\"Create a sequential actor network.\"\"\"\n  def tile_as_nest(non_nested_output):\n    return tf.nest.map_structure(lambda _: non_nested_output,\n                                 action_tensor_spec)\n\n  dense_layers = [\n      dense(tf_sparse_utils.scale_width(num_units, width=width),\n            kernel_regularizer=tf.keras.regularizers.L2(weight_decay))\n      for num_units in actor_fc_layers\n  ]\n  tanh_normal_projection_network_fn = functools.partial(\n      _TanhNormalProjectionNetworkWrapper,\n      weight_decay=weight_decay)\n  last_layer = nest_map.NestMap(\n      tf.nest.map_structure(tanh_normal_projection_network_fn,\n                            action_tensor_spec))\n  if is_sparse:\n    if sparse_output_layer:\n\n      dense_layers.append(last_layer.layers[0]._projection_layer)\n      new_layers = tf_sparse_utils.wrap_all_layers(dense_layers, input_dim)\n      dense_layers = new_layers[:-1]\n      last_layer.layers[0]._projection_layer = new_layers[-1]\n\n    else:\n      dense_layers = tf_sparse_utils.wrap_all_layers(dense_layers, input_dim)\n\n  return sequential.Sequential(\n      dense_layers +\n      [tf.keras.layers.Lambda(tile_as_nest)] + [last_layer])\n\n\n@gin.configurable\nclass SparseSacAgent(sac_agent.SacAgent):\n  \"\"\"Wrapped DqnAgent that supports sparse training.\"\"\"\n\n  def __init__(self,\n               time_step_spec,\n               action_spec,\n               *args,\n               actor_sparsity=None,\n               critic_sparsity=None,\n               **kwargs):\n    super().__init__(time_step_spec,\n                     action_spec,\n                     *args,\n                     **kwargs)\n    # Pruning layer requires the pruning_step to be >1 during forward pass.\n    tf_sparse_utils.update_prune_step(\n        self._critic_network_1, self.train_step_counter + 1)\n    tf_sparse_utils.update_prune_step(\n        self._critic_network_2, self.train_step_counter + 1)\n    tf_sparse_utils.update_prune_step(\n        self._actor_network, self.train_step_counter + 1)\n\n    if critic_sparsity is not None:\n      _ = sparse_utils.init_masks(self._critic_network_1,\n                                  sparsity=critic_sparsity)\n      _ = sparse_utils.init_masks(self._critic_network_2,\n                                  sparsity=critic_sparsity)\n    else:  # Uses init_mask.sparsity value. Either the default or set via gin.\n      _ = sparse_utils.init_masks(self._critic_network_1)\n      _ = sparse_utils.init_masks(self._critic_network_2)\n\n    if actor_sparsity is not None:\n      _ = sparse_utils.init_masks(self._actor_network,\n                                  sparsity=actor_sparsity)\n    else:\n      _ = sparse_utils.init_masks(self._actor_network)\n\n    net_observation_spec = time_step_spec.observation\n    critic_spec = (net_observation_spec, action_spec)\n    self._target_critic_network_1 = (\n        common.maybe_copy_target_network_with_checks(\n            self._critic_network_1,\n            None,\n            input_spec=critic_spec,\n            name='TargetCriticNetwork1'))\n    self._target_critic_network_1 = (\n        common.maybe_copy_target_network_with_checks(\n            self._critic_network_2,\n            None,\n            input_spec=critic_spec,\n            name='TargetCriticNetwork2'))\n\n    def critic_loss_fn(experience, weights):\n      # The following is just to fit to the existing API.\n      transition = self._as_transition(experience)\n      time_steps, policy_steps, next_time_steps = transition\n      actions = policy_steps.action\n      return self._critic_loss_weight * self.critic_loss(\n          time_steps,\n          actions,\n          next_time_steps,\n          td_errors_loss_fn=self._td_errors_loss_fn,\n          gamma=self._gamma,\n          reward_scale_factor=self._reward_scale_factor,\n          weights=weights,\n          training=True)\n\n    def actor_loss_fn(experience, weights):\n      # The following is just to fit to the existing API.\n      transition = self._as_transition(experience)\n      time_steps, _, _ = transition\n      return self._actor_loss_weight*self.actor_loss(\n          time_steps, weights=weights, training=True)\n\n    # Create mask updater if doesn't exists\n    self._mask_updater_critic_1 = mask_updaters.get_mask_updater(\n        self._critic_network_1, self._critic_optimizer, critic_loss_fn)\n    self._mask_updater_critic_2 = mask_updaters.get_mask_updater(\n        self._critic_network_2, self._critic_optimizer, critic_loss_fn)\n    self._mask_updater_actor = mask_updaters.get_mask_updater(\n        self._actor_network, self._actor_optimizer, actor_loss_fn)\n\n  def _train(self, experience, weights):\n    \"\"\"Returns a train op to update the agent's networks.\n\n    This method trains with the provided batched experience.\n\n    Args:\n      experience: A time-stacked trajectory object.\n      weights: Optional scalar or elementwise (per-batch-entry) importance\n        weights.\n\n    Returns:\n      A train_op.\n\n    Raises:\n      ValueError: If optimizers are None and no default value was provided to\n        the constructor.\n    \"\"\"\n    tf.summary.experimental.set_step(self.train_step_counter)\n    transition = self._as_transition(experience)\n    time_steps, policy_steps, next_time_steps = transition\n    actions = policy_steps.action\n\n    trainable_critic_variables = list(object_identity.ObjectIdentitySet(\n        self._critic_network_1.trainable_variables +\n        self._critic_network_2.trainable_variables))\n\n    with tf.GradientTape(watch_accessed_variables=False) as tape:\n      assert trainable_critic_variables, ('No trainable critic variables to '\n                                          'optimize.')\n      tape.watch(trainable_critic_variables)\n      critic_loss = self._critic_loss_weight*self.critic_loss(\n          time_steps,\n          actions,\n          next_time_steps,\n          td_errors_loss_fn=self._td_errors_loss_fn,\n          gamma=self._gamma,\n          reward_scale_factor=self._reward_scale_factor,\n          weights=weights,\n          training=True)\n\n    tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.')\n    critic_grads = tape.gradient(critic_loss, trainable_critic_variables)\n    self._apply_gradients(critic_grads, trainable_critic_variables,\n                          self._critic_optimizer)\n\n    trainable_actor_variables = self._actor_network.trainable_variables\n    with tf.GradientTape(watch_accessed_variables=False) as tape:\n      assert trainable_actor_variables, ('No trainable actor variables to '\n                                         'optimize.')\n      tape.watch(trainable_actor_variables)\n      actor_loss = self._actor_loss_weight*self.actor_loss(\n          time_steps, weights=weights, training=True)\n    tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.')\n    actor_grads = tape.gradient(actor_loss, trainable_actor_variables)\n    self._apply_gradients(actor_grads, trainable_actor_variables,\n                          self._actor_optimizer)\n\n    # BEGIN sparse training mask update\n    # We use the lastest set of gradients to update the masks for sparse\n    # training. Note, we do this before gradient clipping.\n\n    # Define helper methods.\n    def _mask_update_step(mask_updater, updater_name):\n      mask_updater.set_validation_data(experience, weights)\n      mask_updater.update(self.train_step_counter)\n      with tf.name_scope('Drop_fraction/'):\n        tf.summary.scalar(\n            name=f'{updater_name}',\n            data=mask_updater.last_drop_fraction)\n\n    mask_update_step_critic_1 = functools.partial(_mask_update_step,\n                                                  self._mask_updater_critic_1,\n                                                  'critic_1')\n    mask_update_step_critic_2 = functools.partial(_mask_update_step,\n                                                  self._mask_updater_critic_2,\n                                                  'critic_2')\n    mask_update_step_actor = functools.partial(_mask_update_step,\n                                               self._mask_updater_actor,\n                                               'actor')\n\n    # Log sparsities every 1000 train steps.\n    def _log_sparsities():\n      tf_sparse_utils.log_sparsities(self._critic_network_1, 'critic_1')\n      tf_sparse_utils.log_sparsities(self._critic_network_2, 'critic_2')\n      tf_sparse_utils.log_sparsities(self._actor_network, 'actor')\n      tf_sparse_utils.log_total_params(\n          [self._critic_network_1,\n           self._critic_network_2,\n           self._actor_network])\n    tf.cond(self.train_step_counter % 1000 == 0, _log_sparsities, lambda: None)\n\n    # Update critics\n    if self._mask_updater_critic_1 is not None:\n      is_update_critic_1 = self._mask_updater_critic_1.is_update_iter(\n          self.train_step_counter)\n      tf.cond(is_update_critic_1, mask_update_step_critic_1, lambda: None)\n\n    if self._mask_updater_critic_2 is not None:\n      is_update_critic_2 = self._mask_updater_critic_2.is_update_iter(\n          self.train_step_counter)\n      tf.cond(is_update_critic_2, mask_update_step_critic_2, lambda: None)\n\n    # Update actor\n    if self._mask_updater_actor is not None:\n      is_update_actor = self._mask_updater_actor.is_update_iter(\n          self.train_step_counter)\n      tf.cond(is_update_actor, mask_update_step_actor, lambda: None)\n    # END sparse training mask update\n\n    alpha_variable = [self._log_alpha]\n    with tf.GradientTape(watch_accessed_variables=False) as tape:\n      assert alpha_variable, 'No alpha variable to optimize.'\n      tape.watch(alpha_variable)\n      alpha_loss = self._alpha_loss_weight * self.alpha_loss(\n          time_steps, weights=weights, training=True)\n    tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.')\n    alpha_grads = tape.gradient(alpha_loss, alpha_variable)\n    self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer)\n\n    with tf.name_scope('Losses'):\n      tf.compat.v2.summary.scalar(\n          name='critic_loss', data=critic_loss, step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='actor_loss', data=actor_loss, step=self.train_step_counter)\n      tf.compat.v2.summary.scalar(\n          name='alpha_loss', data=alpha_loss, step=self.train_step_counter)\n\n    self.train_step_counter.assign_add(1)\n    self._update_target()\n\n    total_loss = critic_loss + actor_loss + alpha_loss\n\n    extra = sac_agent.SacLossInfo(\n        critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss)\n\n    return tf_agent.LossInfo(loss=total_loss, extra=extra)\n\n\n@gin.configurable\ndef train_eval(\n    root_dir,\n    strategy,\n    env_name='HalfCheetah-v2',\n    # Training params\n    initial_collect_steps=10000,\n    num_iterations=1000000,\n    actor_fc_layers=(256, 256),\n    critic_obs_fc_layers=None,\n    critic_action_fc_layers=None,\n    critic_joint_fc_layers=(256, 256),\n    # Agent params\n    batch_size=256,\n    actor_learning_rate=3e-4,\n    critic_learning_rate=3e-4,\n    alpha_learning_rate=3e-4,\n    gamma=0.99,\n    target_update_tau=0.005,\n    target_update_period=1,\n    reward_scale_factor=0.1,\n    # Replay params\n    reverb_port=None,\n    replay_capacity=1000000,\n    # Others\n    policy_save_interval=10000,\n    replay_buffer_save_interval=100000,\n    eval_interval=10000,\n    eval_episodes=30,\n    debug_summaries=False,\n    summarize_grads_and_vars=False,\n    sparse_output_layer = False,\n    width = 1.0,\n    train_mode_actor = 'dense',\n    train_mode_value = 'dense',\n    weight_decay = 0.0,\n    actor_critic_sparsities_str = '',\n    actor_critic_widths_str = ''):\n  \"\"\"Trains and evaluates SAC.\"\"\"\n  assert FLAGS.is_mujoco\n\n  if actor_critic_widths_str:\n    actor_critic_widths = [float(s) for s in actor_critic_widths_str.split('_')]\n    width_actor = actor_critic_widths[0]\n    width_value = actor_critic_widths[1]\n  else:\n    width_actor = width\n    width_value = width\n\n  if actor_critic_sparsities_str:\n    actor_critic_sparsities = [\n        float(s) for s in actor_critic_sparsities_str.split('_')\n    ]\n  else:\n    # init_mask.sparsity value will be used. Either the default or set via gin.\n    actor_critic_sparsities = [None, None]\n\n  logging.info('Training SAC on: %s', env_name)\n  logging.info('SAC params: train mode actor: %s', train_mode_actor)\n  logging.info('SAC params: train mode value: %s', train_mode_value)\n  logging.info('SAC params: sparse_output_layer: %s', sparse_output_layer)\n  logging.info('SAC params: width: %s', width)\n  logging.info('SAC params: actor_critic_widths_str: %s',\n               actor_critic_widths_str)\n  logging.info('SAC params: width_actor: %s', width_actor)\n  logging.info('SAC params: width_value: %s', width_value)\n  logging.info('SAC params: weight_decay: %s', weight_decay)\n  logging.info('SAC params: actor_critic_sparsities_str %s type %s',\n               actor_critic_sparsities_str, type(actor_critic_sparsities_str))\n  logging.info('SAC params: actor_sparsity: %s', actor_critic_sparsities[0])\n  logging.info('SAC params: critic_sparsity: %s', actor_critic_sparsities[1])\n\n  collect_env = suite_mujoco.load(env_name)\n  eval_env = suite_mujoco.load(env_name)\n\n  _, action_tensor_spec, time_step_tensor_spec = (\n      spec_utils.get_tensor_specs(collect_env))\n\n  actor_net = create_sequential_actor_network(\n      actor_fc_layers=actor_fc_layers, action_tensor_spec=action_tensor_spec,\n      input_dim=time_step_tensor_spec.observation.shape[0],\n      is_sparse=(train_mode_actor == 'sparse'),\n      width=width_actor,\n      weight_decay=weight_decay,\n      sparse_output_layer=sparse_output_layer)\n\n  critic_input_dim = (\n      action_tensor_spec.shape[0] + time_step_tensor_spec.observation.shape[0])\n  critic_net = create_sequential_critic_network(\n      obs_fc_layer_units=critic_obs_fc_layers,\n      action_fc_layer_units=critic_action_fc_layers,\n      joint_fc_layer_units=critic_joint_fc_layers,\n      input_dim=critic_input_dim,\n      is_sparse=(train_mode_value == 'sparse'),\n      width=width_value,\n      weight_decay=weight_decay,\n      sparse_output_layer=sparse_output_layer)\n\n  with strategy.scope():\n    train_step = train_utils.create_train_step()\n    agent = SparseSacAgent(\n        time_step_spec=time_step_tensor_spec,\n        action_spec=action_tensor_spec,\n        actor_sparsity=actor_critic_sparsities[0],\n        critic_sparsity=actor_critic_sparsities[1],\n        actor_network=actor_net,\n        critic_network=critic_net,\n        actor_optimizer=tf.keras.optimizers.Adam(\n            learning_rate=actor_learning_rate),\n        critic_optimizer=tf.keras.optimizers.Adam(\n            learning_rate=critic_learning_rate),\n        alpha_optimizer=tf.keras.optimizers.Adam(\n            learning_rate=alpha_learning_rate),\n        target_update_tau=target_update_tau,\n        target_update_period=target_update_period,\n        td_errors_loss_fn=tf.math.squared_difference,\n        gamma=gamma,\n        reward_scale_factor=reward_scale_factor,\n        gradient_clipping=None,\n        debug_summaries=debug_summaries,\n        summarize_grads_and_vars=summarize_grads_and_vars,\n        train_step_counter=train_step)\n    agent.initialize()\n  table_name = 'uniform_table'\n  table = reverb.Table(\n      table_name,\n      max_size=replay_capacity,\n      sampler=reverb.selectors.Uniform(),\n      remover=reverb.selectors.Fifo(),\n      rate_limiter=reverb.rate_limiters.MinSize(1))\n\n  reverb_checkpoint_dir = os.path.join(root_dir, learner.TRAIN_DIR,\n                                       learner.REPLAY_BUFFER_CHECKPOINT_DIR)\n  reverb_checkpointer = reverb.platform.checkpointers_lib.DefaultCheckpointer(\n      path=reverb_checkpoint_dir)\n  reverb_server = reverb.Server([table],\n                                port=reverb_port,\n                                checkpointer=reverb_checkpointer)\n  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(\n      agent.collect_data_spec,\n      sequence_length=2,\n      table_name=table_name,\n      local_server=reverb_server)\n  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(\n      reverb_replay.py_client,\n      table_name,\n      sequence_length=2,\n      stride_length=1)\n\n  def experience_dataset_fn():\n    return reverb_replay.as_dataset(\n        sample_batch_size=batch_size, num_steps=2).prefetch(50)\n\n  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)\n  env_step_metric = py_metrics.EnvironmentSteps()\n  learning_triggers = [\n      triggers.PolicySavedModelTrigger(\n          saved_model_dir,\n          agent,\n          train_step,\n          interval=policy_save_interval,\n          metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}),\n      triggers.ReverbCheckpointTrigger(\n          train_step,\n          interval=replay_buffer_save_interval,\n          reverb_client=reverb_replay.py_client),\n      triggers.StepPerSecondLogTrigger(train_step, interval=1000),\n  ]\n\n  agent_learner = learner.Learner(\n      root_dir,\n      train_step,\n      agent,\n      experience_dataset_fn,\n      triggers=learning_triggers,\n      strategy=strategy)\n\n  random_policy = random_py_policy.RandomPyPolicy(\n      collect_env.time_step_spec(), collect_env.action_spec())\n  initial_collect_actor = actor.Actor(\n      collect_env,\n      random_policy,\n      train_step,\n      steps_per_run=initial_collect_steps,\n      observers=[rb_observer])\n  logging.info('Doing initial collect.')\n  initial_collect_actor.run()\n\n  tf_collect_policy = agent.collect_policy\n  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(\n      tf_collect_policy, use_tf_function=True)\n\n  collect_actor = actor.Actor(\n      collect_env,\n      collect_policy,\n      train_step,\n      steps_per_run=1,\n      metrics=actor.collect_metrics(10),\n      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),\n      observers=[rb_observer, env_step_metric])\n\n  tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)\n  eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(\n      tf_greedy_policy, use_tf_function=True)\n\n  eval_actor = actor.Actor(\n      eval_env,\n      eval_greedy_policy,\n      train_step,\n      episodes_per_run=eval_episodes,\n      metrics=actor.eval_metrics(eval_episodes),\n      summary_dir=os.path.join(root_dir, 'eval'),\n  )\n\n  average_returns = []\n  if eval_interval:\n    logging.info('Evaluating.')\n    eval_actor.run_and_log()\n    for metric in eval_actor.metrics:\n      if isinstance(metric, py_metrics.AverageReturnMetric):\n        average_returns.append(metric._buffer.mean())\n\n  logging.info('Training.')\n  for _ in range(num_iterations):\n    collect_actor.run()\n    agent_learner.run(iterations=1)\n\n    if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:\n      logging.info('Evaluating.')\n      eval_actor.run_and_log()\n      for metric in eval_actor.metrics:\n        if isinstance(metric, py_metrics.AverageReturnMetric):\n          average_returns.append(metric._buffer.mean())\n\n  # Log last section of evaluation scores for the final metric.\n  idx = int(FLAGS.average_last_fraction * len(average_returns))\n  avg_return = np.mean(average_returns[-idx:])\n  logging.info('Step %d, Average Return: %f', env_step_metric.result(),\n               avg_return)\n\n  rb_observer.close()\n  reverb_server.stop()\n\n\ndef main(_):\n  tf.config.run_functions_eagerly(False)\n  logging.set_verbosity(logging.INFO)\n  tf.compat.v1.enable_v2_behavior()\n  strategy = strategy_utils.get_strategy(FLAGS.tpu, FLAGS.use_gpu)\n\n  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings)\n  logging.info('Gin bindings: %s', FLAGS.gin_bindings)\n  logging.info('# Gin-Config:\\n %s', gin.config.operative_config_str())\n\n  train_eval(\n      FLAGS.root_dir,\n      strategy=strategy,\n      reverb_port=FLAGS.reverb_port)\n\n\nif __name__ == '__main__':\n  flags.mark_flag_as_required('root_dir')\n  app.run(main)\n"
  },
  {
    "path": "rigl/rl/tfagents/sparse_encoding_network.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Keras Encoding Network.\n\nImplements a network that will generate the following layers:\n\n  [optional]: preprocessing_layers  # preprocessing_layers\n  [optional]: (Add | Concat(axis=-1) | ...)  # preprocessing_combiner\n  [optional]: Conv2D # conv_layer_params\n  Flatten\n  [optional]: Dense  # fc_layer_params\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl import logging\nimport gin\nfrom rigl.rl.tfagents import tf_sparse_utils\nfrom six.moves import zip\nimport tensorflow as tf\n\nfrom tf_agents.keras_layers import permanent_variable_rate_dropout\nfrom tf_agents.networks import network\nfrom tf_agents.networks import utils\nfrom tf_agents.utils import nest_utils\n\nCONV_TYPE_2D = '2d'\nCONV_TYPE_1D = '1d'\n\n\ndef _copy_layer(layer):\n  \"\"\"Create a copy of a Keras layer with identical parameters.\n\n  The new layer will not share weights with the old one.\n\n  Args:\n    layer: An instance of `tf.keras.layers.Layer`.\n\n  Returns:\n    A new keras layer.\n\n  Raises:\n    TypeError: If `layer` is not a keras layer.\n    ValueError: If `layer` cannot be correctly cloned.\n  \"\"\"\n  if not isinstance(layer, tf.keras.layers.Layer):\n    raise TypeError('layer is not a keras layer: %s' % str(layer))\n\n  # pylint:disable=unidiomatic-typecheck\n  if type(layer) == tf.compat.v1.keras.layers.DenseFeatures:\n    raise ValueError('DenseFeatures V1 is not supported. '\n                     'Use tf.compat.v2.keras.layers.DenseFeatures instead.')\n  if layer.built:\n    logging.warning(\n        'Beware: Copying a layer that has already been built: \\'%s\\'.  '\n        'This can lead to subtle bugs because the original layer\\'s weights '\n        'will not be used in the copy.', layer.name)\n  # Get a fresh copy so we don't modify an incoming layer in place.  Weights\n  # will not be shared.\n  return type(layer).from_config(layer.get_config())\n\n\n@gin.configurable\nclass EncodingNetwork(network.Network):\n  \"\"\"Feed Forward network with CNN and FNN layers.\"\"\"\n\n  def __init__(self,\n               input_tensor_spec,\n               preprocessing_layers=None,\n               preprocessing_combiner=None,\n               conv_layer_params=None,\n               fc_layer_params=None,\n               dropout_layer_params=None,\n               activation_fn=tf.keras.activations.relu,\n               weight_decay_params=None,\n               kernel_initializer=None,\n               batch_squash=True,\n               dtype=tf.float32,\n               name='EncodingNetwork',\n               conv_type=CONV_TYPE_2D,\n               width=1.0):\n    \"\"\"Creates an instance of `EncodingNetwork`.\n\n    Network supports calls with shape outer_rank + input_tensor_spec.shape. Note\n    outer_rank must be at least 1.\n\n    For example an input tensor spec with shape `(2, 3)` will require\n    inputs with at least a batch size, the input shape is `(?, 2, 3)`.\n\n    Input preprocessing is possible via `preprocessing_layers` and\n    `preprocessing_combiner` Layers.  If the `preprocessing_layers` nest is\n    shallower than `input_tensor_spec`, then the layers will get the subnests.\n    For example, if:\n\n    ```python\n    input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5)\n    preprocessing_layers = (Layer1(), Layer2())\n    ```\n\n    then preprocessing will call:\n\n    ```python\n    preprocessed = [preprocessing_layers[0](observations[0]),\n                    preprocessing_layers[1](observations[1])]\n    ```\n\n    However if\n\n    ```python\n    preprocessing_layers = ([Layer1() for _ in range(2)],\n                            [Layer2() for _ in range(5)])\n    ```\n\n    then preprocessing will call:\n    ```python\n    preprocessed = [\n      layer(obs) for layer, obs in zip(flatten(preprocessing_layers),\n                                       flatten(observations))\n    ]\n    ```\n\n    **NOTE** `preprocessing_layers` and `preprocessing_combiner` are not allowed\n    to have already been built.  This ensures calls to `network.copy()` in the\n    future always have an unbuilt, fresh set of parameters.  Furtheremore,\n    a shallow copy of the layers is always created by the Network, so the\n    layer objects passed to the network are never modified.  For more details\n    of the semantics of `copy`, see the docstring of\n    `tf_agents.networks.Network.copy`.\n\n    Args:\n      input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the\n        input observations.\n      preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer`\n        representing preprocessing for the different observations. All of these\n        layers must not be already built.\n      preprocessing_combiner: (Optional.) A keras layer that takes a flat list\n        of tensors and combines them.  Good options include\n        `tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`. This\n        layer must not be already built.\n      conv_layer_params: Optional list of convolution layers parameters, where\n        each item is either a length-three tuple indicating\n        `(filters, kernel_size, stride)` or a length-four tuple indicating\n        `(filters, kernel_size, stride, dilation_rate)`.\n      fc_layer_params: Optional list of fully_connected parameters, where each\n        item is the number of units in the layer.\n      dropout_layer_params: Optional list of dropout layer parameters, each item\n        is the fraction of input units to drop or a dictionary of parameters\n        according to the keras.Dropout documentation. The additional parameter\n        `permanent`, if set to True, allows to apply dropout at inference for\n        approximated Bayesian inference. The dropout layers are interleaved with\n        the fully connected layers; there is a dropout layer after each fully\n        connected layer, except if the entry in the list is None. This list must\n        have the same length of fc_layer_params, or be None.\n      activation_fn: Activation function, e.g. tf.keras.activations.relu.\n      weight_decay_params: Optional list of weight decay parameters for the\n        fully connected layers.\n      kernel_initializer: Initializer to use for the kernels of the conv and\n        dense layers. If none is provided a default variance_scaling_initializer\n      batch_squash: If True the outer_ranks of the observation are squashed into\n        the batch dimension. This allow encoding networks to be used with\n        observations with shape [BxTx...].\n      dtype: The dtype to use by the convolution and fully connected layers.\n      name: A string representing name of the network.\n      conv_type: string, '1d' or '2d'. Convolution layers will be 1d or 2D\n        respectively\n      width: Scaling factor to apply to the layers.\n\n    Raises:\n      ValueError: If any of `preprocessing_layers` is already built.\n      ValueError: If `preprocessing_combiner` is already built.\n      ValueError: If the number of dropout layer parameters does not match the\n        number of fully connected layer parameters.\n      ValueError: If conv_layer_params tuples do not have 3 or 4 elements each.\n    \"\"\"\n    self._width = width\n    flat_preprocessing_layers = None\n\n    if (len(tf.nest.flatten(input_tensor_spec)) > 1 and\n        preprocessing_combiner is None):\n      raise ValueError(\n          'preprocessing_combiner layer is required when more than 1 '\n          'input_tensor_spec is provided.')\n\n    if preprocessing_combiner is not None:\n      preprocessing_combiner = _copy_layer(preprocessing_combiner)\n\n    if not kernel_initializer:\n      kernel_initializer = tf.compat.v1.variance_scaling_initializer(\n          scale=2.0, mode='fan_in', distribution='truncated_normal')\n\n    layers = []\n\n    if conv_layer_params:\n      if conv_type == '2d':\n        conv_layer_type = tf.keras.layers.Conv2D\n      elif conv_type == '1d':\n        conv_layer_type = tf.keras.layers.Conv1D\n      else:\n        raise ValueError('unsupported conv type of %s. Use 1d or 2d' % (\n            conv_type))\n\n      for config in conv_layer_params:\n        if len(config) == 4:\n          (filters, kernel_size, strides, dilation_rate) = config\n        elif len(config) == 3:\n          (filters, kernel_size, strides) = config\n          dilation_rate = (1, 1) if conv_type == '2d' else (1,)\n        else:\n          raise ValueError(\n              'only 3 or 4 elements permitted in conv_layer_params tuples')\n\n        kernel_regularizer = None\n        # We use the first weight decay param for all conv layers.\n        weight_decay = weight_decay_params[0]\n        if weight_decay is not None:\n          kernel_regularizer = tf.keras.regularizers.l2(weight_decay)\n\n        filters = tf_sparse_utils.scale_width(filters, self._width)\n        layers.append(\n            conv_layer_type(\n                filters=filters,\n                kernel_size=kernel_size,\n                strides=strides,\n                dilation_rate=dilation_rate,\n                activation=activation_fn,\n                kernel_initializer=kernel_initializer,\n                kernel_regularizer=kernel_regularizer,\n                dtype=dtype))\n\n    layers.append(tf.keras.layers.Flatten())\n\n    if fc_layer_params:\n      if dropout_layer_params is None:\n        dropout_layer_params = [None] * len(fc_layer_params)\n      else:\n        if len(dropout_layer_params) != len(fc_layer_params):\n          raise ValueError('Dropout and fully connected layer parameter lists'\n                           'have different lengths (%d vs. %d.)' %\n                           (len(dropout_layer_params), len(fc_layer_params)))\n      if weight_decay_params is None:\n        weight_decay_params = [None] * len(fc_layer_params)\n      else:\n        if len(weight_decay_params) != len(fc_layer_params):\n          raise ValueError('Weight decay and fully connected layer parameter '\n                           'lists have different lengths (%d vs. %d.)' %\n                           (len(weight_decay_params), len(fc_layer_params)))\n\n      for num_units, dropout_params, weight_decay in zip(\n          fc_layer_params, dropout_layer_params, weight_decay_params):\n        kernel_regularizer = None\n        if weight_decay is not None:\n          kernel_regularizer = tf.keras.regularizers.l2(weight_decay)\n        layers.append(\n            tf.keras.layers.Dense(\n                tf_sparse_utils.scale_width(num_units, self._width),\n                activation=activation_fn,\n                kernel_initializer=kernel_initializer,\n                kernel_regularizer=kernel_regularizer,\n                dtype=dtype))\n        if not isinstance(dropout_params, dict):\n          dropout_params = {'rate': dropout_params} if dropout_params else None\n\n        if dropout_params is not None:\n          layers.append(\n              permanent_variable_rate_dropout.PermanentVariableRateDropout(\n                  **dropout_params))\n\n    super(EncodingNetwork, self).__init__(\n        input_tensor_spec=input_tensor_spec, state_spec=(), name=name)\n\n    # Pull out the nest structure of the preprocessing layers. This avoids\n    # saving the original kwarg layers as a class attribute which Keras would\n    # then track.\n    self._preprocessing_nest = tf.nest.map_structure(lambda l: None,\n                                                     preprocessing_layers)\n    self._flat_preprocessing_layers = flat_preprocessing_layers\n    self._preprocessing_combiner = preprocessing_combiner\n    self._postprocessing_layers = layers\n    self._batch_squash = batch_squash\n    self.built = True  # Allow access to self.variables\n\n  def call(self, observation, step_type=None, network_state=(), training=False):\n    del step_type  # unused.\n\n    if self._batch_squash:\n      outer_rank = nest_utils.get_outer_rank(\n          observation, self.input_tensor_spec)\n      batch_squash = utils.BatchSquash(outer_rank)\n      observation = tf.nest.map_structure(batch_squash.flatten, observation)\n\n    if self._flat_preprocessing_layers is None:\n      processed = observation\n    else:\n      raise ValueError('Flat preprocessing layers should be None.')\n\n    states = processed\n\n    if self._preprocessing_combiner is not None:\n      states = self._preprocessing_combiner(states)\n\n    for layer in self._postprocessing_layers:\n      states = layer(states, training=training)\n\n    if self._batch_squash:\n      states = tf.nest.map_structure(batch_squash.unflatten, states)\n\n    return states, network_state\n"
  },
  {
    "path": "rigl/rl/tfagents/sparse_ppo_actor_network.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Sequential Actor Network for PPO.\"\"\"\nimport sys\n\nimport numpy as np\nfrom rigl.rl.tfagents import tf_sparse_utils\nimport tensorflow.compat.v2 as tf\nimport tensorflow_probability as tfp\n\nfrom tf_agents.keras_layers import bias_layer\n\nfrom tf_agents.networks import nest_map\nfrom tf_agents.networks import sequential\n\n\ndef tanh_and_scale_to_spec(inputs, spec):\n  \"\"\"Maps inputs with arbitrary range to range defined by spec using `tanh`.\"\"\"\n  means = (spec.maximum + spec.minimum) / 2.0\n  magnitudes = (spec.maximum - spec.minimum) / 2.0\n\n  return means + magnitudes * tf.tanh(inputs)\n\n\nclass PPOActorNetwork():\n  \"\"\"Contains the actor network structure.\"\"\"\n\n  def __init__(self,\n               seed_stream_class=tfp.util.SeedStream,\n               is_sparse=False,\n               sparse_output_layer=False,\n               weight_decay=0.0,\n               width=1.0):\n    self.seed_stream_class = seed_stream_class\n    self._is_sparse = is_sparse\n    self._sparse_output_layer = sparse_output_layer\n    self._weight_decay = weight_decay\n    self._width = width\n\n  def create_sequential_actor_net(self,\n                                  fc_layer_units,\n                                  action_tensor_spec,\n                                  input_dim,\n                                  seed=None):\n    \"\"\"Helper method for creating the actor network.\"\"\"\n    self._seed_stream = self.seed_stream_class(\n        seed=seed, salt='tf_agents_sequential_layers')\n\n    def _get_seed():\n      seed = self._seed_stream()\n      if seed is not None:\n        seed = seed % sys.maxsize\n      return seed\n\n    def create_dist(loc_and_scale):\n      loc = loc_and_scale['loc']\n      loc = tanh_and_scale_to_spec(loc, action_tensor_spec)\n\n      scale = loc_and_scale['scale']\n      scale = tf.math.softplus(scale)\n\n      return tfp.distributions.MultivariateNormalDiag(\n          loc=loc, scale_diag=scale, validate_args=True)\n\n    def means_layers():\n      layer = tf.keras.layers.Dense(\n          action_tensor_spec.shape.num_elements(),\n          kernel_initializer=tf.keras.initializers.VarianceScaling(\n              scale=0.1, seed=_get_seed()),\n          kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay),\n          name='means_projection_layer')\n\n      return layer\n\n    def std_layers():\n      std_bias_initializer_value = np.log(np.exp(0.35) - 1)\n      return bias_layer.BiasLayer(\n          bias_initializer=tf.constant_initializer(\n              value=std_bias_initializer_value))\n\n    def no_op_layers():\n      return tf.keras.layers.Lambda(lambda x: x)\n\n    def dense_layer(num_units):\n      layer = tf.keras.layers.Dense(\n          tf_sparse_utils.scale_width(num_units, self._width),\n          activation=tf.nn.tanh,\n          kernel_initializer=tf.keras.initializers.Orthogonal(seed=_get_seed()),\n          kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay),\n          )\n      return layer\n\n    all_layers = [dense_layer(n) for n in fc_layer_units]\n    all_layers.append(means_layers())\n    if self._is_sparse:\n      if self._sparse_output_layer:\n        all_layers = tf_sparse_utils.wrap_all_layers(all_layers, input_dim)\n      else:\n        new_layers = tf_sparse_utils.wrap_all_layers(all_layers[:-1], input_dim)\n        all_layers = new_layers + all_layers[-1:]\n\n    return sequential.Sequential(\n        all_layers +\n        [tf.keras.layers.Lambda(\n            lambda x: {'loc': x, 'scale': tf.zeros_like(x)})] +\n        [nest_map.NestMap({\n            'loc': no_op_layers(),\n            'scale': std_layers(),\n        })] +\n        # Create the output distribution from the mean and standard deviation.\n        [tf.keras.layers.Lambda(create_dist)])\n"
  },
  {
    "path": "rigl/rl/tfagents/sparse_ppo_discrete_actor_network.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"Sparse Discrete Sequential Actor Network for PPO.\"\"\"\n\nimport functools\nimport sys\nimport numpy as np\nfrom rigl.rl.tfagents import tf_sparse_utils\n\nimport tensorflow.compat.v2 as tf\nimport tensorflow_probability as tfp\n\nfrom tf_agents.networks import sequential\nfrom tf_agents.specs import distribution_spec\nfrom tf_agents.specs import tensor_spec\n\n\ndef tanh_and_scale_to_spec(inputs, spec):\n  \"\"\"Maps inputs with arbitrary range to range defined by spec using `tanh`.\"\"\"\n  mean = (spec.maximum + spec.minimum) / 2.0\n  magnitude = spec.maximum - spec.minimum\n\n  return mean + (magnitude * tf.tanh(inputs)) / 2.0\n\n\nclass PPODiscreteActorNetwork():\n  \"\"\"Contains the actor network structure.\"\"\"\n\n  def __init__(self, seed_stream_class=tfp.util.SeedStream,\n               is_sparse=False,\n               sparse_output_layer=False,\n               weight_decay=0,\n               width=1.0):\n    if is_sparse:\n      raise ValueError('This functionality is not enabled. wrap_all_layers,'\n                       'functionality needs to be implemented')\n    self.seed_stream_class = seed_stream_class\n    # Sparse params.\n    self._is_sparse = is_sparse\n    self._sparse_output_layer = sparse_output_layer\n    self._width = width\n    self._weight_decay = weight_decay\n\n  def create_sequential_actor_net(self,\n                                  fc_layer_units,\n                                  action_tensor_spec,\n                                  logits_init_output_factor=0.1,\n                                  seed=None):\n    \"\"\"Helper method for creating the actor network.\"\"\"\n\n    self._seed_stream = self.seed_stream_class(\n        seed=seed, salt='tf_agents_sequential_layers')\n    # action_tensor_spec is a BoundedArraySpec which is an array with defined\n    # bounds. Maximum and minimum are arrays with the same shape as the\n    # main array.\n    unique_num_actions = np.unique(action_tensor_spec.maximum -\n                                   action_tensor_spec.minimum + 1)\n    if len(unique_num_actions) > 1 or np.any(unique_num_actions <= 0):\n      raise ValueError('Bounds on discrete actions must be the same for all '\n                       'dimensions and have at least 1 action. Projection '\n                       'Network requires num_actions to be equal across '\n                       'action dimensions. Implement a more general '\n                       'categorical projection if you need more flexibility.')\n\n    output_shape = action_tensor_spec.shape.concatenate(\n        [int(unique_num_actions)])\n\n    def _get_seed():\n      seed = self._seed_stream()\n      if seed is not None:\n        seed = seed % sys.maxsize\n      return seed\n\n    def create_dist(logits):\n      input_param_spec = {\n          'logits': tensor_spec.TensorSpec(\n              shape=(1,) + output_shape, dtype=tf.float32)\n      }\n      dist_spec = distribution_spec.DistributionSpec(\n          tfp.distributions.Categorical,\n          input_param_spec,\n          sample_spec=action_tensor_spec,\n          dtype=action_tensor_spec.dtype)\n      logits = tf.reshape(logits, [-1] + output_shape.as_list())\n      return dist_spec.build_distribution(logits=logits)\n\n    def dense_layer(num_units):\n      dense = functools.partial(\n          tf.keras.layers.Dense,\n          activation=tf.nn.tanh,\n          kernel_initializer=tf.keras.initializers.Orthogonal(seed=_get_seed()),\n          kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay))\n      layer = dense(tf_sparse_utils.scale_width(num_units, self._width))\n      if self._is_sparse:\n        return tf_sparse_utils.wrap_layer(layer)\n      else:\n        return layer\n\n    output_layer = tf.keras.layers.Dense(\n        output_shape.num_elements(),\n        kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(\n            scale=logits_init_output_factor, seed=_get_seed()),\n        kernel_regularizer=tf.keras.regularizers.L2(self._weight_decay),\n        bias_initializer=tf.keras.initializers.Zeros(),\n        name='logits',\n        dtype=tf.float32)\n    if self._is_sparse and self._sparse_output_layer:\n      output_layer = tf_sparse_utils.wrap_layer(output_layer)\n\n    return sequential.Sequential(\n        [dense_layer(num_units) for num_units in fc_layer_units] +\n        [output_layer] +\n        [tf.keras.layers.Lambda(create_dist)])\n"
  },
  {
    "path": "rigl/rl/tfagents/sparse_ppo_discrete_actor_network_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for sparse_ppo_discrete_actor_network.\"\"\"\n\nfrom absl import flags\nfrom absl.testing import parameterized\n\nfrom rigl.rl.tfagents import sparse_ppo_discrete_actor_network\nimport tensorflow as tf\nfrom tf_agents.distributions import utils as distribution_utils\nfrom tf_agents.specs import tensor_spec\nfrom tf_agents.utils import test_utils\n\nfrom tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper\n\nFLAGS = flags.FLAGS\n\n\nclass DeterministicSeedStream(object):\n  \"\"\"A fake seed stream class that always generates a deterministic seed.\"\"\"\n\n  def __init__(self, seed, salt=''):\n    del salt\n    self._seed = seed\n\n  def __call__(self):\n    return self._seed\n\n\nclass PpoActorNetworkTest(parameterized.TestCase, test_utils.TestCase):\n\n  def setUp(self):\n    super(PpoActorNetworkTest, self).setUp()\n    # Run in full eager mode in order to inspect the content of tensors.\n    tf.config.experimental_run_functions_eagerly(True)\n    self.observation_tensor_spec = tf.TensorSpec(shape=[3], dtype=tf.float32)\n    self.action_tensor_spec = tensor_spec.BoundedTensorSpec((), tf.int32, 0, 3)\n\n  def tearDown(self):\n    tf.config.experimental_run_functions_eagerly(False)\n    super(PpoActorNetworkTest, self).tearDown()\n\n  def _init_network(\n      self, is_sparse=False, sparse_output_layer=False,\n      width=1.0, weight_decay=0):\n    actor_net_lib = sparse_ppo_discrete_actor_network.PPODiscreteActorNetwork(\n        is_sparse=is_sparse, sparse_output_layer=sparse_output_layer,\n        width=width, weight_decay=weight_decay)\n    actor_net_lib.seed_stream_class = DeterministicSeedStream\n    return actor_net_lib.create_sequential_actor_net(\n        fc_layer_units=(1,), action_tensor_spec=self.action_tensor_spec, seed=1)\n\n  def test_no_mismatched_shape(self):\n    actor_net = self._init_network()\n    actor_output_spec = actor_net.create_variables(self.observation_tensor_spec)\n    distribution_utils.assert_specs_are_compatible(\n        actor_output_spec, self.action_tensor_spec,\n        'actor_network output spec does not match action spec')\n\n  @parameterized.named_parameters(\n      ('dense-output-F', False, False,\n       (tf.keras.layers.Dense, tf.keras.layers.Dense)),\n      ('dense-output-T', False, True,\n       (tf.keras.layers.Dense, tf.keras.layers.Dense)),\n      ('sparse-all', True, True,\n       (pruning_wrapper.PruneLowMagnitude, pruning_wrapper.PruneLowMagnitude)),\n      ('sparse-outp-dense', True, False,\n       (pruning_wrapper.PruneLowMagnitude, tf.keras.layers.Dense)),\n      )\n  def test_is_sparse(self, is_sparse, sparse_output_layer, expected_layers):\n    expected_units = (1, 4)\n    actor_net = self._init_network(\n        is_sparse=is_sparse, sparse_output_layer=sparse_output_layer)\n    for i, (expected_layer, exp_units) in enumerate(\n        zip(expected_layers, expected_units)):\n      layer = actor_net.layers[i]\n      self.assertIsInstance(layer, expected_layer)\n      if isinstance(layer, pruning_wrapper.PruneLowMagnitude):\n        self.assertEqual(layer.layer.units, exp_units)\n      else:\n        self.assertEqual(layer.units, exp_units)\n\n  def test_width_scaling(self):\n    with self.subTest('dense'):\n      actor_net = self._init_network(width=2.0)\n      self.assertEqual(actor_net.layers[0].units, 2)\n      self.assertEqual(actor_net.layers[1].units, 4)\n\n    with self.subTest('sparse'):\n      actor_net = self._init_network(\n          is_sparse=True, sparse_output_layer=True, width=2.0)\n      self.assertEqual(actor_net.layers[0].layer.units, 2)\n      self.assertEqual(actor_net.layers[1].layer.units, 4)\n\n  @parameterized.named_parameters(\n      ('no-wd-d-d', False, False, 0),\n      ('no-wd-s-d', True, False, 0),\n      ('no-wd-s-s', True, True, 0),\n      ('wd-d-d', False, False, 0.1),\n      ('wd-s-d', True, False, 0.1),\n      ('wd-s-s', True, True, 0.1))\n  def test_weight_decay(self, is_sparse, sparse_output_layer,\n                        expected_weight_decay):\n    actor_net = self._init_network(is_sparse=is_sparse,\n                                   sparse_output_layer=sparse_output_layer,\n                                   weight_decay=expected_weight_decay)\n    for i in range(2):\n      layer = actor_net.layers[i]\n      if isinstance(layer, pruning_wrapper.PruneLowMagnitude):\n        l2_weight_decay = layer.layer.kernel_regularizer.get_config()['l2']\n      else:\n        l2_weight_decay = layer.kernel_regularizer.get_config()['l2']\n      self.assertAlmostEqual(l2_weight_decay, expected_weight_decay)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "rigl/rl/tfagents/sparse_tanh_normal_projection_network.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Project inputs to a tanh-squashed MultivariateNormalDiag distribution.\n\nThis network reproduces Soft Actor-Critic refererence implementation in:\nhttps://github.com/rail-berkeley/softlearning/\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom typing import Callable, Optional, Text\n\nimport gin\nimport tensorflow as tf\nfrom tf_agents.agents.sac import tanh_normal_projection_network\nfrom tf_agents.typing import types\n\n\n@gin.configurable\nclass SparseTanhNormalProjectionNetwork(\n    tanh_normal_projection_network.TanhNormalProjectionNetwork):\n  \"\"\"Generates a tanh-squashed MultivariateNormalDiag distribution.\n\n  Note: Due to the nature of the `tanh` function, values near the spec bounds\n  cannot be returned.\n  \"\"\"\n\n  def __init__(self,\n               sample_spec,\n               activation_fn = None,\n               std_transform = tf.exp,\n               name = 'SparseTanhNormalProjectionNetwork',\n               weight_decay=0.0):\n    \"\"\"Creates an instance of SparseTanhNormalProjectionNetwork.\n\n    Args:\n      sample_spec: A `tensor_spec.BoundedTensorSpec` detailing the shape and\n        dtypes of samples pulled from the output distribution.\n      activation_fn: Activation function to use in dense layer.\n      std_transform: Transformation function to apply to the stddevs.\n      name: A string representing name of the network.\n      weight_decay: Weight decay for L2 regularization.\n    \"\"\"\n    super(SparseTanhNormalProjectionNetwork, self).__init__(\n        sample_spec=sample_spec,\n        activation_fn=activation_fn,\n        std_transform=std_transform,\n        name=name)\n\n    # We reinitialize the projection layer with L2 regularization and also\n    # optionally sparsify it.\n    self._projection_layer = tf.keras.layers.Dense(\n        sample_spec.shape.num_elements() * 2,\n        activation=activation_fn,\n        kernel_regularizer=tf.keras.regularizers.L2(weight_decay),\n        name='projection_layer')\n"
  },
  {
    "path": "rigl/rl/tfagents/sparse_value_network.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Sample Keras Value Network.\n\nImplements a network that will generate the following layers:\n\n  [optional]: preprocessing_layers  # preprocessing_layers\n  [optional]: (Add | Concat(axis=-1) | ...)  # preprocessing_combiner\n  [optional]: Conv2D # conv_layer_params\n  Flatten\n  [optional]: Dense  # fc_layer_params\n  Dense -> 1         # Value output\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport gin\n\nfrom rigl.rl.tfagents import sparse_encoding_network\nfrom rigl.rl.tfagents import tf_sparse_utils\nimport tensorflow as tf\n\nfrom tf_agents.networks import network\n\n\n@gin.configurable\nclass ValueNetwork(network.Network):\n  \"\"\"Feed Forward value network. Reduces to 1 value output per batch item.\"\"\"\n\n  def __init__(self,\n               input_tensor_spec,\n               preprocessing_combiner=None,\n               conv_layer_params=None,\n               fc_layer_params=(75, 40),\n               dropout_layer_params=None,\n               weight_decay=0.0,\n               activation_fn=tf.keras.activations.relu,\n               kernel_initializer=None,\n               batch_squash=True,\n               dtype=tf.float32,\n               name='ValueNetwork',\n               is_sparse=False,\n               sparse_output_layer=False,\n               width=1.0):\n    \"\"\"Creates an instance of `ValueNetwork`.\n\n    Network supports calls with shape outer_rank + observation_spec.shape. Note\n    outer_rank must be at least 1.\n\n    Args:\n      input_tensor_spec: A `tensor_spec.TensorSpec` or a tuple of specs\n        representing the input observations.\n      preprocessing_combiner: (Optional.) A keras layer that takes a flat list\n        of tensors and combines them. Good options include\n        `tf.keras.layers.Add` and `tf.keras.layers.Concatenate(axis=-1)`.\n        This layer must not be already built. For more details see\n        the documentation of `networks.EncodingNetwork`.\n      conv_layer_params: Optional list of convolution layers parameters, where\n        each item is a length-three tuple indicating (filters, kernel_size,\n        stride).\n      fc_layer_params: Optional list of fully_connected parameters, where each\n        item is the number of units in the layer.\n      dropout_layer_params: Optional list of dropout layer parameters, each item\n        is the fraction of input units to drop or a dictionary of parameters\n        according to the keras.Dropout documentation. The additional parameter\n        `permanent`, if set to True, allows to apply dropout at inference for\n        approximated Bayesian inference. The dropout layers are interleaved with\n        the fully connected layers; there is a dropout layer after each fully\n        connected layer, except if the entry in the list is None. This list must\n        have the same length of fc_layer_params, or be None.\n      weight_decay: L2 weight decay regularization parameter.\n      activation_fn: Activation function, e.g. tf.keras.activations.relu,.\n      kernel_initializer: Initializer to use for the kernels of the conv and\n        dense layers. If none is provided a default variance_scaling_initializer\n      batch_squash: If True the outer_ranks of the observation are squashed into\n        the batch dimension. This allow encoding networks to be used with\n        observations with shape [BxTx...].\n      dtype: The dtype to use by the convolution and fully connected layers.\n      name: A string representing name of the network.\n      is_sparse: Whether the network is sparse.\n      sparse_output_layer: Whether the output layer should be sparse. Only\n        applied when is_sparse=True.\n      width: Scaling factor to apply to the layers.\n\n    Raises:\n      ValueError: If input_tensor_spec is not an instance of network.InputSpec.\n    \"\"\"\n    super(ValueNetwork, self).__init__(\n        input_tensor_spec=input_tensor_spec,\n        state_spec=(),\n        name=name)\n\n    self._is_sparse = is_sparse\n    self._sparse_output_layer = sparse_output_layer\n    self._width = width\n\n    if not kernel_initializer:\n      kernel_initializer = tf.compat.v1.keras.initializers.glorot_uniform()\n\n    self._encoder = sparse_encoding_network.EncodingNetwork(\n        input_tensor_spec,\n        preprocessing_layers=None,\n        preprocessing_combiner=preprocessing_combiner,\n        conv_layer_params=conv_layer_params,\n        fc_layer_params=fc_layer_params,\n        dropout_layer_params=dropout_layer_params,\n        activation_fn=activation_fn,\n        weight_decay_params=[weight_decay] * len(fc_layer_params),\n        kernel_initializer=kernel_initializer,\n        batch_squash=batch_squash,\n        dtype=dtype,\n        width=self._width)\n\n    self._postprocessing_layers = tf.keras.layers.Dense(\n        1,\n        activation=None,\n        kernel_initializer=tf.random_uniform_initializer(\n            minval=-0.03, maxval=0.03),\n        kernel_regularizer=tf.keras.regularizers.L2(weight_decay))\n\n    if is_sparse:\n      layers_to_wrap = [l for l in self._encoder._postprocessing_layers\n                        if tf_sparse_utils.is_valid_layer_to_wrap(l)]\n      input_dim = input_tensor_spec.shape[0]\n      if sparse_output_layer:\n        layers_to_wrap.append(self._postprocessing_layers)\n        wrapped_layers = tf_sparse_utils.wrap_all_layers(\n            layers_to_wrap, input_dim)\n        self._postprocessing_layers = wrapped_layers[-1]\n        wrapped_layers = wrapped_layers[:-1]\n      else:\n        wrapped_layers = tf_sparse_utils.wrap_all_layers(\n            layers_to_wrap, input_dim)\n      # We need to recreate the original layer list after wrapping the layers.\n      new_layer_list = []\n      i = 0\n      for unwrapped_layer in self._encoder._postprocessing_layers:\n        if tf_sparse_utils.is_valid_layer_to_wrap(unwrapped_layer):\n          new_layer_list.append(wrapped_layers[i])\n          i += 1\n        else:\n          new_layer_list.append(unwrapped_layer)\n      self._encoder._postprocessing_layers = new_layer_list\n\n  def call(self, observation, step_type=None, network_state=(), training=False):\n    state, network_state = self._encoder(\n        observation, step_type=step_type, network_state=network_state,\n        training=training)\n    value = self._postprocessing_layers(state, training=training)\n    return tf.squeeze(value, -1), network_state\n"
  },
  {
    "path": "rigl/rl/tfagents/tf_sparse_utils.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Utility functions for sparse tf agents training.\"\"\"\n\nimport re\nfrom absl import logging\nimport gin\nfrom rigl import sparse_utils as sparse_utils_rigl\nfrom rigl.rl import sparse_utils\n\nimport tensorflow.compat.v2 as tf\n\nfrom tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule\nfrom tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper\n\nPRUNING_WRAPPER = pruning_wrapper.PruneLowMagnitude\n_LAYER_TYPES_TO_WRAP = (tf.keras.layers.Dense, tf.keras.layers.Conv2D,\n                        tf.keras.layers.Conv1D)\n\n\ndef log_total_params(networks):\n  total_params = 0\n  for net in networks:\n    total_net_params, _ = sparse_utils.get_total_params(net)\n    total_params += total_net_params\n  with tf.name_scope('Params/'):\n    tf.compat.v2.summary.scalar('total', total_params)\n\n\ndef scale_width(num_units, width):\n  assert width > 0\n  return int(max(1, num_units * width))\n\n\n@gin.configurable\ndef wrap_all_layers(layers,\n                    input_dim,\n                    mode='constant',\n                    mask_init_method='erdos_renyi_kernel',\n                    initial_sparsity=0.0,\n                    final_sparsity=0.9,\n                    begin_step=200000,\n                    end_step=600000,\n                    frequency=10000):\n  \"\"\"Wraps a list of dense keras layers to be used by sparse training.\"\"\"\n  # We only need to define static masks here, we will update them through\n  # mask updater later.\n  new_layers = []\n  if mode == 'constant':\n    for layer in layers:\n      schedule = pruning_schedule.ConstantSparsity(\n          target_sparsity=0, begin_step=1000000000)\n      new_layers.append(PRUNING_WRAPPER(layer, pruning_schedule=schedule))\n  elif mode == 'prune':\n    logging.info('Pruning schedule: initial sparsity: %f', initial_sparsity)\n    logging.info('Pruning schedule: mask_init_method: %s', mask_init_method)\n    logging.info('Pruning schedule: final sparsity: %f', final_sparsity)\n    logging.info('Pruning schedule: begin step: %f', begin_step)\n    logging.info('Pruning schedule: end step: %f', end_step)\n    logging.info('Pruning schedule: frequency: %f', frequency)\n\n    # Create dummy masks to get layer-wise sparsities. This is because the\n    # get_sparsities function expects mask variables to calculate the\n    # sparsities.\n    dummy_masks_dict = {}\n    layer_input_dim = input_dim\n    for layer in layers:\n      mask = tf.Variable(tf.ones([layer_input_dim, layer.units]),\n                         trainable=False, name=f'dummymask_{layer.name}')\n      layer_input_dim = layer.units\n      dummy_masks_dict[layer.name] = mask\n\n    # Get layer-wise sparsities.\n    extract_name_fn = lambda x: re.findall('(.+):0', x)[0]\n    reverse_dict = {v.name: k\n                    for k, v in dummy_masks_dict.items()}\n    sparsity_dict = sparse_utils_rigl.get_sparsities(\n        list(dummy_masks_dict.values()),\n        mask_init_method,\n        final_sparsity,\n        custom_sparsity_map={},\n        extract_name_fn=extract_name_fn)\n    # This dict will have {layer_name: layer_sparsity}\n    renamed_sparsity_dict = {reverse_dict[k]: float(v)\n                             for k, v in sparsity_dict.items()}\n    # Wrap layers with possibly non-uniform pruning schedule.\n    for layer in layers:\n      sparsity = renamed_sparsity_dict[layer.name]\n      logging.info('Layer: %s, sparsity: %f', layer.name, sparsity)\n      schedule = pruning_schedule.PolynomialDecay(\n          initial_sparsity=initial_sparsity,\n          final_sparsity=sparsity,\n          begin_step=begin_step,\n          end_step=end_step,\n          frequency=frequency)\n      new_layers.append(PRUNING_WRAPPER(layer, pruning_schedule=schedule))\n\n  return new_layers\n\n\n@gin.configurable\ndef wrap_layer(layer,\n               mode='constant',\n               initial_sparsity=0.0,\n               final_sparsity=0.9,\n               begin_step=200000,\n               end_step=600000,\n               frequency=10000):\n  \"\"\"Wraps a keras layer to be used by sparse training.\"\"\"\n  # We only need to define static masks here, we will update them through\n  # mask updater later.\n  if mode == 'constant':\n    schedule = pruning_schedule.ConstantSparsity(\n        target_sparsity=0, begin_step=1000000000)\n  elif mode == 'prune':\n    logging.info('Pruning schedule: initial sparsity: %f', initial_sparsity)\n    logging.info('Pruning schedule: final sparsity: %f', final_sparsity)\n    logging.info('Pruning schedule: begin step: %f', begin_step)\n    logging.info('Pruning schedule: end step: %f', end_step)\n    logging.info('Pruning schedule: frequency: %f', frequency)\n    schedule = pruning_schedule.PolynomialDecay(\n        initial_sparsity=initial_sparsity,\n        final_sparsity=final_sparsity,\n        begin_step=begin_step,\n        end_step=end_step,\n        frequency=frequency)\n\n  return PRUNING_WRAPPER(layer, pruning_schedule=schedule)\n\n\ndef is_valid_layer_to_wrap(layer):\n  for layer_type in _LAYER_TYPES_TO_WRAP:\n    if isinstance(layer, layer_type):\n      return True\n\n  return False\n\n\n@gin.configurable\ndef log_sparsities(model, model_name='q_net', log_images=False):\n  \"\"\"Logs relevant sparsity stats to tensorboard.\"\"\"\n  for layer in sparse_utils.get_all_pruning_layers(model):\n    for _, mask, threshold in layer.pruning_vars:\n      if log_images:\n        reshaped_mask = tf.expand_dims(tf.expand_dims(mask, 0), -1)\n        with tf.name_scope('Masks/'):\n          tf.compat.v2.summary.image(f'{model_name}/{mask.name}', reshaped_mask)\n      with tf.name_scope('Sparsity/'):\n        sparsity = 1 - tf.reduce_mean(mask)\n        tf.compat.v2.summary.scalar(f'{model_name}/{mask.name}', sparsity)\n      with tf.name_scope('Threshold/'):\n        tf.compat.v2.summary.scalar(f'{model_name}/{threshold.name}', threshold)\n\n  total_params, nparam_dict = sparse_utils.get_total_params(model)\n  with tf.name_scope('Params/'):\n    tf.compat.v2.summary.scalar(f'{model_name}/total', total_params)\n    for k, val in nparam_dict.items():\n      tf.compat.v2.summary.scalar(f'{model_name}/' + k, val)\n\n\ndef update_prune_step(model, step):\n  for layer in sparse_utils.get_all_pruning_layers(model):\n    # Assign iteration count to the layer pruning_step.\n    layer.pruning_step.assign(step)\n\n\ndef flatten_list_of_vars(var_list):\n  flat_vars = [tf.reshape(v, [-1]) for v in var_list]\n  return tf.concat(flat_vars, axis=-1)\n\n\n@gin.configurable\ndef log_snr(tape, loss, step, variables_to_train, freq=1000):\n  \"\"\"Given a gradient tape and loss, it logs signal-to-noise ratio.\"\"\"\n\n  def true_fn():\n    grads_per_sample = tape.jacobian(loss, variables_to_train)\n    list_of_snrs = []\n    for grad in grads_per_sample:\n      if grad is not None:\n        if isinstance(grad, tf.IndexedSlices):\n          grad_values = grad.values\n        else:\n          grad_values = grad\n      grad_mean = tf.math.reduce_mean(grad_values, axis=0)\n      grad_std = tf.math.reduce_std(grad_values, axis=0)\n      list_of_snrs.append(tf.abs(grad_mean / (grad_std + 1e-10)))\n\n    snr_mean = tf.reduce_mean(flatten_list_of_vars(list_of_snrs))\n    snr_std = tf.math.reduce_std((flatten_list_of_vars(list_of_snrs)))\n    with tf.name_scope('SNR/'):\n      tf.compat.v2.summary.scalar(name='mean', data=snr_mean, step=step)\n      tf.compat.v2.summary.scalar(name='std', data=snr_std, step=step)\n\n  tf.cond(step % freq == 0, true_fn, lambda: None)\n"
  },
  {
    "path": "rigl/rl/train.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nr\"\"\"The entry point for training a sparse DQN agent.\"\"\"\n\nimport os\n\nfrom absl import app\nfrom absl import flags\nimport gin\nfrom rigl.rl import run_experiment\nimport tensorflow as tf\n\n\n\nflags.DEFINE_string('base_dir', None,\n                    'Base directory to host all required sub-directories.')\nflags.DEFINE_multi_string(\n    'gin_files', [], 'List of paths to gin configuration files.')\nflags.DEFINE_multi_string(\n    'gin_bindings', [],\n    'Gin bindings to override the values set in the config files '\n    '(e.g. \"DQNAgent.epsilon_train=0.1\",'\n    '      \"create_atari_environment.game_name=\"Pong\"\").')\n\nFLAGS = flags.FLAGS\n\n\ndef create_sparsetrain_runner(base_dir):\n  assert base_dir is not None\n  return run_experiment.SparseTrainRunner(base_dir)\n\n\ndef main(unused_argv):\n  gin.parse_config_files_and_bindings(FLAGS.gin_files, FLAGS.gin_bindings)\n\n  runner = create_sparsetrain_runner(FLAGS.base_dir)\n  runner.run_experiment()\n\n  logconfigfile_path = os.path.join(FLAGS.base_dir, 'operative_config.gin')\n  with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:\n    f.write('# Gin-Config:\\n %s' % gin.config.operative_config_str())\n\n\nif __name__ == '__main__':\n  flags.mark_flag_as_required('base_dir')\n  app.run(main)\n"
  },
  {
    "path": "rigl/sparse_optimizers.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"This module implements some common and new sparse training algorithms.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\n\nimport numpy as np\nfrom rigl import sparse_optimizers_base as sparse_opt_base\nfrom rigl import sparse_utils\n\n\nfrom tensorflow.contrib.model_pruning.python import pruning\nfrom tensorflow.python.framework import dtypes\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import array_ops\nfrom tensorflow.python.ops import control_flow_ops\nfrom tensorflow.python.ops import init_ops\nfrom tensorflow.python.ops import math_ops\nfrom tensorflow.python.ops import nn_ops\nfrom tensorflow.python.ops import state_ops\nfrom tensorflow.python.ops import variable_scope\nfrom tensorflow.python.ops import variables\nfrom tensorflow.python.tpu.ops import tpu_ops\nfrom tensorflow.python.training import moving_averages\nfrom tensorflow.python.training import optimizer as tf_optimizer\nfrom tensorflow.python.training import training_util\n\n\nclass PruningGetterTf1Mixin:\n  \"\"\"Tf1 model_pruning library specific variable retrieval.\"\"\"\n\n  def get_weights(self):\n    return pruning.get_weights()\n\n  def get_masks(self):\n    return pruning.get_masks()\n\n  def get_masked_weights(self):\n    return pruning.get_masked_weights()\n\n\nclass SparseSETOptimizer(PruningGetterTf1Mixin,\n                         sparse_opt_base.SparseSETOptimizerBase):\n  pass\n\n\nclass SparseRigLOptimizer(PruningGetterTf1Mixin,\n                          sparse_opt_base.SparseRigLOptimizerBase):\n  pass\n\n\nclass SparseStaticOptimizer(SparseSETOptimizer):\n  \"\"\"Sparse optimizer that re-initializes weak connections during training.\n\n  Attributes:\n    optimizer: tf.train.Optimizer\n    begin_step: int, first iteration where masks are updated.\n    end_step: int, iteration after which no mask is updated.\n    frequency: int, of mask update operations.\n    drop_fraction: float, of connections to drop during each update.\n    drop_fraction_anneal: str or None, if supplied used to anneal the drop\n      fraction.\n    use_locking: bool, passed to the super.\n    grow_init: str, name of the method used to initialize new connections.\n    momentum: float, for the exponentialy moving average.\n    name: bool, passed to the super.\n  \"\"\"\n\n  def __init__(self,\n               optimizer,\n               begin_step,\n               end_step,\n               frequency,\n               drop_fraction=0.1,\n               drop_fraction_anneal='constant',\n               use_locking=False,\n               grow_init='zeros',\n               name='SparseStaticOptimizer',\n               stateless_seed_offset=0):\n    super(SparseStaticOptimizer, self).__init__(\n        optimizer,\n        begin_step,\n        end_step,\n        frequency,\n        drop_fraction=drop_fraction,\n        drop_fraction_anneal=drop_fraction_anneal,\n        grow_init=grow_init,\n        use_locking=use_locking,\n        name=name,\n        stateless_seed_offset=stateless_seed_offset)\n\n  def generic_mask_update(self, mask, weights, noise_std=1e-5):\n    \"\"\"True branch of the condition, updates the mask.\"\"\"\n    # Ensure that the weights are masked.\n    masked_weights = mask * weights\n    score_drop = math_ops.abs(masked_weights)\n    # Add noise for slight bit of randomness.\n    score_drop += self._random_normal(\n        score_drop.shape,\n        stddev=noise_std,\n        dtype=score_drop.dtype,\n        seed=hash(weights.name + 'drop'))\n    # Revive n_prune many connections using momentum.\n    score_grow = mask\n    return self._get_update_op(\n        score_drop, score_grow, mask, weights, reinit_when_same=True)\n\n\nclass SparseMomentumOptimizer(SparseSETOptimizer):\n  \"\"\"Sparse optimizer that grows connections with the expected gradients.\n\n  A simplified implementation of Momentum based sparse optimizer. No\n  redistribution of sparsity.\n  Original implementation:\n  https://github.com/TimDettmers/sparse_learning/blob/master/mnist_cifar/main.py\n\n  Attributes:\n    optimizer: tf.train.Optimizer\n    begin_step: int, first iteration where masks are updated.\n    end_step: int, iteration after which no mask is updated.\n    frequency: int, of mask update operations.\n    drop_fraction: float, of connections to drop during each update.\n    drop_fraction_anneal: str or None, if supplied used to anneal the drop\n      fraction.\n    use_locking: bool, passed to the super.\n    grow_init: str, name of the method used to initialize new connections.\n    momentum: float, for the exponentialy moving average.\n    use_tpu: bool, if true the masked_gradients are aggregated.\n    name: bool, passed to the super.\n  \"\"\"\n\n  def __init__(self,\n               optimizer,\n               begin_step,\n               end_step,\n               frequency,\n               drop_fraction=0.1,\n               drop_fraction_anneal='constant',\n               use_locking=False,\n               grow_init='zeros',\n               momentum=0.9,\n               use_tpu=False,\n               name='SparseMomentumOptimizer',\n               stateless_seed_offset=0):\n    super(SparseMomentumOptimizer, self).__init__(\n        optimizer,\n        begin_step,\n        end_step,\n        frequency,\n        drop_fraction=drop_fraction,\n        drop_fraction_anneal=drop_fraction_anneal,\n        grow_init=grow_init,\n        use_locking=use_locking,\n        name='SparseMomentumOptimizer',\n        stateless_seed_offset=stateless_seed_offset)\n    self._ema_grads = moving_averages.ExponentialMovingAverage(decay=momentum)\n    self._use_tpu = use_tpu\n\n  def set_masked_grads(self, grads, weights):\n    if self._use_tpu:\n      grads = [tpu_ops.cross_replica_sum(g) for g in grads]\n    self._masked_grads = grads\n    # Using names since better to hash.\n    self._weight2masked_grads = {w.name: m for w, m in zip(weights, grads)}\n\n  def compute_gradients(self, loss, **kwargs):\n    \"\"\"Wraps the compute gradient of passed optimizer.\"\"\"\n    grads_and_vars = self._optimizer.compute_gradients(loss, **kwargs)\n    # Need to update the EMA of the masked_weights. This is a bit hacky and\n    # might not work as expected if the gradients are not applied after every\n    # calculation. However, it should be fine if only .minimize() call is used.\n    masked_grads_vars = self._optimizer.compute_gradients(\n        loss, var_list=self.get_masked_weights())\n    masked_grads = [g for g, _ in masked_grads_vars]\n    self.set_masked_grads(masked_grads, self.get_weights())\n    return grads_and_vars\n\n  def _before_apply_gradients(self, grads_and_vars):\n    \"\"\"Updates momentum before updating the weights with gradient.\"\"\"\n    return self._ema_grads.apply(self._masked_grads)\n\n  def generic_mask_update(self, mask, weights, noise_std=1e-5):\n    \"\"\"True branch of the condition, updates the mask.\"\"\"\n    # Ensure that the weights are masked.\n    casted_mask = math_ops.cast(mask, dtypes.float32)\n    masked_weights = casted_mask * weights\n    score_drop = math_ops.abs(masked_weights)\n    # Add noise for slight bit of randomness.\n    score_drop += self._random_normal(\n        score_drop.shape,\n        stddev=noise_std,\n        dtype=score_drop.dtype,\n        seed=hash(weights.name + 'drop'))\n    # Revive n_prune many connections using momentum.\n    masked_grad = self._weight2masked_grads[weights.name]\n    score_grow = math_ops.abs(self._ema_grads.average(masked_grad))\n    return self._get_update_op(score_drop, score_grow, mask, weights)\n\n\nclass SparseSnipOptimizer(tf_optimizer.Optimizer):\n  \"\"\"Implementation of dynamic sparsity optimizers.\n\n  Implementation of Snip\n  https://arxiv.org/abs/1810.02340\n\n  Attributes:\n    optimizer: tf.train.Optimizer\n    default_sparsity: float, between 0 and 1.\n    mask_init_method: str, used to determine mask initializations.\n    custom_sparsity_map: dict, <str, float> key/value pairs where the mask\n      correspond whose name is '{key}/mask:0' is set to the corresponding\n        sparsity value.\n    use_locking: bool, passed to the super.\n    use_tpu: bool, if true the masked_gradients are aggregated.\n    name: bool, passed to the super.\n  \"\"\"\n\n  def __init__(self,\n               optimizer,\n               default_sparsity,\n               mask_init_method,\n               custom_sparsity_map=None,\n               use_locking=False,\n               use_tpu=False,\n               name='SparseSnipOptimizer'):\n    super(SparseSnipOptimizer, self).__init__(use_locking, name)\n    if not custom_sparsity_map:\n      custom_sparsity_map = {}\n    self._optimizer = optimizer\n    self._use_tpu = use_tpu\n    self._default_sparsity = default_sparsity\n    self._mask_init_method = mask_init_method\n    self._custom_sparsity_map = custom_sparsity_map\n    self.is_snipped = variable_scope.get_variable(\n        'is_snipped', initializer=lambda: False, trainable=False)\n\n  def compute_gradients(self, loss, **kwargs):\n    \"\"\"Wraps the compute gradient of passed optimizer.\"\"\"\n    return self._optimizer.compute_gradients(loss, **kwargs)\n\n  def apply_gradients(self, grads_and_vars, global_step=None, name=None):\n    \"\"\"Wraps the original apply_gradient of the optimizer.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs as returned by\n        `compute_gradients()`.\n      global_step: Optional `Variable` to increment by one after the variables\n        have been updated.\n      name: Optional name for the returned operation.  Default to the name\n        passed to the `Optimizer` constructor.\n\n    Returns:\n      An `Operation` that applies the specified gradients. If `global_step`\n      was not None, that operation also increments `global_step`.\n    \"\"\"\n\n    def apply_gradient_op():\n      return self._optimizer.apply_gradients(\n          grads_and_vars, global_step=global_step, name=name)\n\n    maybe_reduce = lambda x: x\n    if self._use_tpu:\n      maybe_reduce = tpu_ops.cross_replica_sum\n    grads_and_vars_dict = {\n        re.findall('(.+)/weights:0', var.name)[0]: (maybe_reduce(grad), var)\n        for grad, var in grads_and_vars\n        if var.name.endswith('weights:0')\n    }\n\n    def snip_fn(mask, sparsity, dtype):\n      \"\"\"Creates a random sparse mask with deterministic sparsity.\n\n      Args:\n        mask: tf.Tensor, used to obtain correct corresponding gradient.\n        sparsity: float, between 0 and 1.\n        dtype: tf.dtype, type of the return value.\n\n      Returns:\n        tf.Tensor\n      \"\"\"\n      del dtype\n      var_name = sparse_utils.mask_extract_name_fn(mask.name)\n      g, v = grads_and_vars_dict[var_name]\n      score_drop = math_ops.abs(g * v)\n      n_total = np.prod(score_drop.shape.as_list())\n      n_prune = sparse_utils.get_n_zeros(n_total, sparsity)\n      n_keep = n_total - n_prune\n\n      # Sort the entire array since the k needs to be constant for TPU.\n      _, sorted_indices = nn_ops.top_k(\n          array_ops.reshape(score_drop, [-1]), k=n_total)\n      sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)\n      # We will have zeros after having `n_keep` many ones.\n      new_values = array_ops.where(\n          math_ops.range(n_total) < n_keep,\n          array_ops.ones_like(sorted_indices, dtype=mask.dtype),\n          array_ops.zeros_like(sorted_indices, dtype=mask.dtype))\n      new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values,\n                                      new_values.shape)\n      return array_ops.reshape(new_mask, mask.shape)\n\n    def snip_op():\n      all_masks = pruning.get_masks()\n      assigner = sparse_utils.get_mask_init_fn(\n          all_masks,\n          self._mask_init_method,\n          self._default_sparsity,\n          self._custom_sparsity_map,\n          mask_fn=snip_fn)\n      with ops.control_dependencies([assigner]):\n        assign_op = state_ops.assign(\n            self.is_snipped, True, name='assign_true_after_snipped')\n      return assign_op\n\n    maybe_snip_op = control_flow_ops.cond(\n        math_ops.logical_and(\n            math_ops.equal(global_step, 0),\n            math_ops.logical_not(self.is_snipped)), snip_op, apply_gradient_op)\n\n    return maybe_snip_op\n\n\nclass SparseDNWOptimizer(tf_optimizer.Optimizer):\n  \"\"\"Implementation of DNW optimizer.\n\n  Implementation of DNW.\n  See https://arxiv.org/pdf/1906.00586.pdf\n  This optimizer ensures the mask is updated at every iteration, according to\n  the current set of weights. It uses dense gradient to update weights.\n\n  Attributes:\n    optimizer: tf.train.Optimizer\n    default_sparsity: float, between 0 and 1.\n    mask_init_method: str, used to determine mask initializations.\n    custom_sparsity_map: dict, <str, float> key/value pairs where the mask\n      correspond whose name is '{key}/mask:0' is set to the corresponding\n        sparsity value.\n    use_tpu: bool, if true the masked_gradients are aggregated.\n    use_locking: bool, passed to the super.\n    name: bool, passed to the super.\n  \"\"\"\n\n  def __init__(self,\n               optimizer,\n               default_sparsity,\n               mask_init_method,\n               custom_sparsity_map=None,\n               use_tpu=False,\n               use_locking=False,\n               name='SparseDNWOptimizer'):\n    super(SparseDNWOptimizer, self).__init__(use_locking, name)\n    self._optimizer = optimizer\n    self._use_tpu = use_tpu\n    self._default_sparsity = default_sparsity\n    self._mask_init_method = mask_init_method\n    self._custom_sparsity_map = custom_sparsity_map\n\n  def compute_gradients(self, loss, var_list=None, **kwargs):\n    \"\"\"Wraps the compute gradient of passed optimizer.\"\"\"\n    # Replace masked variables with masked_weights so that the gradient is dense\n    # and not masked\n    if var_list is None:\n      var_list = (\n          variables.trainable_variables() +\n          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))\n    var_list = self.replace_with_masked_weights(var_list)\n    grads_and_vars = self._optimizer.compute_gradients(\n        loss, var_list=var_list, **kwargs)\n    return self.replace_masked_weights(grads_and_vars)\n\n  def replace_with_masked_weights(self, var_list):\n    \"\"\"Replaces masked variables with masked weights.\"\"\"\n    weight2masked_weights = {\n        w.name: mw\n        for w, mw in zip(self.get_weights(), self.get_masked_weights())\n    }\n    updated_var_list = [weight2masked_weights.get(w.name, w) for w in var_list]\n    return updated_var_list\n\n  def replace_masked_weights(self, grads_and_vars):\n    \"\"\"Replaces masked weight tensords with weight variables.\"\"\"\n    masked_weights2weight = {\n        mw.name: w\n        for w, mw in zip(self.get_weights(), self.get_masked_weights())\n    }\n    updated_grads_and_vars = [\n        (g, masked_weights2weight.get(w.name, w)) for g, w in grads_and_vars\n    ]\n    return updated_grads_and_vars\n\n  def apply_gradients(self, grads_and_vars, global_step=None, name=None):\n    \"\"\"Wraps the original apply_gradient of the optimizer.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs as returned by\n        `compute_gradients()`.\n      global_step: Optional `Variable` to increment by one after the variables\n        have been updated.\n      name: Optional name for the returned operation.  Default to the name\n        passed to the `Optimizer` constructor.\n\n    Returns:\n      An `Operation` that applies the specified gradients. If `global_step`\n      was not None, that operation also increments `global_step`.\n    \"\"\"\n    optimizer_update = self._optimizer.apply_gradients(\n        grads_and_vars, global_step=global_step, name=name)\n    vars_dict = {\n        re.findall('(.+)/weights:0', var.name)[0]: var\n        for var in self.get_weights()\n    }\n\n    def dnw_fn(mask, sparsity, dtype):\n      \"\"\"Creates a mask with smallest magnitudes with deterministic sparsity.\n\n      Args:\n        mask: tf.Tensor, used to obtain correct corresponding gradient.\n        sparsity: float, between 0 and 1.\n        dtype: tf.dtype, type of the return value.\n\n      Returns:\n        tf.Tensor\n      \"\"\"\n      del dtype\n      var_name = sparse_utils.mask_extract_name_fn(mask.name)\n      v = vars_dict[var_name]\n      score_drop = math_ops.abs(v)\n      n_total = np.prod(score_drop.shape.as_list())\n      n_prune = sparse_utils.get_n_zeros(n_total, sparsity)\n      n_keep = n_total - n_prune\n\n      # Sort the entire array since the k needs to be constant for TPU.\n      _, sorted_indices = nn_ops.top_k(\n          array_ops.reshape(score_drop, [-1]), k=n_total)\n      sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)\n      # We will have zeros after having `n_keep` many ones.\n      new_values = array_ops.where(\n          math_ops.range(n_total) < n_keep,\n          array_ops.ones_like(sorted_indices, dtype=mask.dtype),\n          array_ops.zeros_like(sorted_indices, dtype=mask.dtype))\n      new_mask = array_ops.scatter_nd(sorted_indices_ex, new_values,\n                                      new_values.shape)\n      return array_ops.reshape(new_mask, mask.shape)\n\n    with ops.control_dependencies([optimizer_update]):\n      all_masks = self.get_masks()\n      mask_update_op = sparse_utils.get_mask_init_fn(\n          all_masks,\n          self._mask_init_method,\n          self._default_sparsity,\n          self._custom_sparsity_map,\n          mask_fn=dnw_fn)\n\n    return mask_update_op\n\n  def get_weights(self):\n    return pruning.get_weights()\n\n  def get_masks(self):\n    return pruning.get_masks()\n\n  def get_masked_weights(self):\n    return pruning.get_masked_weights()\n"
  },
  {
    "path": "rigl/sparse_optimizers_base.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"This module implements some common and new sparse training algorithms.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\n\nimport six\n\n\nfrom tensorflow.python.framework import dtypes\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import array_ops\nfrom tensorflow.python.ops import control_flow_ops\nfrom tensorflow.python.ops import init_ops\nfrom tensorflow.python.ops import math_ops\nfrom tensorflow.python.ops import nn_ops\nfrom tensorflow.python.ops import random_ops\nfrom tensorflow.python.ops import state_ops\nfrom tensorflow.python.ops import stateless_random_ops\nfrom tensorflow.python.ops import variable_scope\nfrom tensorflow.python.tpu.ops import tpu_ops\nfrom tensorflow.python.training import learning_rate_decay\nfrom tensorflow.python.training import optimizer as tf_optimizer\nfrom tensorflow.python.training import training_util\n\n\n\ndef extract_number(token):\n  \"\"\"Strips the number from the end of the token if it exists.\n\n  Args:\n    token: str, s or s_d where d is a number: a float or int. `foo_.5`,\n      `foo_foo.5`, `foo_0.5`, `foo_4` are all valid strings.\n\n  Returns:\n    float, d if exists otherwise 1.\n  \"\"\"\n  regexp = re.compile(r'.*_(\\d*\\.?\\d*)$')\n  if regexp.search(token):\n    return float(regexp.search(token).group(1))\n  else:\n    return 1.\n\n\nclass SparseSETOptimizerBase(tf_optimizer.Optimizer):\n  \"\"\"Implementation of dynamic sparsity optimizers.\n\n  Implementation of SET.\n  See https://www.nature.com/articles/s41467-018-04316-3\n  This optimizer wraps a regular optimizer and performs updates on the masks\n  according to schedule given.\n\n  Attributes:\n    optimizer: tf.train.Optimizer\n    begin_step: int, first iteration where masks are updated.\n    end_step: int, iteration after which no mask is updated.\n    frequency: int, of mask update operations.\n    drop_fraction: float, of connections to drop during each update.\n    drop_fraction_anneal: str or None, if supplied used to anneal the drop\n      fraction.\n    use_locking: bool, passed to the super.\n    grow_init: str, name of the method used to initialize new connections.\n    name: bool, passed to the super.\n    use_stateless: bool, if True stateless operations are used. This is\n      important for multi-worker jobs not to diverge.\n    stateless_seed_offset: int, added to the seed of stateless operations. Use\n      this to create randomness without divergence across workers.\n  \"\"\"\n\n  def __init__(self,\n               optimizer,\n               begin_step,\n               end_step,\n               frequency,\n               drop_fraction=0.1,\n               drop_fraction_anneal='constant',\n               use_locking=False,\n               grow_init='zeros',\n               name='SparseSETOptimizer',\n               use_stateless=True,\n               stateless_seed_offset=0):\n    super(SparseSETOptimizerBase, self).__init__(use_locking, name)\n    self._optimizer = optimizer\n    self._grow_init = grow_init\n    self._drop_fraction_anneal = drop_fraction_anneal\n    self._drop_fraction_initial_value = ops.convert_to_tensor(\n        float(drop_fraction),\n        name='%s_drop_fraction' % self._drop_fraction_anneal)\n    self._begin_step = ops.convert_to_tensor(begin_step, name='begin_step')\n    self._end_step = ops.convert_to_tensor(end_step, name='end_step')\n    self._frequency = ops.convert_to_tensor(frequency, name='frequency')\n    self._frequency_val = frequency\n    self._use_stateless = use_stateless\n    self._stateless_seed_offset = stateless_seed_offset\n\n  def compute_gradients(self, loss, **kwargs):\n    \"\"\"Wraps the compute gradient of passed optimizer.\"\"\"\n    result = self._optimizer.compute_gradients(loss, **kwargs)\n    return result\n\n  def apply_gradients(self, grads_and_vars, global_step=None, name=None):\n    \"\"\"Wraps the original apply_gradient of the optimizer.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs as returned by\n        `compute_gradients()`.\n      global_step: Optional `Variable` to increment by one after the variables\n        have been updated.\n      name: Optional name for the returned operation.  Default to the name\n        passed to the `Optimizer` constructor.\n\n    Returns:\n      An `Operation` that applies the specified gradients. If `global_step`\n      was not None, that operation also increments `global_step`.\n    \"\"\"\n    pre_op = self._before_apply_gradients(grads_and_vars)\n    with ops.control_dependencies([pre_op]):\n      optimizer_update = self._optimizer.apply_gradients(\n          grads_and_vars, global_step=global_step, name=name)\n    # We get the default one after calling the super.apply_gradient(), since\n    # we want to preserve original behavior of the optimizer: don't increment\n    # anything if no global_step is passed. But we need the global step for\n    # the mask_update.\n    global_step = (\n        global_step if global_step is not None else\n        training_util.get_or_create_global_step())\n    self._global_step = global_step\n    with ops.control_dependencies([optimizer_update]):\n      return self.cond_mask_update_op(global_step, control_flow_ops.no_op)\n\n  def _before_apply_gradients(self, grads_and_vars):\n    \"\"\"Called before applying gradients.\"\"\"\n    return control_flow_ops.no_op('before_apply_grad')\n\n  def cond_mask_update_op(self, global_step, false_branch):\n    \"\"\"Creates the conditional mask update operation.\n\n    All masks are updated when it is an update iteration\n    (checked by self.is_mask_update_iter()).\n    Arguments:\n      global_step: tf.Variable, current training iteration.\n      false_branch: function, called when it is not a mask update iteration.\n\n    Returns:\n      conditional update operation\n    \"\"\"\n    # Initializing to -freq so that last_update_step+freq=0. This enables early\n    # mask_updates.\n    last_update_step = variable_scope.get_variable(\n        'last_mask_update_step', [],\n        initializer=init_ops.constant_initializer(\n            -self._frequency_val, dtype=global_step.dtype),\n        trainable=False,\n        dtype=global_step.dtype)\n\n    def mask_update_op():\n      update_ops = []\n      for mask, weights in zip(self.get_masks(), self.get_weights()):\n        update_ops.append(self.generic_mask_update(mask, weights))\n\n      with ops.control_dependencies(update_ops):\n        assign_op = state_ops.assign(\n            last_update_step, global_step, name='last_mask_update_step_assign')\n        with ops.control_dependencies([assign_op]):\n          return control_flow_ops.no_op('mask_update')\n\n    maybe_update = control_flow_ops.cond(\n        self.is_mask_update_iter(global_step, last_update_step), mask_update_op,\n        false_branch)\n    return maybe_update\n\n  def get_weights(self):\n    raise NotImplementedError\n\n  def get_masks(self):\n    raise NotImplementedError\n\n  def get_masked_weights(self):\n    raise NotImplementedError\n\n  def is_mask_update_iter(self, global_step, last_update_step):\n    \"\"\"Function for checking if the current step is a mask update step.\n\n    It also creates the drop_fraction op and assigns it to the self object.\n\n    Args:\n      global_step: tf.Variable(int), current training step.\n      last_update_step: tf.Variable(int), holding the last iteration the mask is\n        updated. Used to determine whether current iteration is a mask update\n        step.\n\n    Returns:\n      bool, whether the current iteration is a mask_update step.\n    \"\"\"\n    gs_dtype = global_step.dtype\n    self._begin_step = math_ops.cast(self._begin_step, gs_dtype)\n    self._end_step = math_ops.cast(self._end_step, gs_dtype)\n    self._frequency = math_ops.cast(self._frequency, gs_dtype)\n    is_step_within_update_range = math_ops.logical_and(\n        math_ops.greater_equal(global_step, self._begin_step),\n        math_ops.logical_or(\n            math_ops.less_equal(global_step, self._end_step),\n            # If _end_step is negative, we never stop updating the mask.\n            # In other words we update the mask with given frequency until the\n            # training ends.\n            math_ops.less(self._end_step, 0)))\n    is_update_step = math_ops.less_equal(\n        math_ops.add(last_update_step, self._frequency), global_step)\n    is_mask_update_iter_op = math_ops.logical_and(is_step_within_update_range,\n                                                  is_update_step)\n    self.drop_fraction = self.get_drop_fraction(global_step,\n                                                is_mask_update_iter_op)\n    return is_mask_update_iter_op\n\n  def get_drop_fraction(self, global_step, is_mask_update_iter_op):\n    \"\"\"Returns a constant or annealing drop_fraction op.\"\"\"\n    if self._drop_fraction_anneal == 'constant':\n      drop_frac = self._drop_fraction_initial_value\n    elif self._drop_fraction_anneal == 'cosine':\n      decay_steps = self._end_step - self._begin_step\n      drop_frac = learning_rate_decay.cosine_decay(\n          self._drop_fraction_initial_value,\n          global_step,\n          decay_steps,\n          name='cosine_drop_fraction')\n    elif self._drop_fraction_anneal.startswith('exponential'):\n      exponent = extract_number(self._drop_fraction_anneal)\n      div_dtype = self._drop_fraction_initial_value.dtype\n      power = math_ops.divide(\n          math_ops.cast(global_step - self._begin_step, div_dtype),\n          math_ops.cast(self._end_step - self._begin_step, div_dtype),\n      )\n      drop_frac = math_ops.multiply(\n          self._drop_fraction_initial_value,\n          math_ops.pow(1 - power, exponent),\n          name='%s_drop_fraction' % self._drop_fraction_anneal)\n    else:\n      raise ValueError('drop_fraction_anneal: %s is not valid' %\n                       self._drop_fraction_anneal)\n    return array_ops.where(is_mask_update_iter_op, drop_frac,\n                           array_ops.zeros_like(drop_frac))\n\n  def generic_mask_update(self, mask, weights, noise_std=1e-5):\n    \"\"\"True branch of the condition, updates the mask.\"\"\"\n    # Ensure that the weights are masked.\n    masked_weights = mask * weights\n    score_drop = math_ops.abs(masked_weights)\n    # Add noise for slight bit of randomness.\n    score_drop += self._random_normal(\n        score_drop.shape,\n        stddev=noise_std,\n        dtype=score_drop.dtype,\n        seed=(hash(weights.name + 'drop')))\n    # Randomly revive n_prune many connections from non-existing connections.\n    score_grow = self._random_uniform(\n        weights.shape, seed=hash(weights.name + 'grow'))\n    return self._get_update_op(score_drop, score_grow, mask, weights)\n\n  def _get_update_op(self,\n                     score_drop,\n                     score_grow,\n                     mask,\n                     weights,\n                     reinit_when_same=False):\n    \"\"\"Prunes+grows connections, all tensors same shape.\"\"\"\n    old_dtype = mask.dtype\n    mask_casted = math_ops.cast(mask, dtypes.float32)\n    n_total = array_ops.size(score_drop)\n    n_ones = math_ops.cast(math_ops.reduce_sum(mask_casted), dtype=dtypes.int32)\n    n_prune = math_ops.cast(\n        math_ops.cast(n_ones, dtype=dtypes.float32) * self.drop_fraction,\n        dtypes.int32)\n    n_keep = n_ones - n_prune\n\n    # Sort the entire array since the k needs to be constant for TPU.\n    _, sorted_indices = nn_ops.top_k(\n        array_ops.reshape(score_drop, [-1]), k=n_total)\n    sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)\n    # We will have zeros after having `n_keep` many ones.\n    new_values = array_ops.where(\n        math_ops.range(n_total) < n_keep,\n        array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype),\n        array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype))\n    mask1 = array_ops.scatter_nd(sorted_indices_ex, new_values,\n                                 new_values.shape)\n    # Flatten the scores\n    score_grow = array_ops.reshape(score_grow, [-1])\n    # Set scores of the enabled connections(ones) to min(s) - 1, so that they\n    # have the lowest scores.\n    score_grow_lifted = array_ops.where(\n        math_ops.equal(mask1, 1),\n        array_ops.ones_like(mask1) * (math_ops.reduce_min(score_grow) - 1),\n        score_grow)\n    _, sorted_indices = nn_ops.top_k(score_grow_lifted, k=n_total)\n    sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1)\n    new_values = array_ops.where(\n        math_ops.range(n_total) < n_prune,\n        array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype),\n        array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype))\n    mask2 = array_ops.scatter_nd(sorted_indices_ex, new_values,\n                                 new_values.shape)\n    # Ensure masks are disjoint.\n    assert_op = control_flow_ops.Assert(\n        math_ops.equal(math_ops.reduce_sum(mask1 * mask2), 0.), [mask1, mask2])\n\n    with ops.control_dependencies([assert_op]):\n      # Let's set the weights of the growed connections.\n      mask2_reshaped = array_ops.reshape(mask2, mask.shape)\n    # Set the values of the new connections.\n    grow_tensor = self.get_grow_tensor(weights, self._grow_init)\n    if reinit_when_same:\n      # If dropped and grown, we re-initialize.\n      new_connections = math_ops.equal(mask2_reshaped, 1)\n    else:\n      new_connections = math_ops.logical_and(\n          math_ops.equal(mask2_reshaped, 1), math_ops.equal(mask_casted, 0))\n    new_weights = array_ops.where(new_connections, grow_tensor, weights)\n    weights_update = state_ops.assign(weights, new_weights)\n    # Ensure there is no momentum value for new connections\n    reset_op = self.reset_momentum(weights, new_connections)\n\n    with ops.control_dependencies([weights_update, reset_op]):\n      mask_combined = array_ops.reshape(mask1 + mask2, mask.shape)\n    mask_combined = math_ops.cast(mask_combined, dtype=old_dtype)\n    new_mask = state_ops.assign(mask, mask_combined)\n    return new_mask\n\n  def reset_momentum(self, weights, new_connections):\n    reset_ops = []\n    for s_name in self._optimizer.get_slot_names():\n      # Momentum variable for example, we reset the aggregated values to zero.\n      optim_var = self._optimizer.get_slot(weights, s_name)\n      new_values = array_ops.where(new_connections,\n                                   array_ops.zeros_like(optim_var), optim_var)\n      reset_ops.append(state_ops.assign(optim_var, new_values))\n    return control_flow_ops.group(reset_ops)\n\n  def get_grow_tensor(self, weights, method):\n    \"\"\"Different ways to initialize new connections.\n\n    Args:\n      weights: tf.Tensor or Variable.\n      method: str, available options: 'zeros', 'random_normal', 'random_uniform'\n        and 'initial_value'\n\n    Returns:\n      tf.Tensor same shape and type as weights.\n\n    Raises:\n      ValueError, when the method is not valid.\n    \"\"\"\n    if not isinstance(method, six.string_types):\n      raise ValueError('Grow-Init: %s is not a string' % method)\n\n    if method == 'zeros':\n      grow_tensor = array_ops.zeros_like(weights, dtype=weights.dtype)\n    elif method.startswith('initial_dist'):\n      original_shape = weights.initial_value.shape\n      divisor = extract_number(method)\n      grow_tensor = array_ops.reshape(\n          random_ops.random_shuffle(\n              array_ops.reshape(weights.initial_value, [-1])),\n          original_shape) / divisor\n    elif method.startswith('random_normal'):\n      stddev = math_ops.reduce_std(weights)\n      divisor = extract_number(method)\n      grow_tensor = self._random_normal(\n          weights.shape,\n          stddev=stddev,\n          dtype=weights.dtype,\n          seed=hash(weights.name + 'grow_init_n')) / divisor\n    elif method.startswith('random_uniform'):\n      mean = math_ops.reduce_mean(math_ops.abs(weights))\n      divisor = extract_number(method)\n      grow_tensor = self._random_uniform(\n          weights.shape,\n          minval=-mean,\n          maxval=mean,\n          dtype=weights.dtype,\n          seed=hash(weights.name + 'grow_init_u')) / divisor\n    else:\n      raise ValueError('Grow-Init: %s is not a valid option.' % method)\n    return grow_tensor\n\n  def _random_uniform(self, *args, **kwargs):\n    if self._use_stateless:\n      c_seed = self._stateless_seed_offset + kwargs['seed']\n      kwargs['seed'] = math_ops.cast(\n          array_ops.stack([c_seed, self._global_step]), dtypes.int32)\n      return stateless_random_ops.stateless_random_uniform(*args, **kwargs)\n    else:\n      return random_ops.random_uniform(*args, **kwargs)\n\n  def _random_normal(self, *args, **kwargs):\n    if self._use_stateless:\n      c_seed = self._stateless_seed_offset + kwargs['seed']\n      kwargs['seed'] = math_ops.cast(\n          array_ops.stack([c_seed, self._global_step]), dtypes.int32)\n      return stateless_random_ops.stateless_random_normal(*args, **kwargs)\n    else:\n      return random_ops.random_normal(*args, **kwargs)\n\n\nclass SparseRigLOptimizerBase(SparseSETOptimizerBase):\n  \"\"\"Sparse optimizer that grows connections with the pre-removal gradients.\n\n  Attributes:\n    optimizer: tf.train.Optimizer\n    begin_step: int, first iteration where masks are updated.\n    end_step: int, iteration after which no mask is updated.\n    frequency: int, of mask update operations.\n    drop_fraction: float, of connections to drop during each update.\n    drop_fraction_anneal: str or None, if supplied used to anneal the drop\n      fraction.\n    use_locking: bool, passed to the super.\n    grow_init: str, name of the method used to initialize new connections.\n    init_avg_scale: float, used to scale the gradient when initializing the,\n      momentum values of new connections. We hope this will improve training,\n      compare to starting from 0 for the new connections. Set this to something\n      between 0 and 1 / (1 - momentum). This is because in the current\n      implementation of MomentumOptimizer, aggregated values converge to 1 / (1\n      - momentum) with constant gradients.\n    use_tpu: bool, if true the masked_gradients are aggregated.\n    name: bool, passed to the super.\n  \"\"\"\n\n  def __init__(self,\n               optimizer,\n               begin_step,\n               end_step,\n               frequency,\n               drop_fraction=0.1,\n               drop_fraction_anneal='constant',\n               use_locking=False,\n               grow_init='zeros',\n               initial_acc_scale=0.,\n               use_tpu=False,\n               name='SparseRigLOptimizer',\n               stateless_seed_offset=0):\n    super(SparseRigLOptimizerBase, self).__init__(\n        optimizer,\n        begin_step,\n        end_step,\n        frequency,\n        drop_fraction=drop_fraction,\n        drop_fraction_anneal=drop_fraction_anneal,\n        grow_init=grow_init,\n        use_locking=use_locking,\n        name='SparseRigLOptimizer',\n        stateless_seed_offset=stateless_seed_offset)\n    self._initial_acc_scale = initial_acc_scale\n    self._use_tpu = use_tpu\n\n  def set_masked_grads(self, grads, weights):\n    if self._use_tpu:\n      grads = [tpu_ops.cross_replica_sum(g) for g in grads]\n    self._masked_grads = grads\n    # Using names since better to hash.\n    self._weight2masked_grads = {w.name: m for w, m in zip(weights, grads)}\n\n  def compute_gradients(self, loss, **kwargs):\n    \"\"\"Wraps the compute gradient of passed optimizer.\"\"\"\n    grads_and_vars = self._optimizer.compute_gradients(loss, **kwargs)\n    masked_grads_vars = self._optimizer.compute_gradients(\n        loss, var_list=self.get_masked_weights())\n    masked_grads = [g for g, _ in masked_grads_vars]\n    self.set_masked_grads(masked_grads, self.get_weights())\n    return grads_and_vars\n\n  def apply_gradients(self, grads_and_vars, global_step=None, name=None):\n    \"\"\"Wraps the original apply_gradient of the optimizer.\n\n    Args:\n      grads_and_vars: List of (gradient, variable) pairs as returned by\n        `compute_gradients()`.\n      global_step: Optional `Variable` to increment by one after the variables\n        have been updated.\n      name: Optional name for the returned operation.  Default to the name\n        passed to the `Optimizer` constructor.\n\n    Returns:\n      An `Operation` that applies the specified gradients. If `global_step`\n      was not None, that operation also increments `global_step`.\n    \"\"\"\n    pre_op = self._before_apply_gradients(grads_and_vars)\n    with ops.control_dependencies([pre_op]):\n      # Call this to create slots.\n      _ = self._optimizer.apply_gradients(\n          grads_and_vars, global_step=global_step, name=name)\n\n      def apply_gradient_op():\n        optimizer_update = self._optimizer.apply_gradients(\n            grads_and_vars, global_step=global_step, name=name)\n        return optimizer_update\n\n      # We get the default one after calling the super.apply_gradient(), since\n      # we want to preserve original behavior of the optimizer: don't increment\n      # anything if no global_step is passed. But we need the global step for\n      # the mask_update.\n      global_step = (\n          global_step if global_step is not None else\n          training_util.get_or_create_global_step())\n      self._global_step = global_step\n      return self.cond_mask_update_op(global_step, apply_gradient_op)\n\n  def generic_mask_update(self, mask, weights, noise_std=1e-5):\n    \"\"\"True branch of the condition, updates the mask.\"\"\"\n    # Ensure that the weights are masked.\n    casted_mask = math_ops.cast(mask, dtype=dtypes.float32)\n    masked_weights = casted_mask * weights\n    score_drop = math_ops.abs(masked_weights)\n    # Add noise for slight bit of randomness.\n    score_drop += self._random_normal(\n        score_drop.shape,\n        stddev=noise_std,\n        dtype=score_drop.dtype,\n        seed=hash(weights.name + 'drop'))\n    # Revive n_prune many connections using gradient.\n    score_grow = math_ops.abs(self._weight2masked_grads[weights.name])\n    with ops.control_dependencies([score_grow]):\n      return self._get_update_op(score_drop, score_grow, mask, weights)\n\n  def get_grow_tensor(self, weights, method):\n    \"\"\"Returns initialization for grown weights.\"\"\"\n    if method.startswith('grad_scale'):\n      masked_grad = self._weight2masked_grads[weights.name]\n      divisor = extract_number(method)\n      grow_tensor = masked_grad / divisor\n    elif method.startswith('grad_sign'):\n      masked_grad_sign = math_ops.sign(self._weight2masked_grads[weights.name])\n      divisor = extract_number(method)\n      grow_tensor = masked_grad_sign / divisor\n    else:\n      grow_tensor = super(SparseRigLOptimizerBase,\n                          self).get_grow_tensor(weights, method)\n    return grow_tensor\n\n  def reset_momentum(self, weights, new_connections):\n    reset_ops = []\n    for s_name in self._optimizer.get_slot_names():\n      # Momentum variable for example, we reset the aggregated values to zero.\n      optim_var = self._optimizer.get_slot(weights, s_name)\n      accum_grad = (\n          self._weight2masked_grads[weights.name] * self._initial_acc_scale)\n      new_values = array_ops.where(new_connections, accum_grad, optim_var)\n      reset_ops.append(state_ops.assign(optim_var, new_values))\n    return control_flow_ops.group(reset_ops)\n"
  },
  {
    "path": "rigl/sparse_optimizers_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for the sparse_optimizers file.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport itertools\n\nfrom absl import flags\nfrom absl.testing import parameterized\nimport numpy as np\nfrom rigl import sparse_optimizers\nfrom rigl import sparse_utils\nimport tensorflow.compat.v1 as tf  # tf\n\nfrom tensorflow.contrib.model_pruning.python import pruning\nfrom tensorflow.contrib.model_pruning.python.layers import layers\n\n\nFLAGS = flags.FLAGS\n\n\nclass SparseSETOptimizerTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,\n                   freq_iter=2):\n    \"\"\"Setups a trivial training procedure for sparse training.\"\"\"\n    tf.reset_default_graph()\n    optim = tf.train.GradientDescentOptimizer(0.1)\n    sparse_optim = sparse_optimizers.SparseSETOptimizer(\n        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)\n    x = tf.random.uniform((1, n_inp))\n    y = layers.masked_fully_connected(x, n_out, activation_fn=None)\n    global_step = tf.train.get_or_create_global_step()\n    weight = pruning.get_weights()[0]\n    # There is one masked layer to be trained.\n    mask = pruning.get_masks()[0]\n    # Around half of the values of the mask is set to zero with `mask_update`.\n    mask_update = tf.assign(\n        mask,\n        tf.constant(\n            np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]),\n            dtype=tf.float32))\n    loss = tf.reduce_mean(y)\n    global_step = tf.train.get_or_create_global_step()\n    train_op = sparse_optim.minimize(loss, global_step)\n\n    # Init\n    sess = tf.Session()\n    init = tf.global_variables_initializer()\n    sess.run(init)\n    sess.run([mask_update])\n\n    return sess, train_op, mask, weight, global_step\n\n  @parameterized.parameters((15, 25, 0.5), (15, 25, 0.2), (3, 5, 0.2))\n  def testMaskNonUpdateIterations(self, n_inp, n_out, drop_frac):\n    \"\"\"Training a layer for 5 iterations and see whether mask is kept intact.\n\n    The mask should be updated only in iterations 1 and 3 (since start_iter=1,\n    end_iter=4, freq_iter=2).\n\n    Args:\n      n_inp: int, number of input channels.\n      n_out: int, number of output channels\n      drop_frac: float, passed to the sparse optimizer.\n    \"\"\"\n    sess, train_op, mask, _, _ = self._setup_graph(\n        n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2)\n    expected_updates = [1, 3]\n    # Running 5 times to make sure the mask is not updated after end_iter.\n    for i in range(1, 6):\n      c_mask, = sess.run([mask])\n      sess.run([train_op])\n      c_mask2, = sess.run([mask])\n      if i not in expected_updates:\n        self.assertAllEqual(c_mask, c_mask2)\n\n  @parameterized.parameters((15, 25, 0.5), (15, 25, 0.7), (30, 10, 0.9))\n  def testUpdateIterations(self, n_inp, n_out, drop_frac):\n    \"\"\"Checking whether the mask is updated during correct iterations.\n\n    The mask should be updated only in iterations 1 and 3 (since start_iter=1,\n    end_iter=4, freq_iter=2). Number of 1's in the mask should be equal.\n\n    Args:\n      n_inp: int, number of input channels.\n      n_out: int, number of output channels\n      drop_frac: float, passed to the sparse optimizer.\n    \"\"\"\n    sess, train_op, mask, _, _ = self._setup_graph(\n        n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2)\n    expected_updates = [1, 3]\n    # Running 4 times since last update is at 3.\n    for i in range(1, 5):\n      c_mask, = sess.run([mask])\n      sess.run([train_op])\n      c_mask2, = sess.run([mask])\n      if i in expected_updates:\n        # Number of ones (connections) should be same.\n        self.assertEqual(c_mask.sum(), c_mask2.sum())\n        # Assert there is some change in the mask.\n        self.assertNotAllClose(c_mask, c_mask2)\n\n  @parameterized.parameters((3, 7, 2), (1, 5, 3), (0, 4, 1))\n  def testNoDrop(self, start_iter, end_iter, freq_iter):\n    \"\"\"Checks when the drop fraction is 0, no update is made.\n\n    The mask should be updated only in iterations 1 and 3 (since start_iter=1,\n    end_iter=4, freq_iter=2). Number of 1's in the mask should be equal.\n\n    Args:\n      start_iter: int, start iteration for sparse training.\n      end_iter: int, final iteration for sparse training.\n      freq_iter: int, mask update frequency.\n    \"\"\"\n    # Setting drop_fraction to 0; so there is nothing dropped, nothing changed.\n    sess, train_op, mask, _, _ = self._setup_graph(\n        3, 5, 0, start_iter=start_iter, end_iter=end_iter, freq_iter=freq_iter)\n    for _ in range(end_iter+2):\n      c_mask, = sess.run([mask])\n      sess.run([train_op])\n      c_mask2, = sess.run([mask])\n      self.assertAllEqual(c_mask, c_mask2)\n\n  def testNewConnectionZeroInit(self):\n    \"\"\"Checks whether the new connections are initialized correctly to zeros.\n    \"\"\"\n    end_iter = 4\n    sess, train_op, mask, weight, _ = self._setup_graph(\n        n_inp=3, n_out=5, drop_frac=0.5, start_iter=0, end_iter=end_iter,\n        freq_iter=1)\n    # Let's iterate until the mask updates are done.\n    for _ in range(end_iter + 1):\n      mask_tensor, = sess.run([mask])\n      sess.run([train_op])\n      new_mask_tensor, new_weight_tensor = sess.run([mask, weight])\n      # Let's sum the values of the new connections\n      new_weights = new_weight_tensor[np.logical_and(mask_tensor == 0,\n                                                     new_mask_tensor == 1)]\n      self.assertTrue(np.all(new_weights == 0))\n\n  @parameterized.parameters(itertools.product(\n      ((3, 7, 2), (5, 3), (1,)), ('zeros', 'random_normal', 'random_uniform')))\n  def testShapeOfGetGrowTensor(self, shape, init_type):\n    \"\"\"Checks whether the new tensor is created with correct shape.\"\"\"\n    optim = tf.train.GradientDescentOptimizer(0.1)\n    sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1,\n                                                        use_stateless=False)\n    weights = tf.random_uniform(shape)\n    grow_tensor = sparse_optim.get_grow_tensor(weights, init_type)\n    self.assertAllEqual(weights.shape, grow_tensor.shape)\n\n  @parameterized.parameters(itertools.product(\n      (tf.float32, tf.float64),\n      ('zeros', 'random_normal', 'random_uniform')))\n  def testDtypeOfGetGrowTensor(self, dtype, init_type):\n    \"\"\"Checks whether the new tensor is created with correct data type.\"\"\"\n    optim = tf.train.GradientDescentOptimizer(0.1)\n    sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1,\n                                                        use_stateless=False)\n    weights = tf.random_uniform((3, 4), dtype=dtype, maxval=5)\n    grow_tensor = sparse_optim.get_grow_tensor(weights, init_type)\n    self.assertEqual(grow_tensor.dtype, weights.dtype)\n\n  @parameterized.parameters('ones', 'zero', None, 0)\n  def testValueErrorOfGetGrowTensor(self, method):\n    \"\"\"Checks whether the new tensor is created with correct shape and type.\"\"\"\n    optim = tf.train.GradientDescentOptimizer(0.1)\n    sparse_optim = sparse_optimizers.SparseSETOptimizer(optim, 0, 0, 1,\n                                                        use_stateless=False)\n    weights = tf.random_uniform((3, 4))\n    with self.assertRaises(ValueError):\n      sparse_optim.get_grow_tensor(weights, method)\n\n\nclass SparseStaticOptimizerTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,\n                   freq_iter=2):\n    \"\"\"Setups a trivial training procedure for sparse training.\"\"\"\n    tf.reset_default_graph()\n    optim = tf.train.GradientDescentOptimizer(0.1)\n    sparse_optim = sparse_optimizers.SparseStaticOptimizer(\n        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)\n    x = tf.random.uniform((1, n_inp))\n    y = layers.masked_fully_connected(x, n_out, activation_fn=None)\n    global_step = tf.train.get_or_create_global_step()\n    weight = pruning.get_weights()[0]\n    # There is one masked layer to be trained.\n    mask = pruning.get_masks()[0]\n    # Around half of the values of the mask is set to zero with `mask_update`.\n    mask_update = tf.assign(\n        mask,\n        tf.constant(\n            np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]),\n            dtype=tf.float32))\n    loss = tf.reduce_mean(y)\n    global_step = tf.train.get_or_create_global_step()\n    train_op = sparse_optim.minimize(loss, global_step)\n\n    # Init\n    sess = tf.Session()\n    init = tf.global_variables_initializer()\n    sess.run(init)\n    sess.run([mask_update])\n\n    return sess, train_op, mask, weight, global_step\n\n  @parameterized.parameters((15, 25, 0.5), (15, 25, 0.2), (3, 5, 0.2))\n  def testMaskStatic(self, n_inp, n_out, drop_frac):\n    \"\"\"Training a layer for 5 iterations and see whether mask is kept intact.\n\n    The mask should be updated only in iterations 1 and 3 (since start_iter=1,\n    end_iter=4, freq_iter=2).\n\n    Args:\n      n_inp: int, number of input channels.\n      n_out: int, number of output channels\n      drop_frac: float, passed to the sparse optimizer.\n    \"\"\"\n    sess, train_op, mask, _, _ = self._setup_graph(\n        n_inp, n_out, drop_frac, start_iter=1, end_iter=4, freq_iter=2)\n    # Running 5 times to make sure the mask is not updated after end_iter.\n    for _ in range(5):\n      c_mask, = sess.run([mask])\n      sess.run([train_op])\n      c_mask2, = sess.run([mask])\n      self.assertAllEqual(c_mask, c_mask2)\n\n\nclass SparseMomentumOptimizerTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,\n                   freq_iter=2, momentum=0.5):\n    \"\"\"Setups a trivial training procedure for sparse training.\"\"\"\n    tf.reset_default_graph()\n    optim = tf.train.GradientDescentOptimizer(0.1)\n    sparse_optim = sparse_optimizers.SparseMomentumOptimizer(\n        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac,\n        momentum=momentum)\n    x = tf.ones((1, n_inp))\n    y = layers.masked_fully_connected(x, n_out, activation_fn=None)\n    # Multiplying the output with range of constants to have constant but\n    # different gradients at the masked weights.\n    y = y * tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape)\n    loss = tf.reduce_sum(y)\n    global_step = tf.train.get_or_create_global_step()\n    train_op = sparse_optim.minimize(loss, global_step)\n    weight = pruning.get_weights()[0]\n    masked_grad = sparse_optim._weight2masked_grads[weight.name]\n    masked_grad_ema = sparse_optim._ema_grads.average(masked_grad)\n    # Init\n    sess = tf.Session()\n    init = tf.global_variables_initializer()\n    sess.run(init)\n\n    return sess, train_op, masked_grad_ema\n\n  @parameterized.parameters((3, 4, 0.5), (5, 2, 0.), (2, 5, 1.))\n  def testMomentumUpdate(self, n_inp, n_out, momentum):\n    \"\"\"Checking whether momentum applied correctly.\"\"\"\n    sess, train_op, masked_grad_ema = self._setup_graph(\n        n_inp, n_out, 0.5, start_iter=1, end_iter=4, freq_iter=2,\n        momentum=momentum)\n\n    # Running 6 times to make sure the momeuntum is always updated.\n    current_momentum = np.zeros((n_inp, n_out))\n    for _ in range(6):\n      ema_masked_grad, = sess.run([masked_grad_ema])\n      self.assertAllEqual(ema_masked_grad, current_momentum)\n      sess.run([train_op])\n      # This is since we multiply the output values with range(n_out)\n      # Note the broadcast from n_out vector to (n_inp, n_out) matrix.\n      current_momentum = (current_momentum * momentum +\n                          (1 - momentum) * np.arange(n_out))\n\n      ema_masked_grad, = sess.run([masked_grad_ema])\n      self.assertAllEqual(ema_masked_grad, current_momentum)\n\n\nclass SparseRigLOptimizerTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,\n                   freq_iter=2):\n    \"\"\"Setups a trivial training procedure for sparse training.\"\"\"\n    tf.reset_default_graph()\n    optim = tf.train.GradientDescentOptimizer(1e-3)\n    global_step = tf.train.get_or_create_global_step()\n    sparse_optim = sparse_optimizers.SparseRigLOptimizer(\n        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)\n    x = tf.ones((1, n_inp))\n    y = layers.masked_fully_connected(x, n_out, activation_fn=None)\n    # Multiplying the output with range of constants to have constant but\n    # different gradients at the masked weights. We also multiply the loss with\n    # global_step to increase the gradient linearly with time.\n    scale_vector = (\n        tf.reshape(tf.cast(tf.range(tf.size(y)), dtype=y.dtype), y.shape) *\n        tf.cast(global_step, dtype=y.dtype))\n    y = y * scale_vector\n    loss = tf.reduce_sum(y)\n    global_step = tf.train.get_or_create_global_step()\n    train_op = sparse_optim.minimize(loss, global_step)\n    weight = pruning.get_weights()[0]\n    expected_gradient = tf.broadcast_to(scale_vector, weight.shape)\n    masked_grad = sparse_optim._weight2masked_grads[weight.name]\n\n    # Init\n    sess = tf.Session()\n    init = tf.global_variables_initializer()\n    sess.run(init)\n\n    return sess, train_op, masked_grad, expected_gradient\n\n  @parameterized.parameters((3, 4), (5, 2), (2, 5))\n  def testMaskedGradientCalculation(self, n_inp, n_out):\n    \"\"\"Checking whether masked_grad is calculated after apply_gradients.\"\"\"\n    # No drop since we don't want to change the mask but check whether the grad\n    # is calculated after the gradient step.\n    sess, train_op, masked_grad, expected_gradient = self._setup_graph(\n        n_inp, n_out, 0., start_iter=0, end_iter=3, freq_iter=1)\n    # Since we only update the mask every 2 iterations, we will iterate 6 times.\n\n    for i in range(6):\n      is_mask_update = i % 2 == 0\n      if is_mask_update:\n        expected_gradient_tensor, = sess.run([expected_gradient])\n        _, masked_grad_tensor = sess.run([train_op, masked_grad])\n        self.assertAllEqual(masked_grad_tensor,\n                            expected_gradient_tensor)\n      else:\n        sess.run([train_op])\n\n  @parameterized.parameters(\n      (3, 7, 2, [1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1]),\n      (1, 5, 3, [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1]),\n      (0, 4, 1, [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]))\n  def testApplyGradients(self, start_iter, end_iter, freq_iter, is_incremented):\n    \"\"\"Checking  apply_gradient is called in non mask update iterations.\"\"\"\n    sess, train_op, _, _ = self._setup_graph(\n        3, 5, .5, start_iter=start_iter, end_iter=end_iter, freq_iter=freq_iter)\n    global_step = tf.train.get_or_create_global_step()\n    # Since we only update the mask every 2 iterations, we will iterate 6 times.\n    for one_if_incremented in is_incremented:\n      before, = sess.run([global_step])\n      sess.run([train_op])\n      after, = sess.run([global_step])\n      if one_if_incremented == 1:\n        self.assertEqual(before + 1, after)\n      else:\n        # Mask update step.\n        self.assertEqual(before, after)\n\n\nclass SparseSnipOptimizerTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_graph(self, default_sparsity, mask_init_method,\n                   custom_sparsity_map, n_inp=3, n_out=5):\n    \"\"\"Setups a trivial training procedure for sparse training.\"\"\"\n    tf.reset_default_graph()\n    optim = tf.train.GradientDescentOptimizer(1e-3)\n    sparse_optim = sparse_optimizers.SparseSnipOptimizer(\n        optim, default_sparsity, mask_init_method,\n        custom_sparsity_map=custom_sparsity_map)\n\n    inp_values = np.arange(1, n_inp+1)\n    scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5\n    # The gradient is the outer product of input and the output gradients.\n    # Since the loss is sample sum the output gradient is equal to the scale\n    # vector.\n    expected_grads = np.outer(inp_values, scale_vector_values)\n\n    x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp))\n    y = layers.masked_fully_connected(x, n_out, activation_fn=None)\n    scale_vector = tf.constant(scale_vector_values, dtype=tf.float32)\n\n    y = y * scale_vector\n    loss = tf.reduce_sum(y)\n\n    global_step = tf.train.get_or_create_global_step()\n    train_op = sparse_optim.minimize(loss, global_step)\n\n    # Init\n    sess = tf.Session()\n    init = tf.global_variables_initializer()\n    sess.run(init)\n    mask = pruning.get_masks()[0]\n    weights = pruning.get_weights()[0]\n    return sess, train_op, expected_grads, sparse_optim, mask, weights\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testSnipSparsity(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether masked_grad is calculated after apply_gradients.\"\"\"\n    # No drop since we don't want to change the mask but check whether the grad\n    # is calculated after the gradient step.\n    sess, train_op, _, _, mask, _ = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    _ = sess.run([train_op])\n    snipped_mask, = sess.run([mask])\n    n_ones = np.sum(snipped_mask)\n    n_zeros = snipped_mask.size - n_ones\n    n_zeros_expected = sparse_utils.get_n_zeros(snipped_mask.size,\n                                                default_sparsity)\n    self.assertEqual(n_zeros, n_zeros_expected)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testGradientUsed(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether masked_grad is calculated after apply_gradients.\"\"\"\n    # No drop since we don't want to change the mask but check whether the grad\n    # is calculated after the gradient step.\n    sess, train_op, expected_grads, _, mask, weights = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    # Calculate sensitivity scores.\n    weights, = sess.run([weights])\n    expected_scores = np.abs(expected_grads*weights)\n    _ = sess.run([train_op])\n    snipped_mask, = sess.run([mask])\n    kept_connection_scores = expected_scores[snipped_mask == 1]\n    min_score_kept = np.min(kept_connection_scores)\n\n    snipped_connection_scores = expected_scores[snipped_mask == 0]\n    max_score_snipped = np.max(snipped_connection_scores)\n    self.assertLessEqual(max_score_snipped, min_score_kept)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testInitialMaskIsDense(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether masked_grad is calculated after apply_gradients.\"\"\"\n    # No drop since we don't want to change the mask but check whether the grad\n    # is calculated after the gradient step.\n    sess, _, _, _, mask, _ = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    mask_start, = sess.run([mask])\n    self.assertEqual(np.sum(mask_start), mask_start.size)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testAfterSnipTraining(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether masked_grad is calculated after apply_gradients.\"\"\"\n    # No drop since we don't want to change the mask but check whether the grad\n    # is calculated after the gradient step.\n    sess, train_op, _, sparse_optim, mask, _ = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    global_step = tf.train.get_or_create_global_step()\n    is_snip_iter = sess.run([train_op])\n    self.assertTrue(is_snip_iter)\n    # On other iterations mask should stay same. Let's do 3 more iterations.\n    for i in range(3):\n      mask_before, c_iter = sess.run([mask, global_step])\n      self.assertEqual(i, c_iter)\n      is_snip_iter, is_snipped = sess.run([train_op, sparse_optim.is_snipped])\n      self.assertTrue(is_snipped)\n      self.assertFalse(is_snip_iter)\n      mask_after, = sess.run([mask])\n      self.assertAllEqual(mask_after, mask_before)\n\n\nclass SparseDNWOptimizerTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_graph(self,\n                   default_sparsity,\n                   mask_init_method,\n                   custom_sparsity_map,\n                   n_inp=3,\n                   n_out=5):\n    \"\"\"Setups a trivial training procedure for sparse training.\"\"\"\n    tf.reset_default_graph()\n    optim = tf.train.GradientDescentOptimizer(1e-3)\n    sparse_optim = sparse_optimizers.SparseDNWOptimizer(\n        optim,\n        default_sparsity,\n        mask_init_method,\n        custom_sparsity_map=custom_sparsity_map)\n\n    inp_values = np.arange(1, n_inp + 1)\n    scale_vector_values = np.random.uniform(size=(n_out,)) - 0.5\n    # The gradient is the outer product of input and the output gradients.\n    # Since the loss is sample sum the output gradient is equal to the scale\n    # vector.\n    expected_grads = np.outer(inp_values, scale_vector_values)\n\n    x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp))\n    y = layers.masked_fully_connected(x, n_out, activation_fn=None)\n    scale_vector = tf.constant(scale_vector_values, dtype=tf.float32)\n\n    y = y * scale_vector\n    loss = tf.reduce_sum(y)\n\n    global_step = tf.train.get_or_create_global_step()\n    grads_and_vars = sparse_optim.compute_gradients(loss)\n    train_op = sparse_optim.apply_gradients(\n        grads_and_vars, global_step=global_step)\n    # Init\n    sess = tf.Session()\n    init = tf.global_variables_initializer()\n    sess.run(init)\n    mask = pruning.get_masks()[0]\n    weights = pruning.get_weights()[0]\n    return (sess, train_op, (expected_grads, grads_and_vars), mask, weights)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testDNWSparsity(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether masked_grad is calculated after apply_gradients.\"\"\"\n    # No drop since we don't want to change the mask but check whether the grad\n    # is calculated after the gradient step.\n    sess, train_op, _, mask, _ = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    _ = sess.run([train_op])\n    dnw_mask, = sess.run([mask])\n    n_ones = np.sum(dnw_mask)\n    n_zeros = dnw_mask.size - n_ones\n    n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size, default_sparsity)\n    self.assertEqual(n_zeros, n_zeros_expected)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testWeightsUsed(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether masked_grad is calculated after apply_gradients.\"\"\"\n    # No drop since we don't want to change the mask but check whether the grad\n    # is calculated after the gradient step.\n    sess, train_op, _, mask, weights = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    # Calculate sensitivity scores.\n    weights, = sess.run([weights])\n    expected_scores = np.abs(weights)\n    _ = sess.run([train_op])\n    dnw_mask, = sess.run([mask])\n    kept_connection_scores = expected_scores[dnw_mask == 1]\n    min_score_kept = np.min(kept_connection_scores)\n\n    dnw_mask_connection_scores = expected_scores[dnw_mask == 0]\n    max_score_removed = np.max(dnw_mask_connection_scores)\n    self.assertLessEqual(max_score_removed, min_score_kept)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testGradientIsDense(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether calculated gradients are dense.\"\"\"\n    sess, _, grad_info, _, _ = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    expected_grad, grads_and_vars = grad_info\n    grad, = sess.run([grads_and_vars[0][0]])\n    self.assertAllClose(expected_grad, grad)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testDNWUpdates(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether mask is updated correctly.\"\"\"\n    sess, train_op, _, mask, weights = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    # On all iterations mask should have least magnitude connections.\n    for _ in range(5):\n      sess.run([train_op])\n      mask_after, weights_after = sess.run([mask, weights])\n\n      kept_connection_magnitudes = np.abs(weights_after[mask_after == 1])\n      min_score_kept = np.min(kept_connection_magnitudes)\n\n      removed_connection_magnitudes = np.abs(weights_after[mask_after == 0])\n      max_score_removed = np.max(removed_connection_magnitudes)\n      self.assertLessEqual(max_score_removed, min_score_kept)\n\n  @parameterized.parameters((3, 4, 0.5), (5, 3, 0.8), (8, 5, 0.8))\n  def testSparsityAfterDNWUpdates(self, n_inp, n_out, default_sparsity):\n    \"\"\"Checking whether mask is updated correctly.\"\"\"\n    sess, train_op, _, mask, _ = self._setup_graph(\n        default_sparsity, 'random', {}, n_inp=n_inp, n_out=n_out)\n    # On all iterations mask should have least magnitude connections.\n    for _ in range(5):\n      sess.run([train_op])\n      dnw_mask, = sess.run([mask])\n      n_ones = np.sum(dnw_mask)\n      n_zeros = dnw_mask.size - n_ones\n      n_zeros_expected = sparse_utils.get_n_zeros(dnw_mask.size,\n                                                  default_sparsity)\n      self.assertEqual(n_zeros, n_zeros_expected)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "rigl/sparse_utils.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"This module has helper functions for the interpolation experiments.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\nimport numpy as np\nfrom rigl import str_sparsities\nimport tensorflow.compat.v1 as tf\nfrom google_research.micronet_challenge import counting\n\nDEFAULT_ERK_SCALE = 1.0\n\n\ndef mask_extract_name_fn(mask_name):\n  return re.findall('(.+)/mask:0', mask_name)[0]\n\n\ndef get_n_zeros(size, sparsity):\n  return int(np.floor(sparsity * size))\n\n\ndef calculate_sparsity(masks):\n  dense_params = tf.constant(0.)\n  sparse_params = tf.constant(0.)\n  for mask in masks:\n    dense_params += tf.cast(tf.size(mask), dtype=dense_params.dtype)\n    sparse_params += tf.cast(tf.reduce_sum(mask), dtype=sparse_params.dtype)\n  return 1. - sparse_params / dense_params\n\n\ndef get_mask_random_numpy(mask_shape, sparsity, random_state=None):\n  \"\"\"Creates a random sparse mask with deterministic sparsity.\n\n  Args:\n    mask_shape: list, used to obtain shape of the random mask.\n    sparsity: float, between 0 and 1.\n    random_state: np.random.RandomState, if given the shuffle call is made using\n      the RandomState\n\n  Returns:\n    numpy.ndarray\n  \"\"\"\n  flat_ones = np.ones(mask_shape).flatten()\n  n_zeros = get_n_zeros(flat_ones.size, sparsity)\n  flat_ones[:n_zeros] = 0\n  if random_state:\n    random_state.shuffle(flat_ones)\n  else:\n    np.random.shuffle(flat_ones)\n  new_mask = flat_ones.reshape(mask_shape)\n  return new_mask\n\n\ndef get_mask_random(mask, sparsity, dtype, random_state=None):\n  \"\"\"Creates a random sparse mask with deterministic sparsity.\n\n  Args:\n    mask: tf.Tensor, used to obtain shape of the random mask.\n    sparsity: float, between 0 and 1.\n    dtype: tf.dtype, type of the return value.\n    random_state: np.random.RandomState, if given the shuffle call is made using\n      the RandomState\n\n  Returns:\n    tf.Tensor\n  \"\"\"\n  new_mask_numpy = get_mask_random_numpy(\n      mask.shape.as_list(), sparsity, random_state=random_state)\n  new_mask = tf.constant(new_mask_numpy, dtype=dtype)\n  return new_mask\n\n\ndef get_sparsities_erdos_renyi(all_masks,\n                               default_sparsity,\n                               custom_sparsity_map,\n                               include_kernel,\n                               extract_name_fn=mask_extract_name_fn,\n                               erk_power_scale=DEFAULT_ERK_SCALE):\n  \"\"\"Given the method, returns the sparsity of individual layers as a dict.\n\n  It ensures that the non-custom layers have a total parameter count as the one\n  with uniform sparsities. In other words for the layers which are not in the\n  custom_sparsity_map the following equation should be satisfied.\n\n  # eps * (p_1 * N_1 + p_2 * N_2) = (1 - default_sparsity) * (N_1 + N_2)\n  Args:\n    all_masks: list, of all mask Variables.\n    default_sparsity: float, between 0 and 1.\n    custom_sparsity_map: dict, <str, float> key/value pairs where the mask\n      correspond whose name is '{key}/mask:0' is set to the corresponding\n        sparsity value.\n    include_kernel: bool, if True kernel dimension are included in the scaling.\n    extract_name_fn: function, extracts the variable name.\n    erk_power_scale: float, if given used to take power of the ratio. Use\n      scale<1 to make the erdos_renyi softer.\n\n  Returns:\n    sparsities, dict of where keys() are equal to all_masks and individiual\n      masks are mapped to the their sparsities.\n  \"\"\"\n  # We have to enforce custom sparsities and then find the correct scaling\n  # factor.\n\n  is_eps_valid = False\n  # # The following loop will terminate worst case when all masks are in the\n  # custom_sparsity_map. This should probably never happen though, since once\n  # we have a single variable or more with the same constant, we have a valid\n  # epsilon. Note that for each iteration we add at least one variable to the\n  # custom_sparsity_map and therefore this while loop should terminate.\n  dense_layers = set()\n  while not is_eps_valid:\n    # We will start with all layers and try to find right epsilon. However if\n    # any probablity exceeds 1, we will make that layer dense and repeat the\n    # process (finding epsilon) with the non-dense layers.\n    # We want the total number of connections to be the same. Let say we have\n    # for layers with N_1, ..., N_4 parameters each. Let say after some\n    # iterations probability of some dense layers (3, 4) exceeded 1 and\n    # therefore we added them to the dense_layers set. Those layers will not\n    # scale with erdos_renyi, however we need to count them so that target\n    # paratemeter count is achieved. See below.\n    # eps * (p_1 * N_1 + p_2 * N_2) + (N_3 + N_4) =\n    #    (1 - default_sparsity) * (N_1 + N_2 + N_3 + N_4)\n    # eps * (p_1 * N_1 + p_2 * N_2) =\n    #    (1 - default_sparsity) * (N_1 + N_2) - default_sparsity * (N_3 + N_4)\n    # eps = rhs / (\\sum_i p_i * N_i) = rhs / divisor.\n\n    divisor = 0\n    rhs = 0\n    raw_probabilities = {}\n    for mask in all_masks:\n      var_name = extract_name_fn(mask.name)\n      shape_list = mask.shape.as_list()\n      n_param = np.prod(shape_list)\n      n_zeros = get_n_zeros(n_param, default_sparsity)\n      if var_name in dense_layers:\n        # See `- default_sparsity * (N_3 + N_4)` part of the equation above.\n        rhs -= n_zeros\n      elif var_name in custom_sparsity_map:\n        # We ignore custom_sparsities in erdos-renyi calculations.\n        pass\n      else:\n        # Corresponds to `(1 - default_sparsity) * (N_1 + N_2)` part of the\n        # equation above.\n        n_ones = n_param - n_zeros\n        rhs += n_ones\n        # Erdos-Renyi probability: epsilon * (n_in + n_out / n_in * n_out).\n        if include_kernel:\n          raw_probabilities[mask.name] = (np.sum(shape_list) /\n                                          np.prod(shape_list))**erk_power_scale\n        else:\n          n_in, n_out = shape_list[-2:]\n          raw_probabilities[mask.name] = (n_in + n_out) / (n_in * n_out)\n        # Note that raw_probabilities[mask] * n_param gives the individual\n        # elements of the divisor.\n        divisor += raw_probabilities[mask.name] * n_param\n    # By multipliying individual probabilites with epsilon, we should get the\n    # number of parameters per layer correctly.\n    eps = rhs / divisor\n    # If eps * raw_probabilities[mask.name] > 1. We set the sparsities of that\n    # mask to 0., so they become part of dense_layers sets.\n    max_prob = np.max(list(raw_probabilities.values()))\n    max_prob_one = max_prob * eps\n    if max_prob_one > 1:\n      is_eps_valid = False\n      for mask_name, mask_raw_prob in raw_probabilities.items():\n        if mask_raw_prob == max_prob:\n          var_name = extract_name_fn(mask_name)\n          tf.logging.info('Sparsity of var: %s had to be set to 0.', var_name)\n          dense_layers.add(var_name)\n    else:\n      is_eps_valid = True\n\n  sparsities = {}\n  # With the valid epsilon, we can set sparsities of the remaning layers.\n  for mask in all_masks:\n    var_name = extract_name_fn(mask.name)\n    shape_list = mask.shape.as_list()\n    n_param = np.prod(shape_list)\n    if var_name in custom_sparsity_map:\n      sparsities[mask.name] = custom_sparsity_map[var_name]\n      tf.logging.info('layer: %s has custom sparsity: %f', var_name,\n                      sparsities[mask.name])\n    elif var_name in dense_layers:\n      sparsities[mask.name] = 0.\n    else:\n      probability_one = eps * raw_probabilities[mask.name]\n      sparsities[mask.name] = 1. - probability_one\n    tf.logging.info('layer: %s, shape: %s, sparsity: %f', var_name, mask.shape,\n                    sparsities[mask.name])\n  return sparsities\n\n\ndef get_sparsities_uniform(all_masks,\n                           default_sparsity,\n                           custom_sparsity_map,\n                           extract_name_fn=mask_extract_name_fn):\n  \"\"\"Given the method, returns the sparsity of individual layers as a dict.\n\n  Args:\n    all_masks: list, of all mask Variables.\n    default_sparsity: float, between 0 and 1.\n    custom_sparsity_map: dict, <str, float> key/value pairs where the mask\n      correspond whose name is '{key}/mask:0' is set to the corresponding\n        sparsity value.\n    extract_name_fn: function, extracts the variable name.\n\n  Returns:\n    sparsities, dict of where keys() are equal to all_masks and individiual\n      masks are mapped to the their sparsities.\n  \"\"\"\n  sparsities = {}\n  for mask in all_masks:\n    var_name = extract_name_fn(mask.name)\n    if var_name in custom_sparsity_map:\n      sparsities[mask.name] = custom_sparsity_map[var_name]\n    else:\n      sparsities[mask.name] = default_sparsity\n  return sparsities\n\n\ndef get_sparsities_str(all_masks, default_sparsity):\n  \"\"\"Given the method, returns the sparsity of individual layers as a dict.\n\n  Args:\n    all_masks: list, of all mask Variables.\n    default_sparsity: float, between 0 and 1.\n\n  Returns:\n    sparsities, dict of where keys() are equal to all_masks and individiual\n      masks are mapped to the their sparsities.\n  \"\"\"\n  str_sparsities_parsed = str_sparsities.read_all()\n  if default_sparsity in str_sparsities_parsed:\n    sprsts = str_sparsities_parsed[default_sparsity]\n    sparsities = {mask.name: sprsts[mask.name] for mask in all_masks}\n  else:\n    raise ValueError('sparsity: %f is not defined' % default_sparsity)\n  return sparsities\n\n\ndef get_sparsities(all_masks,\n                   method,\n                   default_sparsity,\n                   custom_sparsity_map,\n                   extract_name_fn=mask_extract_name_fn,\n                   erk_power_scale=DEFAULT_ERK_SCALE):\n  \"\"\"Given the method, returns the sparsity of individual layers as a dict.\n\n  Args:\n    all_masks: list, of all mask Variables.\n    method: str, 'random' or 'erdos_renyi'.\n    default_sparsity: float, between 0 and 1.\n    custom_sparsity_map: dict, <str, float> key/value pairs where the mask\n      correspond whose name is '{key}/mask:0' is set to the corresponding\n        sparsity value.\n    extract_name_fn: function, extracts the variable name.\n    erk_power_scale: float, passed to the erdos_renyi function.\n\n  Returns:\n    sparsities, dict of where keys() are equal to all_masks and individiual\n      masks are mapped to the their sparsities.\n\n  Raises:\n    ValueError: when a key from custom_sparsity not found in all_masks.\n    ValueError: when an invalid initialization option is given.\n  \"\"\"\n  # (1) Ensure all keys are valid and processed.\n  keys_found = set()\n  for mask in all_masks:\n    var_name = extract_name_fn(mask.name)\n    if var_name in custom_sparsity_map:\n      keys_found.add(var_name)\n  keys_given = set(custom_sparsity_map.keys())\n  if keys_found != keys_given:\n    diff = keys_given - keys_found\n    raise ValueError('No masks are found for the following names: %s' %\n                     str(diff))\n\n  if method in ('erdos_renyi', 'erdos_renyi_kernel'):\n    include_kernel = method == 'erdos_renyi_kernel'\n    sparsities = get_sparsities_erdos_renyi(\n        all_masks,\n        default_sparsity,\n        custom_sparsity_map,\n        include_kernel=include_kernel,\n        extract_name_fn=extract_name_fn,\n        erk_power_scale=erk_power_scale)\n  elif method == 'random':\n    sparsities = get_sparsities_uniform(\n        all_masks,\n        default_sparsity,\n        custom_sparsity_map,\n        extract_name_fn=extract_name_fn)\n  elif method == 'str':\n    sparsities = get_sparsities_str(all_masks, default_sparsity)\n  else:\n    raise ValueError('Method: %s is not valid mask initialization method' %\n                     method)\n  return sparsities\n\n\ndef get_mask_init_fn(all_masks,\n                     method,\n                     default_sparsity,\n                     custom_sparsity_map,\n                     mask_fn=get_mask_random,\n                     erk_power_scale=DEFAULT_ERK_SCALE,\n                     extract_name_fn=mask_extract_name_fn):\n  \"\"\"Returns a function for initializing masks randomly.\n\n  Args:\n    all_masks: list, of all masks to be updated.\n    method: str, method to initialize the masks, passed to the\n      sparse_utils.get_mask() function.\n    default_sparsity: float, if 0 mask left intact, if greater than one, a\n      fraction of the ones in each mask is flipped to 0.\n    custom_sparsity_map: dict, sparsity of individual variables can be\n      overridden here. Key should point to the correct variable name, and value\n      should be in [0, 1].\n    mask_fn: function, to initialize masks with given sparsity.\n    erk_power_scale: float, passed to get_sparsities.\n    extract_name_fn: function, used to grab names from the variable.\n\n  Returns:\n    A callable to run after an init op. See `init_fn` of\n    `tf.train.Scaffold`. Returns None if no `preinitialize_checkpoint` field\n    is set in `RunnerSpec`.\n  Raise:\n    ValueError: when there is no mask corresponding to a key in the\n      custom_sparsity_map.\n  \"\"\"\n  sparsities = get_sparsities(\n      all_masks,\n      method,\n      default_sparsity,\n      custom_sparsity_map,\n      erk_power_scale=erk_power_scale,\n      extract_name_fn=extract_name_fn)\n  tf.logging.info('Per layer sparsities are like the following: %s',\n                  str(sparsities))\n  assign_ops = []\n  for mask in all_masks:\n    new_mask = mask_fn(mask, sparsities[mask.name], mask.dtype)\n    assign_op = tf.assign(mask, new_mask)\n    assign_ops.append(assign_op)\n\n  return tf.group(assign_ops)\n\n\n## Calculating flops and parameters using a list of Keras layers.\ndef _get_kernel(layer):\n  \"\"\"Given the Keras layer returns the weights.\"\"\"\n  if isinstance(layer, tf.keras.layers.DepthwiseConv2D):\n    return layer.depthwise_kernel\n  else:\n    return layer.kernel\n\n\ndef get_stats(masked_layers,\n              default_sparsity=0.8,\n              method='erdos_renyi',\n              custom_sparsities=None,\n              is_debug=False,\n              width=1.,\n              first_layer_name='conv1',\n              last_layer_name='conv_preds',\n              param_size=32,\n              erk_power_scale=DEFAULT_ERK_SCALE):\n  \"\"\"Given the Keras layer returns the size and FLOPS of the model.\n\n  Args:\n    masked_layers: list, of tf.keras.Layer.\n    default_sparsity: float, if 0 mask left intact, if greater than one, a\n      fraction of the ones in each mask is flipped to 0.\n    method: str, passed to the `.get_sparsities()` functions.\n    custom_sparsities: dictor None, sparsity of individual variables can be\n      overridden here. Key should point to the correct variable name, and value\n      should be in [0, 1].\n    is_debug: bool, if True prints individual stats for given layers.\n    width: float, multiplier for the individual layer widths.\n    first_layer_name: str, to scale the width correctly.\n    last_layer_name: str, to scale the width correctly.\n    param_size: int, number of bits to represent a single parameter.\n    erk_power_scale: float, passed to the get_sparsities function.\n\n  Returns:\n    total_flops, sum of multiply and add operations.\n    total_param_bits, total bits to represent the model during the inference.\n    real_sparsity, calculated independently omitting bias parameters.\n  \"\"\"\n  if custom_sparsities is None:\n    custom_sparsities = {}\n  sparsities = get_sparsities([_get_kernel(l) for l in masked_layers],\n                              method,\n                              default_sparsity,\n                              custom_sparsities,\n                              lambda a: a,\n                              erk_power_scale=erk_power_scale)\n  total_flops = 0\n  total_param_bits = 0\n  total_params = 0.\n  n_zeros = 0.\n  for layer in masked_layers:\n    kernel = _get_kernel(layer)\n    k_shape = kernel.shape.as_list()\n    d_in, d_out = 2, 3\n    # If fully connected change indices.\n    if len(k_shape) == 2:\n      d_in, d_out = 0, 1\n    # and  k_shape[d_in] != 1 since depthwise\n    if not kernel.name.startswith(first_layer_name) and k_shape[d_in] != 1:\n      k_shape[d_in] = int(k_shape[d_in] * width)\n    if not kernel.name.startswith(last_layer_name) and k_shape[d_out] != 1:\n      k_shape[d_out] = int(k_shape[d_out] * width)\n    if is_debug:\n      print(kernel.name, layer.input_shape, k_shape, sparsities[kernel.name])\n\n    if isinstance(layer, tf.keras.layers.Conv2D):\n      layer_op = counting.Conv2D(layer.input_shape[1], k_shape, layer.strides,\n                                 'same', True, 'relu')\n    elif isinstance(layer, tf.keras.layers.DepthwiseConv2D):\n      layer_op = counting.DepthWiseConv2D(layer.input_shape[1], k_shape,\n                                          layer.strides, 'same', True, 'relu')\n    elif isinstance(layer, tf.keras.layers.Dense):\n      layer_op = counting.FullyConnected(k_shape, True, 'relu')\n    else:\n      raise ValueError('Should not happen.')\n    param_count, n_mults, n_adds = counting.count_ops(layer_op,\n                                                      sparsities[kernel.name],\n                                                      param_size)\n    total_param_bits += param_count\n    total_flops += n_mults + n_adds\n    n_param = np.prod(k_shape)\n    total_params += n_param\n    n_zeros += int(n_param * sparsities[kernel.name])\n\n  return total_flops, total_param_bits, n_zeros / total_params\n"
  },
  {
    "path": "rigl/sparse_utils_test.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Tests for the data_helper input pipeline and the training process.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl.testing import parameterized\nimport numpy as np\nfrom rigl import sparse_utils\nimport tensorflow.compat.v1 as tf\n\n\nclass GetMaskRandomTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_session(self):\n    \"\"\"Resets the graph and returns a fresh session.\"\"\"\n    tf.reset_default_graph()\n    sess = tf.Session()\n    return sess\n\n  @parameterized.parameters(((30, 40), 0.5), ((1, 2, 1, 4), 0.8), ((3,), 0.1))\n  def testMaskConnectionDeterminism(self, shape, sparsity):\n    sess = self._setup_session()\n    mask = tf.ones(shape)\n    mask1 = sparse_utils.get_mask_random(mask, sparsity, tf.int32)\n    mask2 = sparse_utils.get_mask_random(mask, sparsity, tf.int32)\n    mask1_array, = sess.run([mask1])\n    mask2_array, = sess.run([mask2])\n    self.assertEqual(np.sum(mask1_array), np.sum(mask2_array))\n\n  @parameterized.parameters(((30, 4), 0.5, 60), ((1, 2, 1, 4), 0.8, 2),\n                            ((30,), 0.1, 27))\n  def testMaskFraction(self, shape, sparsity, expected_ones):\n    sess = self._setup_session()\n    mask = tf.ones(shape)\n    mask1 = sparse_utils.get_mask_random(mask, sparsity, tf.int32)\n    mask1_array, = sess.run([mask1])\n\n    self.assertEqual(np.sum(mask1_array), expected_ones)\n\n  @parameterized.parameters(tf.int32, tf.float32, tf.int64, tf.float64)\n  def testMaskDtype(self, dtype):\n    _ = self._setup_session()\n    mask = tf.ones((3, 2))\n    mask1 = sparse_utils.get_mask_random(mask, 0.5, dtype)\n    self.assertEqual(mask1.dtype, dtype)\n\n\nclass GetSparsitiesTest(tf.test.TestCase, parameterized.TestCase):\n\n  def _setup_session(self):\n    \"\"\"Resets the graph and returns a fresh session.\"\"\"\n    tf.reset_default_graph()\n    sess = tf.Session()\n    return sess\n\n  @parameterized.parameters(0., 0.4, 0.9)\n  def testSparsityDictRandom(self, default_sparsity):\n    _ = self._setup_session()\n    all_masks = [tf.get_variable(shape=(2, 3), name='var1/mask'),\n                 tf.get_variable(shape=(2, 3), name='var2/mask'),\n                 tf.get_variable(shape=(1, 1, 3), name='var3/mask')]\n    custom_sparsity = {'var1': 0.8}\n    sparsities = sparse_utils.get_sparsities(\n        all_masks, 'random', default_sparsity, custom_sparsity)\n    self.assertEqual(sparsities[all_masks[0].name], 0.8)\n    self.assertEqual(sparsities[all_masks[1].name], default_sparsity)\n    self.assertEqual(sparsities[all_masks[2].name], default_sparsity)\n\n  @parameterized.parameters(0.1, 0.4, 0.9)\n  def testSparsityDictErdosRenyiCustom(self, default_sparsity):\n    _ = self._setup_session()\n    all_masks = [tf.get_variable(shape=(2, 4), name='var1/mask'),\n                 tf.get_variable(shape=(2, 3), name='var2/mask'),\n                 tf.get_variable(shape=(1, 1, 3), name='var3/mask')]\n    custom_sparsity = {'var3': 0.8}\n    sparsities = sparse_utils.get_sparsities(\n        all_masks, 'erdos_renyi', default_sparsity, custom_sparsity)\n    self.assertEqual(sparsities[all_masks[2].name], 0.8)\n\n  @parameterized.parameters(0.1, 0.4, 0.9)\n  def testSparsityDictErdosRenyiError(self, default_sparsity):\n    _ = self._setup_session()\n    all_masks = [tf.get_variable(shape=(2, 4), name='var1/mask'),\n                 tf.get_variable(shape=(2, 3), name='var2/mask'),\n                 tf.get_variable(shape=(1, 1, 3), name='var3/mask')]\n    custom_sparsity = {'var3': 0.8}\n    sparsities = sparse_utils.get_sparsities(\n        all_masks, 'erdos_renyi', default_sparsity, custom_sparsity)\n    self.assertEqual(sparsities[all_masks[2].name], 0.8)\n\n  @parameterized.parameters(((2, 3), (2, 3), 0.5),\n                            ((1, 1, 2, 3), (1, 1, 2, 3), 0.3),\n                            ((8, 6), (4, 3), 0.7),\n                            ((80, 4), (20, 20), 0.8),\n                            ((2, 6), (2, 3), 0.8))\n  def testSparsityDictErdosRenyiSparsitiesScale(\n      self, shape1, shape2, default_sparsity):\n    _ = self._setup_session()\n    all_masks = [tf.get_variable(shape=shape1, name='var1/mask'),\n                 tf.get_variable(shape=shape2, name='var2/mask')]\n    custom_sparsity = {}\n    sparsities = sparse_utils.get_sparsities(\n        all_masks, 'erdos_renyi', default_sparsity, custom_sparsity)\n    sparsity1 = sparsities[all_masks[0].name]\n    size1 = np.prod(shape1)\n    sparsity2 = sparsities[all_masks[1].name]\n    size2 = np.prod(shape2)\n    # Ensure that total number of connections are similar.\n    expected_zeros_uniform = (\n        sparse_utils.get_n_zeros(size1, default_sparsity) +\n        sparse_utils.get_n_zeros(size2, default_sparsity))\n    # Ensure that total number of connections are similar.\n    expected_zeros_current = (\n        sparse_utils.get_n_zeros(size1, sparsity1) +\n        sparse_utils.get_n_zeros(size2, sparsity2))\n    # Due to rounding we can have some difference. This is expected but should\n    # be less than number of rounding operations we make.\n    diff = abs(expected_zeros_uniform - expected_zeros_current)\n    tolerance = 2\n    self.assertLessEqual(diff, tolerance)\n\n    # Ensure that ErdosRenyi proportions are preserved.\n    factor1 = (shape1[-1] + shape1[-2]) / float(shape1[-1] * shape1[-2])\n    factor2 = (shape2[-1] + shape2[-2]) / float(shape2[-1] * shape2[-2])\n    self.assertAlmostEqual((1 - sparsity1) / factor1,\n                           (1 - sparsity2) / factor2)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "rigl/str_sparsities.py",
    "content": "# coding=utf-8\n# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Reads ResNet-50 sparsity distributions found by STR.\n\n[STR]: https://arxiv.org/abs/2002.03231\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport re\n\nREPORTED_SPARSITIES = \"\"\"\nOverall - Overall 25502912 4089284608 79.55 81.27 87.70 90.23 90.55 94.80 95.03 95.15 96.11 96.53 97.78 98.05 98.22 98.79 98.98 99.10\nLayer 1 - conv1 9408 118013952 51.46 51.40 63.02 59.80 59.83 64.87 67.36 66.96 72.11 69.46 73.29 73.47 72.05 75.12 76.12 77.75\nLayer 2 - layer1.0.conv1 4096 12845056 69.36 73.24 87.57 83.28 85.18 89.60 91.41 91.11 92.38 91.75 94.46 94.51 94.60 95.95 96.53 96.51\nLayer 3 - layer1.0.conv2 36864 115605504 77.85 76.26 90.87 89.48 87.31 94.79 94.27 95.04 95.69 96.07 97.36 97.77 98.35 98.51 98.59 98.84\nLayer 4 - layer1.0.conv3 16384 51380224 74.81 74.65 86.52 85.80 85.25 91.85 92.78 93.67 94.13 94.69 96.61 97.03 97.37 98.04 98.21 98.47\nLayer 5 - layer1.0.downsample.0 16384 51380224 70.95 72.96 83.53 83.34 82.56 89.13 90.62 90.17 91.83 92.69 95.48 94.89 95.68 96.98 97.56 97.72\nLayer 6 - layer1.1.conv1 16384 51380224 80.27 79.58 89.82 89.89 88.51 94.56 96.64 95.78 95.81 96.81 98.79 98.90 98.98 99.13 99.62 99.47\nLayer 7 - layer1.1.conv2 36864 115605504 81.36 80.95 91.75 90.60 89.61 94.70 95.78 96.18 96.42 97.26 98.65 99.07 99.40 99.11 99.31 99.56\nLayer 8 - layer1.1.conv3 16384 51380224 84.45 80.11 91.22 91.70 90.21 95.17 97.05 95.81 96.34 97.23 98.68 98.76 98.90 99.16 99.57 99.46\nLayer 9 - layer1.2.conv1 16384 51380224 78.23 79.79 90.12 88.07 89.36 94.62 95.94 94.74 96.23 96.75 97.96 98.41 98.72 99.38 99.35 99.46\nLayer 10 - layer1.2.conv2 36864 115605504 76.01 81.53 91.06 87.03 88.27 93.90 95.63 94.26 96.24 96.11 97.54 98.27 98.44 99.32 99.19 99.39\nLayer 11 - layer1.2.conv3 16384 51380224 84.47 83.28 94.95 90.99 92.64 95.76 96.95 96.01 96.87 97.31 98.38 98.60 98.72 99.38 99.27 99.51\nLayer 12 - layer2.0.conv1 32768 102760448 73.74 73.96 86.78 85.95 85.90 92.32 94.79 93.86 94.62 95.64 97.19 98.22 98.52 98.48 98.84 98.92\nLayer 13 - layer2.0.conv2 147456 115605504 82.56 85.70 91.31 93.91 94.03 97.54 97.43 97.65 98.38 98.62 99.24 99.23 99.40 99.61 99.67 99.63\nLayer 14 - layer2.0.conv3 65536 51380224 84.70 83.55 93.04 93.13 92.13 96.61 97.37 97.21 97.59 98.14 98.80 98.95 99.18 99.29 99.47 99.43\nLayer 15 - layer2.0.downsample.0 131072 102760448 85.10 87.66 92.78 94.96 95.13 98.07 97.97 98.15 98.70 98.88 99.37 99.35 99.40 99.69 99.68 99.71\nLayer 16 - layer2.1.conv1 65536 51380224 85.42 85.79 94.04 95.31 94.94 97.92 98.53 98.21 98.84 99.06 99.46 99.53 99.72 99.78 99.81 99.80\nLayer 17 - layer2.1.conv2 147456 115605504 76.95 82.75 87.63 91.50 91.76 95.59 97.22 96.07 97.32 97.80 98.24 98.24 98.60 99.24 99.66 99.33\nLayer 18 - layer2.1.conv3 65536 51380224 84.76 84.71 93.10 93.66 93.23 97.00 98.18 97.35 98.06 98.41 98.96 99.21 99.32 99.55 99.58 99.59\nLayer 19 - layer2.2.conv1 65536 51380224 84.30 85.34 92.70 94.61 94.76 97.72 97.91 98.21 98.54 98.98 99.24 99.35 99.50 99.62 99.63 99.77\nLayer 20 - layer2.2.conv2 147456 115605504 84.28 85.43 92.99 94.86 94.90 97.52 97.21 98.11 98.19 99.04 99.28 99.37 99.46 99.63 99.59 99.72\nLayer 21 - layer2.2.conv3 65536 51380224 82.19 84.21 91.12 93.38 93.53 96.89 97.14 97.59 97.77 98.66 98.96 99.15 99.25 99.49 99.51 99.57\nLayer 22 - layer2.3.conv1 65536 51380224 83.37 84.41 90.46 93.26 93.50 96.71 97.89 96.99 98.14 98.36 99.10 99.23 99.33 99.53 99.75 99.60\nLayer 23 - layer2.3.conv2 147456 115605504 82.83 84.03 91.44 93.21 93.25 96.83 98.02 96.96 98.45 98.30 98.97 99.06 99.26 99.31 99.81 99.68\nLayer 24 - layer2.3.conv3 65536 51380224 82.93 85.65 91.02 94.14 93.56 97.20 97.97 97.04 98.16 98.36 98.88 98.97 99.20 99.32 99.67 99.62\nLayer 25 - layer3.0.conv1 131072 102760448 76.63 77.98 85.99 88.85 88.60 94.26 95.07 94.97 96.21 96.59 97.75 98.04 98.30 98.72 99.11 99.06\nLayer 26 - layer3.0.conv2 589824 115605504 87.35 88.68 94.39 96.14 96.19 98.51 98.77 98.72 99.11 99.23 99.53 99.59 99.64 99.73 99.80 99.81\nLayer 27 - layer3.0.conv3 262144 51380224 81.22 83.22 90.58 93.19 93.05 96.82 97.38 97.32 97.98 98.28 98.88 99.03 99.16 99.39 99.55 99.53\nLayer 28 - layer3.0.downsample.0 524288 102760448 89.75 90.99 96.05 97.20 97.16 98.96 99.21 99.20 99.50 99.58 99.78 99.82 99.86 99.91 99.94 99.93\nLayer 29 - layer3.1.conv1 262144 51380224 85.88 87.35 93.43 95.36 96.12 98.64 98.77 98.87 99.22 99.33 99.64 99.67 99.72 99.82 99.88 99.84\nLayer 30 - layer3.1.conv2 589824 115605504 85.06 86.24 92.74 95.06 95.30 98.09 98.28 98.36 98.75 99.08 99.46 99.48 99.54 99.69 99.76 99.76\nLayer 31 - layer3.1.conv3 262144 51380224 84.34 86.79 92.15 94.84 94.90 97.75 98.15 98.11 98.56 98.94 99.30 99.36 99.45 99.65 99.79 99.70\nLayer 32 - layer3.2.conv1 262144 51380224 87.51 89.15 94.15 96.77 96.46 98.81 98.83 98.96 99.19 99.44 99.67 99.71 99.74 99.82 99.85 99.89\nLayer 33 - layer3.2.conv2 589824 115605504 87.15 88.67 94.09 95.59 96.14 98.86 98.69 98.91 99.21 99.20 99.64 99.72 99.76 99.85 99.84 99.90\nLayer 34 - layer3.2.conv3 262144 51380224 84.86 86.90 92.40 94.99 94.99 98.19 98.19 98.42 98.76 98.97 99.42 99.56 99.62 99.76 99.75 99.88\nLayer 35 - layer3.3.conv1 262144 51380224 86.62 89.46 94.06 96.08 95.88 98.70 98.71 98.77 99.01 99.27 99.58 99.66 99.69 99.83 99.87 99.87\nLayer 36 - layer3.3.conv2 589824 115605504 86.52 87.97 93.56 96.10 96.11 98.70 98.82 98.89 99.19 99.31 99.68 99.73 99.77 99.88 99.87 99.93\nLayer 37 - layer3.3.conv3 262144 51380224 84.19 86.81 92.32 94.94 94.91 98.20 98.37 98.43 98.82 99.00 99.51 99.57 99.64 99.81 99.81 99.87\nLayer 38 - layer3.4.conv1 262144 51380224 85.85 88.40 93.55 95.49 95.86 98.35 98.44 98.55 98.79 98.96 99.54 99.59 99.60 99.82 99.86 99.87\nLayer 39 - layer3.4.conv2 589824 115605504 85.96 87.38 93.27 95.66 95.63 98.41 98.58 98.56 99.19 99.26 99.64 99.69 99.67 99.87 99.90 99.92\nLayer 40 - layer3.4.conv3 262144 51380224 83.45 85.76 91.75 94.49 94.35 97.67 98.09 97.99 98.65 98.94 99.49 99.52 99.48 99.77 99.86 99.85\nLayer 41 - layer3.5.conv1 262144 51380224 83.33 85.77 91.79 95.09 94.24 97.46 97.89 97.92 98.71 98.90 99.35 99.52 99.58 99.76 99.79 99.83\nLayer 42 - layer3.5.conv2 589824 115605504 84.98 86.67 92.48 94.92 95.13 97.88 98.14 98.32 98.91 99.00 99.44 99.58 99.69 99.80 99.83 99.87\nLayer 43 - layer3.5.conv3 262144 51380224 79.78 82.23 89.39 93.14 92.76 96.59 97.04 97.30 98.10 98.41 99.03 99.25 99.44 99.61 99.71 99.75\nLayer 44 - layer4.0.conv1 524288 102760448 77.83 79.61 87.11 90.32 90.64 95.39 95.84 95.92 97.17 97.35 98.36 98.60 98.83 99.20 99.37 99.42\nLayer 45 - layer4.0.conv2 2359296 115605504 86.18 88.00 93.53 95.66 95.78 98.31 98.47 98.55 99.08 99.16 99.54 99.63 99.69 99.81 99.85 99.86\nLayer 46 - layer4.0.conv3 1048576 51380224 78.43 80.48 87.85 91.14 91.27 96.00 96.40 96.47 97.53 97.92 98.81 99.00 99.15 99.45 99.57 99.61\nLayer 47 - layer4.0.downsample.0 2097152 102760448 88.49 89.98 95.03 96.79 96.90 98.91 99.06 99.11 99.45 99.51 99.77 99.82 99.85 99.92 99.94 99.94\nLayer 48 - layer4.1.conv1 1048576 51380224 82.07 84.02 90.34 93.69 93.72 97.15 97.56 97.76 98.45 98.75 99.27 99.36 99.54 99.67 99.76 99.80\nLayer 49 - layer4.1.conv2 2359296 115605504 83.42 85.23 91.16 93.98 93.93 97.26 97.58 97.71 98.36 98.67 99.25 99.34 99.50 99.68 99.76 99.80\nLayer 50 - layer4.1.conv3 1048576 51380224 78.08 79.96 86.66 90.48 90.22 95.22 95.76 95.89 96.88 97.65 98.70 98.85 99.13 99.45 99.58 99.66\nLayer 51 - layer4.2.conv1 1048576 51380224 76.34 77.93 84.98 87.57 88.47 93.90 93.87 94.16 95.55 95.91 97.66 97.97 98.15 98.88 99.08 99.22\nLayer 52 - layer4.2.conv2 2359296 115605504 73.57 74.97 82.32 84.37 86.01 91.92 91.66 92.22 94.02 94.16 96.65 97.13 97.29 98.44 98.74 99.00\nLayer 53 - layer4.2.conv3 1048576 51380224 68.78 70.38 78.11 80.29 81.73 89.64 89.43 89.65 91.40 92.65 96.02 96.72 96.93 98.47 98.83 99.15\nLayer 54 - fc 2048000 2048000 50.65 52.46 60.48 64.50 65.12 75.20 75.73 75.80 78.57 80.69 85.96 87.26 88.03 91.11 92.15 92.87\"\"\"\n\n\ndef _name_map_str(k):\n  \"\"\"Maps the naming of the layers.\"\"\"\n  if k == 'conv1':\n    new_key = 'initial_conv'\n  elif k == 'fc':\n    new_key = 'final_dense'\n  else:\n    if 'downsample' in k:\n      group_id = re.search(r'layer(\\d)\\.0\\.downsample\\.0', k).group(1)\n      new_key = 'bottleneck_projection_block_group_projection_block_group%s' % group_id\n    else:\n      res = re.search(r'layer(\\d)\\.(\\d)\\.conv(\\d)', k)\n      group_id, block_id, layer_id = (int(res.group(1)), int(res.group(2)),\n                                      int(res.group(3)))\n      if block_id == 0:\n        new_key = 'bottleneck_%d_block_group_projection_block_group%d' % (\n            layer_id, group_id)\n      else:\n        new_key = 'bottleneck_%d_block_group%d_%d_1' % (layer_id, group_id,\n                                                        block_id)\n  return 'resnet_model/%s/mask:0' % new_key\n\n\ndef read_all():\n  \"\"\"Reads and returns sparsity distributions.\"\"\"\n  str_sparsities_parsed = collections.defaultdict(dict)\n  for l in REPORTED_SPARSITIES.strip().split('\\n'):\n    l = l.split('-')[1].strip().split(' ')\n    if l[0] == 'Overall':\n      overall_sparsities = list(map(float, l[3:]))\n    else:\n      for i, ls in enumerate(l[3:]):\n        # Sparsities are between 0 and 1, so devide by 100.\n        s = overall_sparsities[i] / 100\n        new_key = _name_map_str(l[0])\n        # Accuracies are between 0 and 1, so devide by 100.\n        str_sparsities_parsed[s][new_key] = float(ls) / 100.\n  return str_sparsities_parsed\n"
  },
  {
    "path": "run.sh",
    "content": "# Copyright 2022 RigL Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n#!/bin/bash\n\nset -e\nset -x\n\nvirtualenv -p python3 env\nsource env/bin/activate\n\npip install -r rigl/requirements.txt\npython -m rigl.sparse_optimizers_test\npython -m rigl.sparse_utils_test\n"
  }
]